<a href="https://colab.research.google.com/github/look4pritam/KnowledgeDistillation/blob/master/Notebooks/KnowledgeDistillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# Teacher Model.

In [None]:
class TeacherCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 32x32 -> 16x16
        x = F.relu(F.max_pool2d(self.conv2(x), 2))  # 16x16 -> 8x8
        x = F.relu(F.max_pool2d(self.conv3(x), 2))  # 8x8 -> 4x4
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Lightning Module for Training Teacher Model.

In [None]:
class TeacherLightning(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = model
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss_teacher", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss_teacher", loss, prog_bar=True)
        self.log("val_acc_teacher", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss_teacher", loss)
        self.log("test_acc_teacher", acc)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

# Student Model.

In [None]:
class StudentModel(pl.LightningModule):
    def __init__(self, teacher_model, temperature=3, alpha=0.5, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        self.lr = lr

        # Freeze teacher parameters
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False

        # Student architecture
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def setup(self, stage=None):
        self.teacher.to(self.device)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch

        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(x)

        student_logits = self(x)

        # Loss components
        ce_loss = F.cross_entropy(student_logits, y)
        student_logits_dist = student_logits / self.temperature
        teacher_logits_dist = teacher_logits / self.temperature

        # KL Divergence
        student_log_prob = F.log_softmax(student_logits_dist, dim=1)
        teacher_prob = F.softmax(teacher_logits_dist, dim=1)
        kl_loss = F.kl_div(student_log_prob, teacher_prob, reduction='batchmean')

        # Total loss
        distillation_loss = (self.temperature ** 2) * kl_loss
        loss = (1 - self.alpha) * ce_loss + self.alpha * distillation_loss

        # Logging
        self.log("train_loss_student", loss)
        self.log("ce_loss", ce_loss)
        self.log("dist_loss", distillation_loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        student_logits = self(x)
        loss = F.cross_entropy(student_logits, y)
        preds = torch.argmax(student_logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss_student", loss, prog_bar=True)
        self.log("val_acc_student", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        student_logits = self(x)
        loss = F.cross_entropy(student_logits, y)
        preds = torch.argmax(student_logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss_student", loss)
        self.log("test_acc_student", acc)
        return acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# CIFAR-10 Data Module.

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])

    def prepare_data(self):
        datasets.CIFAR10(root='./data', train=True, download=True)
        datasets.CIFAR10(root='./data', train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = datasets.CIFAR10(root='./data', train=True,
                                                 transform=self.transform)
            self.val_dataset = datasets.CIFAR10(root='./data', train=False,
                                               transform=self.transform)
        if stage == 'test':
            self.test_dataset = datasets.CIFAR10(root='./data', train=False,
                                                transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

# Train Teacher Model.

In [None]:
teacher_model = TeacherCNN()
teacher_lightning = TeacherLightning(teacher_model, learning_rate=1e-3)

In [None]:
dm = CIFAR10DataModule(batch_size=32)

In [None]:
trainer_teacher = pl.Trainer(
      max_epochs=20,
      accelerator="auto",
      devices=1 if torch.cuda.is_available() else None,
     logger=pl.loggers.TensorBoardLogger("logs/teacher")
    )
trainer_teacher.fit(teacher_lightning, dm)

# Save trained teacher model.

In [None]:
torch.save(teacher_model.state_dict(), "teacher_model.pth")

# Train Student Model using the trained Teacher Model.

In [None]:
teacher = TeacherCNN()
teacher.load_state_dict(torch.load("teacher_model.pth"))
teacher.eval()

In [None]:
student = StudentModel(teacher, temperature=3, alpha=0.5, lr=1e-3)

In [None]:
trainer_student = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,
        logger=pl.loggers.TensorBoardLogger("logs/student")
)

In [None]:
trainer_student.fit(student, dm)