In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

In [8]:
num_classes=10

In [9]:
device = torch.device('cuda')

# Define your smaller ResNet student model
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(131072, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(65536, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [10]:
# Define the teacher model
teacher_model = DeepNN(10)

# Define the student model
student_model = LightNN(10)

# Define the discriminator model
discriminator = nn.Sequential(
    nn.Linear(num_classes, 128),
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid()
)

# Loss functions
criterion_classification = nn.CrossEntropyLoss()
criterion_kl_divergence = nn.KLDivLoss()

# Optimizers
optimizer_student = optim.SGD(student_model.parameters(), lr=0.1)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.001)

# Data preprocessing
transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor()])

# Load your dataset
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Training loop
num_epochs = 10
alpha = 0.1  # Adjust this hyperparameter

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer_student.zero_grad()
        optimizer_discriminator.zero_grad()

        # Forward pass with the teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)

        # Forward pass with the student model
        student_outputs = student_model(inputs)

        # Forward pass with the discriminator
        student_predictions = discriminator(student_outputs)
        teacher_predictions = discriminator(teacher_outputs)

        # Adversarial loss for the discriminator
        loss_discriminator = -torch.log(teacher_predictions).mean() - torch.log(1 - student_predictions).mean()

        # Knowledge distillation loss
        loss_kl_divergence = criterion_kl_divergence(student_outputs, teacher_outputs)

        # Classification loss
        loss_classification = criterion_classification(student_outputs, labels)

        # Total loss for the student model
        loss_student = loss_classification + alpha * loss_kl_divergence

        # Backpropagate and update weights for the student model
        loss_student.backward()
        optimizer_student.step()

        # Backpropagate and update weights for the discriminator
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Calculate accuracy for the student model
        _, predicted = student_outputs.max(1)
        accuracy = (predicted == labels).sum().item() / labels.size(0)

        # Print loss and accuracy for the student model and discriminator
        print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{step}/{len(train_loader)}]')
        print(f'Student Loss: {loss_student.item():.4f}, Student Accuracy: {accuracy:.4f}')
        print(f'Discriminator Loss: {loss_discriminator.item():.4f}')

        step += 1

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|███████████████████████| 170498071/170498071 [00:05<00:00, 30519600.06it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x2048 and 131072x512)