In [1]:
import warnings; warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

from tqdm.auto import trange


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = torchvision.datasets.CIFAR10(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# Create train and test dataloaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [2]:
class CNN_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, input):
        return self.block(input)


class CNN(nn.Module):
    def __init__(self, in_channels=3, inter_channels=None, image_size=32, n_classes=10):
        super().__init__()

        if inter_channels is None:
            inter_channels = [64, 128, 256]

        model = []
        for out_channels in inter_channels:
            model.append(CNN_Block(in_channels, out_channels))
            in_channels = out_channels
        
        model.append(nn.Flatten())

        image_size = int(image_size / (2 ** len(inter_channels)))
        in_features = out_channels * image_size * image_size
        out_features = int(in_features / 4)
        model.append(nn.Linear(in_features, out_features))
        model.append(nn.ReLU())

        in_features = out_features
        out_features = n_classes
        model.append(nn.Linear(in_features, out_features))

        self.model = nn.Sequential(*model)
    
    def forward(self, input):
        return self.model(input)

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


large_model = CNN()
print(f"Large model has {count_parameters(large_model):,} parameters.")

Large model has 5,351,882 parameters.


In [4]:
def train(model, n_epochs=15):
    model.to(device)
    optim = torch.optim.AdamW(model.parameters())

    best_epoch, best_accuracy = -1, 0
    for epoch in trange(n_epochs):
        model.train()
        for input, target in train_loader:
            input, target = input.to(device), target.to(device)
            pred = model(input)
            loss = F.cross_entropy(pred, target)

            optim.zero_grad()
            loss.backward()
            optim.step()
        
        model.eval()
        correct, total = 0, 0
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            pred = model(input)
            pred = pred.argmax(dim=1)
            correct += (pred == target).sum()
            total += target.numel()
        
        accuracy = 100 * correct / total
        print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
        if accuracy > best_accuracy:
            best_epoch = epoch
            best_accuracy = accuracy
    
    print(f"Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}")

In [5]:
train(large_model)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 38.2%
Epoch 1: accuracy 50.6%
Epoch 2: accuracy 60.6%
Epoch 3: accuracy 67.1%
Epoch 4: accuracy 72.5%
Epoch 5: accuracy 71.9%
Epoch 6: accuracy 72.6%
Epoch 7: accuracy 76.2%
Epoch 8: accuracy 72.9%
Epoch 9: accuracy 76.5%
Epoch 10: accuracy 76.5%
Epoch 11: accuracy 70.2%
Epoch 12: accuracy 77.5%
Epoch 13: accuracy 78.4%
Epoch 14: accuracy 79.0%
Best accuracy 79.0% after epoch 14


In [6]:
smaller_model = CNN(inter_channels=[16, 32])
print(f"Smaller model has {count_parameters(smaller_model):,} parameters.")

Smaller model has 1,070,970 parameters.


In [7]:
train(smaller_model)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 46.0%
Epoch 1: accuracy 59.6%
Epoch 2: accuracy 61.2%
Epoch 3: accuracy 62.3%
Epoch 4: accuracy 64.8%
Epoch 5: accuracy 68.0%
Epoch 6: accuracy 67.4%
Epoch 7: accuracy 68.4%
Epoch 8: accuracy 69.5%
Epoch 9: accuracy 68.5%
Epoch 10: accuracy 69.0%
Epoch 11: accuracy 68.9%
Epoch 12: accuracy 69.5%
Epoch 13: accuracy 68.8%
Epoch 14: accuracy 69.2%
Best accuracy 69.5% after epoch 8


In [8]:
def train_distill(student_model, teacher_model, n_epochs=15, alpha=0.7, T=5):
    student_model.to(device)
    teacher_model.to(device)
    optim = torch.optim.AdamW(student_model.parameters())

    best_epoch, best_accuracy = -1, 0
    for epoch in trange(n_epochs):
        student_model.train()
        for input, target in train_loader:
            input, target = input.to(device), target.to(device)
            student_pred = student_model(input)
            teacher_pred = teacher_model(input)

            student_logprobs = F.log_softmax(student_pred / T, dim=-1)
            teacher_probs = F.softmax(teacher_pred / T, dim=-1)
            distill_loss = F.kl_div(student_logprobs, teacher_probs, reduction="batchmean")

            ce_loss = F.cross_entropy(student_pred, target)
            loss = alpha * distill_loss * (T ** 2) + (1 - alpha) * ce_loss

            optim.zero_grad()
            loss.backward()
            optim.step()
        
        student_model.eval()
        correct, total = 0, 0
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            pred = student_model(input)
            pred = pred.argmax(dim=1)
            correct += (pred == target).sum()
            total += target.numel()
        
        accuracy = 100 * correct / total
        print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
        if accuracy > best_accuracy:
            best_epoch = epoch
            best_accuracy = accuracy
    
    print(f"Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}")

In [9]:
student_model = CNN(inter_channels=[16, 32])
teacher_model = large_model
train_distill(student_model, teacher_model, alpha=.8, T=10)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 43.9%
Epoch 1: accuracy 54.4%
Epoch 2: accuracy 59.8%
Epoch 3: accuracy 63.4%
Epoch 4: accuracy 64.4%
Epoch 5: accuracy 65.4%
Epoch 6: accuracy 67.3%
Epoch 7: accuracy 69.0%
Epoch 8: accuracy 68.8%
Epoch 9: accuracy 70.2%
Epoch 10: accuracy 69.3%
Epoch 11: accuracy 70.7%
Epoch 12: accuracy 71.8%
Epoch 13: accuracy 71.4%
Epoch 14: accuracy 73.0%
Best accuracy 73.0% after epoch 14
