* [GitHub](https://github.com/lif31up/knowledge-distillation)
* [GitBook - Knoweldge Distillation](https://lif31up.gitbook.io/blog/transfer-learning-and-knowledge-distillation/distilling-the-knowledge-in-a-neural-network)
* [GitBook - Vision Transformer](https://lif31up.gitbook.io/blog/geometric-learning-and-computer-vision/an-image-is-worth-16x16-words-transformers-for-image-recognition-at-scale)

In [2]:
from torch import nn
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision as tv
from torch.optim import lr_scheduler
import torch.nn.functional as F
import copy

In [3]:
from google.colab import drive
drive.mount('../content/drive')

Drive already mounted at ../content/drive; to attempt to forcibly remount, call drive.mount("../content/drive", force_remount=True).


In [4]:
TEACH_SAVE_TO = "../content/drive/MyDrive/Colab Notebooks/knoweldge_distillation/teacher.bin"
TEACH_LOAD_FROM = TEACH_SAVE_TO
STNDT_SAVE_TO = "../content/drive/MyDrive/Colab Notebooks/knoweldge_distillation/student.bin"
STNDT_LOAD_FROM = STNDT_SAVE_TO

# ViT/DistillViT for MNIST from scratch
This implementation is inspired by:
[Distilling the Knowledge in a Neural Network (2015)](https://arxiv.org/abs/1503.02531) by Geoffrey Hinton, Oriol Vinyals, Jeff Dean.
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (2021)](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (2021)](https://arxiv.org/abs/2010.11929)

Distillation (or Knowledge Distillation) is a model compression technique where a small model is trained to mimic a large, complex model by learning its "thought process (or soft probabilities)". Most large state-of-the-art models are incredibly accurate but come with high costs—computation, memory, and latency. Distillation captures the knowledge inside those models and packs it into a more efficient one.

The Vision Transformer (ViT) attains excellent results when pretrained at sufficient scale and transferred to tasks with fewer datapoints. When pretrained on the public ImageNet-21k dataset or the in-house JFT-300M dataset, ViT approaches or beats state-of-the-art image recognition benchmarks.

- **Task:** Image Recognition
- **Dataset:** MNIST-10

## ViTs (Vision Transformers)
The researchers experimented with applying a standard Transformer directly to images with minimal modifications. They split an image into patches and provide the sequence of linear embeddings of these patches as input to a Transformer. Image patches are treated the same way as tokens in NLP applications. The model is trained on image classification in a supervised fashion.

When trained on mid-sized datasets like ImageNet without strong regularization, these models achieve modest accuracies—a few percentage points below ResNets of comparable size. This seemingly discouraging result is expected: Transformers lack the inductive biases inherent to CNNs, such as translation equivariance and locality. As a result, they don't generalize well when trained on insufficient data.

However, the picture changes when models are trained on larger datasets (14–300M images). Large-scale training trumps inductive bias. The Vision Transformer (ViT) attains excellent results when pretrained at sufficient scale and transferred to tasks with fewer datapoints. When pretrained on the public ImageNet-21k dataset or the in-house JFT-300M dataset, ViT approaches or beats state-of-the-art image recognition benchmarks.

### Method
The model design follows the original Transformer as closely as possible. This intentionally simple setup offers a key advantage: scalable NLP architectures and their efficient implementations can be used almost out of the box.

Figure 1 shows an overview of the model. The standard Transformer receives a 1D sequence of token embeddings as input. To handle 2D images, we reshape the image $x$ into a sequence of flattened 2D patches $x_p$ using embed $\text{embed}(x → x_p)$:

$$
\text{embed}(x \in \mathbb{R}^{H \times W \times C}) = x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}
$$

- $C$ is the number of channels.
- $(H, W)$ is the resolution of the original image (height and width).
- $(P,P)$ is the resolution of each image patch (expressed as $x_p$).
- $N = HW / P^2$ is the resulting number of patches, which also serves as the input sequence length for the Transformer.

Like BERT's `[class]` token, we add a learnable embedding to the start of the embedded patch sequence $(z_0^0 x_{\text{class}})$. Its state at the Transformer encoder's output $(z_L^0)$ becomes the image representation $y$.

1. $z_0 = [ {x}_{\text{class}}; {x}^1_p E; {x}^2_p E, ..., {x}_p^N E ] + E_{\text{pos}} \quad\text{where is }E \in \mathbb{R}^{(P^2 \cdot C) \times D}, E_{\text{pos}} \in \mathbb{R}^{(N + 1) \times D}$
2. $z’_{\mathcal{l}} = \text{MSA}(\text{LN}(z_{\mathcal{l} - 1})) + z_{\mathcal{l - 1}}, \quad\text{where is }\mathcal{l} = 1, ..., L$
3. $z{\mathcal{l}} = \text{MLP}(\text{LN}(z'_{\mathcal{l}})) + z'_{\mathcal{l}}, \quad\text{where is }\mathcal{l} = 1, ..., L$
4. $y = \text{LN}(z_L^0)$

Vision Transformer has much less image-specific inductive bias than CNNs. In CNNs, 2D neighborhood structure and translation equivariance are built into each layer. In ViT, only the MLPs are local and translationally equivariant, while the self-attention layers are global. Other than that, the position embeddings at initialization time carry no information about the 2D positions of the patches and all spatial relations between the patches have to be learned
from scratch.

In [5]:
class ViT(nn.Module):
  def __init__(self, config):
    super(ViT, self).__init__()
    self.config = config
    self.stacks = nn.ModuleList([EncoderStack(self.config) for _ in range(config.n_stacks)])
    self.flatten = nn.Flatten(start_dim=1)
    self.cls = nn.Parameter(torch.zeros(config.dim))
    self.fc = self._get_fc(self.config.dummy).apply(self.config.init_weights)
  # __init__

  def add_cls(self, x):
    cls = self.cls.expand(x.shape[0], 1, -1)
    x = torch.cat([x, cls], dim=1)
    return x
  # add_cls

  def forward(self, x):
    x = self.add_cls(x)
    for stack in self.stacks: x = stack(x)
    return self.fc(self.flatten(x))
  # forward

  def _get_fc(self, dummy):
    with torch.no_grad():
      cls = self.cls.expand(1, -1)
      dummy = torch.cat([dummy, cls], dim=0)
      for stack in self.stacks: dummy = stack(dummy)
    dummy = dummy.flatten(start_dim=0)
    return nn.Linear(dummy.shape[0], self.config.output_dim, bias=self.config.bias)
  # _get_fc
# Transformer

class EncoderStack(nn.Module):
  def __init__(self, config):
    super(EncoderStack, self).__init__()
    self.config = config
    self.mt_attn = MultiHeadAttention(config, mode="scaled")
    self.ffn = nn.ModuleList()
    for _ in range(config.n_hidden):
      self.ffn.append(nn.Linear(config.dim, config.dim, bias=config.bias))
    self.activation, self.ln = nn.GELU(), nn.LayerNorm(config.dim)
    self.dropout = nn.Dropout(config.dropout)

    self.apply(self.config.init_weights)
  # __init__

  def forward(self, x):
    res = x
    x = self.ln(self.mt_attn(x) + res)
    res = x
    for i, layer in enumerate(self.ffn):
      if i != len(self.ffn): x = self.dropout(self.activation(layer(x)))
      else: x = self.dropout(layer(x))
    return self.ln(x + res)
  # forward
# EncoderStack

class MultiHeadAttention(nn.Module):
  def __init__(self, config, mode="scaled"):
    super(MultiHeadAttention, self).__init__()
    assert config.dim % config.n_heads == 0, "Dimension must be divisible by number of heads"
    self.config = config
    self.sqrt_d_k, self.mode = (config.dim // config.n_heads) ** 0.5, mode
    self.w_q, self.w_k = nn.Linear(config.dim, config.dim, bias=config.bias), nn.Linear(config.dim, config.dim, bias=config.bias)
    self.w_v, self.w_o = nn.Linear(config.dim, config.dim, bias=config.bias), nn.Linear(config.dim, config.dim, bias=config.bias)
    self.ln, self.dropout, self.softmax = nn.LayerNorm(config.dim), nn.Dropout(config.attention_dropout), nn.Softmax(dim=1)

    self.apply(self.config.init_weights)
  # __init__

  def forward(self, x, y=None):
    Q = self.w_q(x)
    (K, V) = (self.w_k(x), self.w_v(x)) if self.mode != "cross" else (self.w_k(y), self.w_v(y))
    raw_attn_scores = torch.matmul(Q, K.transpose(-2, -1))
    down_scaled_raw_attn_scores = raw_attn_scores / self.sqrt_d_k
    if self.mode == "masked":
      masked_indices = torch.rand(*down_scaled_raw_attn_scores.shape[:-1], 1) < self.config.mask_prob
      down_scaled_raw_attn_scores[masked_indices] = float("-inf")
    attn_scores = self.softmax(down_scaled_raw_attn_scores)
    attn_scores = self.dropout(attn_scores)
    return self.ln(torch.matmul(attn_scores, V) + x)
  # attn_score
# MultiHeadAttention

In [6]:
class Embedder(Dataset):
  def __init__(self, dataset, config):
    super(Embedder, self).__init__()
    self.dataset, self.config = dataset, config
    self.is_consolidated = False
  # __init__

  def __len__(self): return len(self.dataset)

  def __getitem__(self, item):
    if self.is_consolidated: return self.dataset[item][0], self.dataset[item][1]
    feature, label = self.dataset[item]
    patches = feature.unfold(1, 30, 30).unfold(2, 30, 30).permute(1, 2, 0, 3, 4)
    flatten_patches = torch.reshape(input=patches, shape=(9, -1))
    label = F.one_hot(torch.tensor(label), num_classes=10).float()
    return flatten_patches, label
  # __getitem__

  def consolidate(self):
    buffer = list()
    progression = tqdm(self)
    for feature, label in progression: buffer.append((feature, label))
    self.dataset, self.is_consolidated = buffer, True
    return self
  # consolidate
# Embedder

def load_CIFAR_10(transform, path='./data', trainset_len=1000, testset_len=500):
  # trainset, testset are provided as torch.nn.utils.dataset
  trainset = tv.datasets.MNIST(root=path, train=True, download=True, transform=transform)
  trainset_indices = torch.randperm(trainset.__len__()).tolist()[:trainset_len]
  trainset = Subset(dataset=trainset, indices=trainset_indices)
  testset = tv.datasets.MNIST(root=path, train=False, download=True, transform=transform)
  testset_indices = torch.randperm(testset.__len__()).tolist()[:testset_len]
  testset = Subset(dataset=testset, indices=testset_indices)
  return trainset, testset
#load_CIFAR_10

def get_transform_CIFAR_10(input_size=135):
  return tv.transforms.Compose([
    # 1. Augmentation for better generalization
    tv.transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0)),
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.ColorJitter(brightness=0.1, contrast=0.1),  # If RGB
    # 2. Resize and ToTensor
    tv.transforms.Resize((input_size, input_size)),
    tv.transforms.ToTensor(),
    # 3. Normalization using ImageNet statistics for pre-trained models
    tv.transforms.Normalize(
      mean=[0.485],
      std=[0.229]
    ),
  ])  # TRANSFORM
# get_transform_CIFAR_10

### Configuration and Train

In [7]:
class Config:
  def __init__(self, is_teacher=False):
    self.iters = 50
    self.batch_size = 16
    self.trainset_len, self.testset_len = 10000, 1000
    self.dummy = None

    self.n_heads = 3
    self.n_stacks = 6
    self.n_hidden = 3
    self.dim = 900
    self.output_dim = 10
    self.bias = True

    self.dropout = 0.1
    self.attention_dropout = 0.1
    self.eps = 1e-3
    self.betas = (0.9, 0.98)
    self.epochs = 5
    self.batch_size = 16
    self.lr = 1e-4
    self.alpha = 0.4
    self.clip_grad = False
    self.mask_prob = 0.3
    self.init_weights = init_weights
  # __init__
# Config

def init_weights(m):
  if isinstance(m, nn.Linear):
    nn.init.xavier_uniform_(m.weight)
    if m.bias is not None: nn.init.zeros_(m.bias)
# init_weights

In [8]:
config = Config()

In [9]:
# load dataset, transform from folder
cifar_10_transform = get_transform_CIFAR_10(input_size=90)
trainset, testset = load_CIFAR_10(path='./data', transform=cifar_10_transform, trainset_len=config.trainset_len, testset_len=config.testset_len)

In [10]:
# embed dataset (3 times 3 patches)
trainset = Embedder(dataset=trainset, config=config).consolidate()
config.dummy = trainset.__getitem__(0)[0]
trainset = DataLoader(dataset=trainset, batch_size=config.batch_size)
testset = Embedder(dataset=testset, config=config).consolidate()
testset = DataLoader(dataset=testset, batch_size=config.batch_size)

100%|██████████| 10000/10000 [00:09<00:00, 1088.39it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1187.82it/s]


In [11]:
teacher = ViT(config=config)

In [12]:
def train(model:nn.Module, path: str, config: Config, trainset, device):
  model.to(device)
  model.train()

  # optim, criterion, scheduler
  optim = torch.optim.Adam(model.parameters(), lr=config.lr, eps=config.eps)
  criterion = nn.CrossEntropyLoss()
  scheduler = lr_scheduler.StepLR(optim, step_size=5, gamma=0.1)

  progression = tqdm(range(config.iters))
  for _ in progression:
    for feature, label in trainset:
      feature, label = feature.to(device, non_blocking=True), label.to(device, non_blocking=True)
      pred = model(feature)
      loss = criterion(pred, label)
      optim.zero_grad()
      loss.backward()
      optim.step()
    # for feature label
    scheduler.step()
    progression.set_postfix(loss=loss.item())
  # for in progression

  features = {
    "sate": model.state_dict(),
    "config": config
  } # feature
  torch.save(features, f"{path}")
# train

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train(model=teacher, path=TEACH_SAVE_TO, config=config, trainset=trainset, device=device)

100%|██████████| 50/50 [10:40<00:00, 12.81s/it, loss=0.0268]


### Evaluate the teacher

In [14]:
def evaluate(model, dataset, device):
  model.to(device)
  model.eval()
  correct, n_total = 0, 0
  for feature, label in tqdm(dataset):
    feature, label = feature.to(device, non_blocking=True), label.to(device, non_blocking=True)
    output = model.forward(feature)
    output = torch.softmax(output, dim=-1)
    pred = torch.argmax(input=output, dim=-1)
    label = torch.argmax(input=label, dim=-1)
    for p, l in zip(pred, label):
      if p == l: correct += 1
      n_total += 1
  # for
  print(f"Accuracy: {correct / n_total:.4f}")
# eval

In [15]:
evaluate(model=teacher, dataset=testset, device=device)

100%|██████████| 63/63 [00:00<00:00, 109.89it/s]

Accuracy: 0.9160





## Knowledge Distillation
**Distillation (or Knowledge Distillation)** is a model compression technique where a small model is trained to mimic a large, complex model by learning its *"thought process (or soft probabilities)"*. Most large state-of-the-art models are incredibly accurate but come with high costs—computation, memory, and latency. Distillation captures the knowledge inside those models and packs it into a more efficient one.

- The goal is to create a smaller, faster model that can be deployed on devices with limited computational resources.
- A well-trained distilled model can even surpass the original, as the softened labels provide richer information and act as a regularizer.

### Teacher/Student and Hard/Soft Targets

Distillation relies on key concepts: soft and hard targets, and the teacher-student model relationship:

- The **teacher model** is a large, pre-trained, highly accurate model. It's already an expert on the task and is typically provided as frozen.
- The **student model** is a smaller, more efficient (with fewer layers or parameters) model.

**Hard targets** are the ground-truth labels from the dataset. They provide only a single correct answer and don't convey relationships between classes—for example, that a cat is more similar to a dog than to an airplane.

**Soft targets** are the output probabilities from the teacher model's final softmax layer. A well-trained teacher looking at a cat might output probabilities like `["airplane": 1e-5, "cat": 0.90, "dog": 0.05]`. This reveals a small chance it could be a dog, but certainly not an airplane. This nuanced information is called *"dark knowledge"*, and it's what the student model learns from.

### Temperature-Scaled Softmax

To make the soft targets more informative, distillation uses a **temperature parameter** $T$ in the softmax function:

- The standard softmax function is $P_i = \frac{ \exp{(z_i)} }{ \sum_{j}{\exp{(z_j)}} }$.
- The temperature-scaled softmax is $P_i = \frac{ \exp{(z_i / T)} }{ \sum_{j}{\exp{(z_j / T})} }$.

There are some attributes of $T$:

- When $T = 1$, the function behaves as standard softmax.
- When $T > 1$, it softens the probability distribution, making it less extreme (for example, `[0.05, 0.90, 0.05]` might become `[0.15, 0.70, 0.15]`).
- During training, $T$ is set to a high value (like $3$ or $4$), then reset to $1$ for final inference.

Softened probabilities provide much more information—they amplify the small differences between non-target classes, making it easier for the student to learn the relationships the teacher has discovered.

### Understanding Total Loss

Unlike typical deep learning training, the student model is trained using a combined loss function that balances two distributions:

- The **distillation loss** measures how closely the student's soft predictions match the teacher's soft targets.
- The **student loss (or hard target loss)** measures how well the student's predictions match the original ground-truth labels.

This combined loss is called *“total loss”* and calculated via:

$$
\text{Total Loss} = \alpha \cdot \mathcal{L}_{\text{distillation}} + (1 - \alpha) \cdot \mathcal{L}_{\text{student}}
$$
$$
\mathcal{L}_{\text{distillation}}(P_{\text{soft}} | Q_{\text{soft}}) = \sum_{i}{P_{\text{soft}}(i) \log{\big{(} \frac{P_{\text{soft}}{(i)}}{Q_{ \text{soft}}{(i)}} \big{)}}}
$$
$$
\mathcal{L}_{\text{student}}(P|Q_{\text{hard}}) = \text{CrossEntropyLoss}
$$

- $Q_{\text{soft}}$ represents soft targets, while $Q_{\text{hard}}$ represents hard targets.
- $\alpha$ is a hyperparameter that balances the two losses.
- In this formula, $\mathcal{L}_{\text{distillation}}$ uses **KL Divergence**: $\sum_{i}{P(i) \log{\big( \frac{P(i)}{Q(i)} \big{)}}}$

During training, the combined loss procedure works as follows:

1. Pass an input through the teacher model (which is frozen) to get $Q_{\text{soft}}$—theses are softend using temperature param $T$.
2. Pass the same input through the student model to get its predictions—these are also softened using $T$.
3. Combined the two losses:
    1. Calculate $\mathcal{L}_{\text{distillation}}$ between the student's and teacher's softened predictions.
    2. Calculate $\mathcal{L}_{\text{student}}$ between the student's output (at $T = 1$, not softened) and the ground-truth labels (hard targets).
4. Backpropagate the total loss and update the student's weights.

In [16]:
student_config = copy.deepcopy(config)
student_config.n_stacks = 3
student_config.temperature = 2.0
student_config.alpha = 0.35

In [17]:
def distillate(student, teacher, dataset, config, path, device):
  student.to(device)
  teacher.to(device)
  student.train()
  teacher.eval()

  # optim, criterion, scheduler
  optim = torch.optim.Adam(student.parameters(), lr=config.lr, eps=config.eps)
  criterion = nn.CrossEntropyLoss()
  scheduler = lr_scheduler.StepLR(optim, step_size=5, gamma=0.1)

  progression = tqdm(range(config.iters))
  for _ in progression:
    for feature, label in dataset:
      feature, label = feature.to(device), label.to(device)
      soft_label = F.softmax(teacher(feature) / config.temperature, dim=-1)
      output = student(feature)
      distill_loss = criterion(output, soft_label)
      student_loss = criterion(output, label)
      loss = (config.alpha * distill_loss) + (1 - config.alpha) * student_loss
      optim.zero_grad()
      loss.backward()
      optim.step()
    # for feature, label
    scheduler.step()
    progression.set_postfix(loss=loss.item())

    features = {
      "sate": student.state_dict(),
      "config": config
    } # feature
    torch.save(features, f"{path}")
  # fpr


In [18]:
student = ViT(student_config)

In [19]:
teacher_data = torch.load(f=TEACH_LOAD_FROM, map_location=torch.device('cpu'), weights_only=False)
teacher = ViT(config)
teacher.load_state_dict(teacher_data['sate'])

<All keys matched successfully>

In [20]:
distillate(student=student, teacher=teacher, dataset=trainset, config=student_config, path=STNDT_SAVE_TO, device=device)

100%|██████████| 50/50 [14:37<00:00, 17.55s/it, loss=0.181]


In [21]:
evaluate(model=student, dataset=testset, device=device)

100%|██████████| 63/63 [00:00<00:00, 260.26it/s]

Accuracy: 0.9210



