In [None]:
import sys
import math
sys.path.append("..")


import torch
from torch import nn
from torchvision.datasets import MNIST, CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from algorithms.vanilla_distilation import distill
from trainer import Trainer

device = "cuda"

%load_ext autoreload
%autoreload 2

#### Take very simple convolutional architecture

In [2]:
class SimpleConvNet(nn.Module):
    def __init__(
        self,
        n_layers=5,
        n_filters=32,
        input_shape=(1, 28, 28),
        n_classes=10
    ):
        super().__init__()
        self.convs = nn.ModuleList([nn.Conv2d(input_shape[0], n_filters, 3, padding=1)])
        for i in range(n_layers-1):
            self.convs.append(nn.SiLU())
            self.convs.append(nn.Conv2d(n_filters, n_filters, 3, padding=1))
        self.linear = nn.Sequential(
            nn.SiLU(),
            nn.Flatten(),
            nn.Linear(input_shape[1] * input_shape[2] * n_filters, n_classes)
        )
        
    
    def forward(self, x):
        for layer in self.convs:
            x = layer(x)
        return self.linear(x)

#### Initialize teacher, student and another student-like network just for reference numbers

In [3]:
teacher = SimpleConvNet(n_layers=10, n_filters=64, input_shape=(3, 32, 32), n_classes=100).to(device)
student = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_reference = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)

In [4]:
config = {
    'lr': 1e-3,
    'batch_size': 512,
    'max_iters': -1,
    'max_epochs': 7,
    'evaluate_every': 36,
    'log_every': 36,
    'verbose': True,
    "optimizer": "AdamW"
}

#### Load CIFAR100

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR100(root="./data", train=True, download=True, transform=transform)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=transform)

In [6]:
trainer = Trainer(train_config=config)

#### Train teacher and reference student

In [7]:
teacher = trainer.train(
    model=teacher,
    train_loader=DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True),
    val_loader=DataLoader(test_dataset, batch_size=2048, shuffle=False)
)

Training iters:   0%|          | 0/350000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.313e+00, Last validation accuracy: 0.0557, Epoch: 0.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.169e+00, Last validation accuracy: 0.0922, Epoch: 0.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.858e+00, Last validation accuracy: 0.1440, Epoch: 1.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.566e+00, Last validation accuracy: 0.1702, Epoch: 1.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.270e+00, Last validation accuracy: 0.2079, Epoch: 2.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.160e+00, Last validation accuracy: 0.2477, Epoch: 2.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.671e+00, Last validation accuracy: 0.2747, Epoch: 3.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.613e+00, Last validation accuracy: 0.2943, Epoch: 3.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.145e+00, Last validation accuracy: 0.3091, Epoch: 4.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.319e+00, Last validation accuracy: 0.3256, Epoch: 4.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.676e+00, Last validation accuracy: 0.3095, Epoch: 5.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.975e+00, Last validation accuracy: 0.3332, Epoch: 5.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.212e+00, Last validation accuracy: 0.3086, Epoch: 6.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.367e+00, Last validation accuracy: 0.3218, Epoch: 6.737
Loss: 1.426e+00, Last validation accuracy: 0.3218, Epoch: 7.000


In [8]:
student_reference = trainer.train(
    model=student_reference,
    train_loader=DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True),
    val_loader=DataLoader(test_dataset, batch_size=2048, shuffle=False)
)

Training iters:   0%|          | 0/350000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.818e+00, Last validation accuracy: 0.1488, Epoch: 0.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.573e+00, Last validation accuracy: 0.1823, Epoch: 0.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.388e+00, Last validation accuracy: 0.2069, Epoch: 1.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.311e+00, Last validation accuracy: 0.2120, Epoch: 1.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.062e+00, Last validation accuracy: 0.2244, Epoch: 2.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.175e+00, Last validation accuracy: 0.2321, Epoch: 2.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.897e+00, Last validation accuracy: 0.2314, Epoch: 3.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.989e+00, Last validation accuracy: 0.2379, Epoch: 3.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.523e+00, Last validation accuracy: 0.2389, Epoch: 4.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.783e+00, Last validation accuracy: 0.2458, Epoch: 4.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.415e+00, Last validation accuracy: 0.2458, Epoch: 5.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.638e+00, Last validation accuracy: 0.2553, Epoch: 5.737


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.357e+00, Last validation accuracy: 0.2489, Epoch: 6.369


eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.522e+00, Last validation accuracy: 0.2514, Epoch: 6.737
Loss: 2.417e+00, Last validation accuracy: 0.2514, Epoch: 7.000


#### Distill teacher knowledge into student

In [9]:
student = distill(
    teacher=teacher,
    student=student,
    optimizer=torch.optim.AdamW(student.parameters(), lr=config["lr"], betas=(0.8, 0.9)),
    train_loader=DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True),
    test_loader=None,
    iterations=math.ceil(len(train_dataset) / config["batch_size"]) * config["max_epochs"],
    # iterations=len(train_dataset) * config["max_epochs"],
    test_freq=-1,
    alpha=0.6,
    T=2.5
)

Training - 686/686 [██████████████████████████████] ELP: 00:51, accuracy: 0.3396 - 


#### Accuracy for training from scratch

In [13]:
trainer.evaluate(
    DataLoader(test_dataset, batch_size=2048, shuffle=False),
    model=student_reference
)

eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

{'Accuracy': 0.2569}

#### Accuracy for distillation

In [10]:
trainer.evaluate(
    DataLoader(test_dataset, batch_size=2048, shuffle=False),
    model=student
)

eval batch:   0%|          | 0/5 [00:00<?, ?it/s]

{'Accuracy': 0.2831}