In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, random_split

In [2]:
# Device configuration GPU/CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## VGG-16 Model


In [3]:
class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )



    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
# Hyperparameters
num_classes = 10
num_epochs = 30
batch_size = 64
learning_rate = 0.01

## Dataset

In [None]:
# Load CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
])

dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Split dataset into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
model = VGG16(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

# Training Original VGG16

In [7]:
total_step = len(train_loader)
model.train()
for epoch in range(num_epochs):

    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:  # Print every 100 mini-batches
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
            running_loss = 0.0

    # Validate the model
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/30], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

torch.save(model.state_dict(), "model.h5")

Epoch [1/30], Step [100/625], Loss: 2.1577
Epoch [1/30], Step [200/625], Loss: 1.8895
Epoch [1/30], Step [300/625], Loss: 1.7619
Epoch [1/30], Step [400/625], Loss: 1.6587
Epoch [1/30], Step [500/625], Loss: 1.5539
Epoch [1/30], Step [600/625], Loss: 1.5102
Epoch [1/30], Validation Loss: 1.5053, Validation Accuracy: 46.20%
Epoch [2/30], Step [100/625], Loss: 2.2861
Epoch [2/30], Step [200/625], Loss: 2.1639
Epoch [2/30], Step [300/625], Loss: 2.0202
Epoch [2/30], Step [400/625], Loss: 1.9621
Epoch [2/30], Step [500/625], Loss: 1.9606
Epoch [2/30], Step [600/625], Loss: 1.9274
Epoch [2/30], Validation Loss: 1.8952, Validation Accuracy: 22.71%
Epoch [3/30], Step [100/625], Loss: 1.9423
Epoch [3/30], Step [200/625], Loss: 1.8883
Epoch [3/30], Step [300/625], Loss: 1.8446
Epoch [3/30], Step [400/625], Loss: 1.8531
Epoch [3/30], Step [500/625], Loss: 1.8059
Epoch [3/30], Step [600/625], Loss: 1.8096
Epoch [3/30], Validation Loss: 1.7377, Validation Accuracy: 29.60%


KeyboardInterrupt: 

The code in the cell looks correct and should work without any errors, given that all the necessary variables and modules are already defined and imported in the previous cells. However, to ensure that the code runs smoothly, you should make sure that the `train_loader`, `val_loader`, `model`, `criterion`, `optimizer`, `device`, and `num_epochs` are properly defined and imported.



Made changes.

In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%') 

Accuracy of the network on the 10000 test images: 76.99 %


# Training Pruned VGG16 (TODO!)

# Training with Knowledge Distillation

In [4]:
def train_student(student_model, teacher_model, train_loader, val_loader, num_epochs, soft_target_loss_weight, ce_loss_weight, temperature):
    student_model.train()
    teacher_model.eval()

    optimizer = torch.optim.SGD(student_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # Forward pass through teacher model
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
                teacher_probs = nn.functional.softmax(teacher_outputs / temperature, dim=1)

            # Forward pass through student model
            student_outputs = student_model(inputs)
            student_probs = nn.functional.log_softmax(student_outputs / temperature, dim=1)

            # Compute distillation loss
            soft_target_loss = torch.sum(teacher_probs * (teacher_probs.log() - student_probs))/ student_probs.size()[0] * (temperature**2)

            label_loss = ce_loss(student_outputs, labels)

            loss = soft_target_loss_weight * soft_target_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  # Print every 100 mini-batches
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

        # Validate the student model
        student_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student_model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

        student_model.train()

    torch.save(student_model.state_dict(), "student_model.h5")