In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Caculate the number of trainable parameters
def _calc_width(net):
    import numpy as np
    net_params = filter(lambda p: p.requires_grad, net.parameters())
    weight_count = 0
    for param in net_params:
        weight_count += np.prod(param.size())
    return weight_count



In [11]:
# Load Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.label_emb = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0),1, 28, 28)
generator=Generator().cuda()
PATH="./generator_state.pt"
generator.load_state_dict(torch.load(PATH))
generator.eval()
def generate_digits( batch_size,noise_dim=100):
    z = torch.randn(batch_size, noise_dim, device='cuda')
    labels = torch.randint(0, 10, (batch_size,), device='cuda')
    with torch.no_grad():
        images = generator(z, labels)
    return images,labels


In [12]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Define the Teacher Model
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the Student Model
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32*7*7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:

# Train the Teacher Model
teacher_model = TeacherModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)

def train_standalone(model):
    model.train()
    for epoch in range(10):
        for _ in range(len(train_loader)):
            images, labels = generate_digits(64,100)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_standalone(teacher_model)



Epoch 1, Loss: 0.013882556930184364
Epoch 2, Loss: 0.01285826787352562
Epoch 3, Loss: 0.0016865209909155965
Epoch 4, Loss: 0.026519469916820526
Epoch 5, Loss: 0.007001764141023159
Epoch 6, Loss: 0.030459141358733177
Epoch 7, Loss: 0.0029465716797858477
Epoch 8, Loss: 0.00031537990435026586
Epoch 9, Loss: 0.0002815399202518165
Epoch 10, Loss: 0.06564853340387344


In [15]:

# Evaluate both models
def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

teacher_acc = evaluate(teacher_model)
print(f'Teacher Model Accuracy: {teacher_acc}%')


Teacher Model Accuracy: 94.39%


In [16]:

standalone_student_model = StudentModel().cuda()
optimizer = optim.Adam(standalone_student_model.parameters(), lr=0.001)

train_standalone(standalone_student_model)
standalone_student_acc = evaluate(standalone_student_model)
print(f'Standalone Student Model Accuracy: {standalone_student_acc}%')

# Knowledge Distillation Loss Function
def kd_loss(student_logits, teacher_logits, labels, T=3, alpha=0.5):
    kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits/T, dim=1),
                                                  F.softmax(teacher_logits/T, dim=1)) * (T*T)
    ce_loss = criterion(student_logits, labels)
    return alpha * kd_loss + (1 - alpha) * ce_loss

# Train the Student Model using Knowledge Distillation
student_model = StudentModel().cuda()
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

def train_student():
    teacher_model.eval()
    student_model.train()
    for epoch in range(10):
        for _ in range(len(train_loader)):
            images, labels = generate_digits(64,100)
            optimizer.zero_grad()
            student_outputs = student_model(images)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            loss = kd_loss(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

train_student()
student_acc = evaluate(student_model)
print(f'Student Model Accuracy: {student_acc}%')

Epoch 1, Loss: 0.043353449553251266
Epoch 2, Loss: 0.003263709368184209
Epoch 3, Loss: 0.011174341663718224
Epoch 4, Loss: 0.007313861511647701
Epoch 5, Loss: 0.00018130919488612562
Epoch 6, Loss: 0.02880851924419403
Epoch 7, Loss: 0.010320514440536499
Epoch 8, Loss: 0.08353959023952484
Epoch 9, Loss: 0.015486414544284344
Epoch 10, Loss: 0.0006192185683175921
Standalone Student Model Accuracy: 93.53%
Epoch 1, Loss: 0.31982678174972534
Epoch 2, Loss: 0.07603195309638977
Epoch 3, Loss: 0.13417203724384308
Epoch 4, Loss: 0.07136975973844528
Epoch 5, Loss: 0.031091423705220222
Epoch 6, Loss: 0.04806558042764664
Epoch 7, Loss: 0.020203091204166412
Epoch 8, Loss: 0.04069128632545471
Epoch 9, Loss: 0.0252186618745327
Epoch 10, Loss: 0.030238721519708633
Student Model Accuracy: 94.53%
