In [1]:
import torch
import torch.nn as nn

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
batch_size = 512 if device.type == "cuda" else 128
lr = 0.01
epochs = 20

In [4]:
train_data = datasets.MNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [5]:
class Teacher(nn.Module):
  def __init__(self, dropout=0.2):
    super().__init__()
    self.model = nn.Sequential(
      nn.Dropout(p=dropout),
      nn.Linear(784, 1024),
      nn.ReLU(),
      nn.Linear(1024, 1024),
      nn.ReLU(),
      nn.Linear(1024, 10),
    )

  def forward(self, x):
    x = torch.flatten(x, 1)
    return self.model(x)

In [6]:
teacher = Teacher().to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
loss_fn = nn.CrossEntropyLoss()

In [7]:
acc = -1.
for e in range(epochs):
  teacher.train()
  for x, y in (pbar := tqdm(train_dataloader)):
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad(set_to_none=True)

    out = teacher(x)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    pbar.set_description(f"Epoch {e+1}/{epochs} Loss: {loss.item():.4f}")
    pbar.set_postfix(acc=acc)

  teacher.eval()
  with torch.no_grad():
    total = 0
    correct = 0
    for x, y in test_dataloader:
      x, y = x.to(device), y.to(device)
      out = teacher(x)
      _, pred = torch.max(out, dim=1)
      total += y.size(0)
      correct += (pred == y).sum().item()
  acc = (correct / total)
  scheduler.step()
print(f"Final accuracy: {acc*100:.2f}%")

Epoch 1/20 Loss: 0.1984: 100%|██████████| 118/118 [00:12<00:00,  9.50it/s, acc=-1]
Epoch 2/20 Loss: 0.2688: 100%|██████████| 118/118 [00:09<00:00, 11.92it/s, acc=0.958]
Epoch 3/20 Loss: 0.0266: 100%|██████████| 118/118 [00:09<00:00, 12.86it/s, acc=0.964]
Epoch 4/20 Loss: 0.1111: 100%|██████████| 118/118 [00:08<00:00, 13.43it/s, acc=0.97]
Epoch 5/20 Loss: 0.0580: 100%|██████████| 118/118 [00:09<00:00, 12.15it/s, acc=0.974]
Epoch 6/20 Loss: 0.0985: 100%|██████████| 118/118 [00:09<00:00, 12.14it/s, acc=0.973]
Epoch 7/20 Loss: 0.0300: 100%|██████████| 118/118 [00:10<00:00, 11.55it/s, acc=0.976]
Epoch 8/20 Loss: 0.0457: 100%|██████████| 118/118 [00:09<00:00, 12.11it/s, acc=0.977]
Epoch 9/20 Loss: 0.1548: 100%|██████████| 118/118 [00:08<00:00, 13.33it/s, acc=0.979]
Epoch 10/20 Loss: 0.0028: 100%|██████████| 118/118 [00:09<00:00, 13.07it/s, acc=0.979]
Epoch 11/20 Loss: 0.0709: 100%|██████████| 118/118 [00:09<00:00, 12.00it/s, acc=0.979]
Epoch 12/20 Loss: 0.0005: 100%|██████████| 118/118 [00:0

Final accuracy: 98.23%


In [8]:
class Student(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
      nn.Linear(784, 32),
      nn.ReLU(),
      nn.Linear(32, 10),
    )

  def forward(self, x):
    x = torch.flatten(x, 1)
    return self.model(x)

In [9]:
student = Student().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

In [10]:
acc = -1.
for e in range(epochs):
  student.train()
  for x, y in (pbar := tqdm(train_dataloader)):
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad(set_to_none=True)

    out = student(x)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    pbar.set_description(f"Epoch {e+1}/{epochs} Loss: {loss.item():.4f}")
    pbar.set_postfix(acc=acc)

  student.eval()
  with torch.no_grad():
    total = 0
    correct = 0
    for x, y in test_dataloader:
      x, y = x.to(device), y.to(device)
      out = student(x)
      _, pred = torch.max(out, dim=1)
      total += y.size(0)
      correct += (pred == y).sum().item()
  acc = (correct / total)
  scheduler.step()
print(f"Final accuracy: {acc*100:.2f}%")

Epoch 1/20 Loss: 0.3278: 100%|██████████| 118/118 [00:08<00:00, 13.17it/s, acc=-1]
Epoch 2/20 Loss: 0.2085: 100%|██████████| 118/118 [00:09<00:00, 12.37it/s, acc=0.934]
Epoch 3/20 Loss: 0.1641: 100%|██████████| 118/118 [00:09<00:00, 12.48it/s, acc=0.946]
Epoch 4/20 Loss: 0.2248: 100%|██████████| 118/118 [00:09<00:00, 12.53it/s, acc=0.951]
Epoch 5/20 Loss: 0.0870: 100%|██████████| 118/118 [00:08<00:00, 13.72it/s, acc=0.961]
Epoch 6/20 Loss: 0.1375: 100%|██████████| 118/118 [00:08<00:00, 13.12it/s, acc=0.959]
Epoch 7/20 Loss: 0.0729: 100%|██████████| 118/118 [00:09<00:00, 12.36it/s, acc=0.963]
Epoch 8/20 Loss: 0.2365: 100%|██████████| 118/118 [00:09<00:00, 12.54it/s, acc=0.962]
Epoch 9/20 Loss: 0.0217: 100%|██████████| 118/118 [00:09<00:00, 12.41it/s, acc=0.966]
Epoch 10/20 Loss: 0.1790: 100%|██████████| 118/118 [00:08<00:00, 13.39it/s, acc=0.966]
Epoch 11/20 Loss: 0.1003: 100%|██████████| 118/118 [00:09<00:00, 13.05it/s, acc=0.967]
Epoch 12/20 Loss: 0.0118: 100%|██████████| 118/118 [00:

Final accuracy: 96.91%


In [11]:
teacher.eval()
student = Student().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

kl_div_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

temperature = 2
soft_targets_weight = temperature ** 2
hard_targets_weight = 0.5

In [12]:
acc = -1.
for e in range(epochs):
  student.train()
  for x, y in (pbar := tqdm(train_dataloader)):
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad(set_to_none=True)

    with torch.no_grad():
      teacher_logits = teacher(x)
    out = student(x)

    soft_targets = (teacher_logits / temperature).log_softmax(dim=-1)
    soft_prob = (out / temperature).log_softmax(dim=-1)

    soft_targets_loss = kl_div_loss(soft_prob, soft_targets)
    hard_targets_loss = loss_fn(out, y)

    loss = soft_targets_weight * soft_targets_loss + hard_targets_weight * hard_targets_loss
    loss.backward()
    optimizer.step()

    pbar.set_description(f"Epoch {e+1}/{epochs} Loss: {loss.item():.4f}")
    pbar.set_postfix(acc=acc)

  student.eval()
  if e % 5 == 4:
    with torch.no_grad():
      total = 0
      correct = 0
      for x, y in test_dataloader:
        x, y = x.to(device), y.to(device)
        out = student(x)
        _, pred = torch.max(out, dim=1)
        total += y.size(0)
        correct += (pred == y).sum().item()
  acc = (correct / total)
  scheduler.step()
print(f"Final accuracy: {acc*100:.2f}%")

Epoch 1/20 Loss: 0.9064: 100%|██████████| 118/118 [00:09<00:00, 13.06it/s, acc=-1]
Epoch 2/20 Loss: 0.6239: 100%|██████████| 118/118 [00:09<00:00, 12.42it/s, acc=0.969]
Epoch 3/20 Loss: 0.4436: 100%|██████████| 118/118 [00:09<00:00, 12.24it/s, acc=0.969]
Epoch 4/20 Loss: 0.4051: 100%|██████████| 118/118 [00:08<00:00, 13.58it/s, acc=0.969]
Epoch 5/20 Loss: 0.7404: 100%|██████████| 118/118 [00:09<00:00, 12.31it/s, acc=0.969]
Epoch 6/20 Loss: 0.3185: 100%|██████████| 118/118 [00:09<00:00, 12.14it/s, acc=0.965]
Epoch 7/20 Loss: 0.3779: 100%|██████████| 118/118 [00:09<00:00, 12.36it/s, acc=0.965]
Epoch 8/20 Loss: 0.1877: 100%|██████████| 118/118 [00:08<00:00, 13.64it/s, acc=0.965]
Epoch 9/20 Loss: 0.3368: 100%|██████████| 118/118 [00:09<00:00, 12.26it/s, acc=0.965]
Epoch 10/20 Loss: 0.2357: 100%|██████████| 118/118 [00:09<00:00, 12.34it/s, acc=0.965]
Epoch 11/20 Loss: 0.2943: 100%|██████████| 118/118 [00:09<00:00, 12.83it/s, acc=0.968]
Epoch 12/20 Loss: 0.4608: 100%|██████████| 118/118 [00:

Final accuracy: 97.08%
