In [None]:
import os
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

# Configurations
CONFIG = {
    'data': {
        'train_dir': 'data/train',
        'test_dir': 'data/test',
        'batch_size': 128
    },
    'model': {
        'num_classes': 10
    },
    'training': {
        'epochs': 50,
        'learning_rate': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'alpha': 0.5,
        'temperature': 3.0
    },
    'logging': {
        'log_dir': 'logs'
    }
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'



# Teacher model

In [None]:
# Teacher Model
class TeacherNet(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.res_block1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256)
        )
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.res_block1(x) + x
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Student Model

In [None]:
# Student Model
class StudentNet(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.res_block = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128)
        )
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.res_block(x) + x
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Knowledge distilation loss

In [None]:


# Knowledge Distillation Loss
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5, T=3.0):
    return nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1),
                          F.softmax(teacher_outputs / T, dim=1)) * (alpha * T * T) + \
           F.cross_entropy(outputs, labels) * (1. - alpha)

#Load Data

In [None]:

# Load Data
def get_dataloaders(train_dir, test_dir, batch_size):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


#Training Loop

In [None]:
# Training Loop
def train():
    train_loader, test_loader = get_dataloaders(CONFIG['data']['train_dir'], CONFIG['data']['test_dir'], CONFIG['data']['batch_size'])
    teacher = TeacherNet(num_classes=CONFIG['model']['num_classes']).to(device)
    student = StudentNet(num_classes=CONFIG['model']['num_classes']).to(device)
    optimizer = optim.SGD(student.parameters(),
                          lr=CONFIG['training']['learning_rate'],
                          momentum=CONFIG['training']['momentum'],
                          weight_decay=CONFIG['training']['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['training']['epochs'])
    writer = SummaryWriter(log_dir=CONFIG['logging']['log_dir'])

    for epoch in range(CONFIG['training']['epochs']):
        student.train()
        accuracy = torchmetrics.Accuracy().to(device)
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            student_outputs = student(inputs)
            with torch.no_grad():
                teacher_outputs = teacher(inputs)
            loss = loss_fn_kd(student_outputs, targets, teacher_outputs, alpha=CONFIG['training']['alpha'], T=CONFIG['training']['temperature'])
            loss.backward()
            optimizer.step()
            accuracy(student_outputs.softmax(dim=-1), targets)

        print(f"Epoch [{epoch + 1}/{CONFIG['training']['epochs']}], Loss: {loss.item():.4f}, Accuracy: {accuracy.compute():.4f}")
        scheduler.step()

    writer.close()

if __name__ == "__main__":
    train()