# Assignment 3. Knowledge Distillation

## Goals

이 실습의 목적은 **Knowledge Distillation**을 활용하여, 작은 모델(Student)이 큰 모델(Teacher)의 지식을 효과적으로 학습하는 방법을 이해하고 실험을 통해 비교하는 것입니다.


## Contents

1. **Baseline 학습 (Cross-Entropy Loss)**
    - Teacher 모델과 Student 모델을 각각 Cross-Entropy Loss만으로 학습시켜 정확도를 비교합니다.
2. **Knowledge Distillation (Soft Targets)**
    - Teacher의 softmax 출력을 활용한 Knowledge Distillation을 적용하고, temperature 및 loss weight에 따른 영향을 분석합니다.
3. **Cosine Loss Minimization (Cosine Loss)**  
    - Teacher와 Student의 convolutional feature를 추출하여, CosineEmbeddingLoss를 적용해 내부 표현 유사도를 증가시키는 방식으로 학습합니다.
4. **Intermediate Regressor (Regressor + MSE)**
    - Teacher의 feature map과 Student의 regressed feature map을 MSE로 정렬하며, 중간 표현을 직접 학습합니다.

# Environment Setup

본 실습에서는 PyTorch와 Torchvision을 활용하여 Knowledge Distillation을 구현합니다. 먼저 필요한 라이브러리를 import하고, 실행 환경(GPU/CPU)을 설정합니다.

## Import Modules
- `torch`, `torch.nn`, `torch.optim`: PyTorch의 핵심 기능 및 신경망, 최적화 알고리즘
- `torchvision.transforms`, `torchvision.datasets`: CIFAR-10 데이터셋 로딩 및 전처리를 위한 모듈
- `collections.OrderedDict`: 이후에 모델 구조 정의 시 순서를 보장하기 위한 dict 구조

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from collections import OrderedDict

## Data Loading: CIFAR-10

본 실습에서는 CIFAR-10 데이터셋을 사용하여 Knowledge Distillation의 효과를 검증합니다. CIFAR-10은 10개의 클래스로 구성된 32x32 크기의 컬러 이미지 데이터셋입니다.

In [2]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='D:\\data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='D:\\data', train=False, download=True, transform=transforms_cifar)

# Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to D:\data\cifar-10-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████████| 170M/170M [00:11<00:00, 15.3MB/s]


Extracting D:\data\cifar-10-python.tar.gz to D:\data
Files already downloaded and verified


In [3]:
import random
import numpy as np
import os

def set_seed(seed=44):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

# 시드 고정
def get_train_loader(train_dataset, seed=44):
    set_seed(seed)
    def seed_worker(worker_id):
        worker_seed = seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(seed)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True,
                                               num_workers=0, generator=g, worker_init_fn=seed_worker)
    return train_loader

## Load Pretrained Model Weights (VGG on CIFAR-10)

Knowledge Distillation에서 중요한 전제는 **강력한 성능을 가진 Teacher 모델**이 존재한다는 것입니다. 본 코드에서는 사전에 학습된 VGG 모델의 가중치를 불러와 Teacher 모델로 사용할 준비를 합니다.

In [4]:
os.environ["TORCH_HOME"] = "D:\\data"
state_dict_url = "https://github.com/SKKU-ESLAB/pytorch-models/releases/download/samsung/vgg.cifar.pretrained.pth"
state_dict = torch.hub.load_state_dict_from_url(state_dict_url, progress=True)
state_dict = state_dict["state_dict"]

Downloading: "https://github.com/SKKU-ESLAB/pytorch-models/releases/download/samsung/vgg.cifar.pretrained.pth" to D:\data\hub\checkpoints\vgg.cifar.pretrained.pth
100%|█████████████████████████████████████████████████████████████████████████████| 35.2M/35.2M [00:04<00:00, 8.88MB/s]


## Define Teacher and Student Models

Knowledge Distillation 실험을 위해 두 개의 모델 구조를 정의합니다. 두 모델은 VGG 스타일의 CNN 구조를 기반으로 하며, **Teacher (VGGCifar9)** 모델은 더 깊고 복잡한 구조, **Student (VGGCifar5)** 모델은 간단한 구조로 설계되어 있습니다.

In [5]:
class VGGCifar9(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv4', nn.Conv2d(256, 512, 3, padding=1, bias=False)),
            ('bn4', nn.BatchNorm2d(512)),
            ('relu4', nn.ReLU(True)),
            ('conv5', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn5', nn.BatchNorm2d(512)),
            ('relu5', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv6', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn6', nn.BatchNorm2d(512)),
            ('relu6', nn.ReLU(True)),
            ('conv7', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn7', nn.BatchNorm2d(512)),
            ('relu7', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x

class VGGCifar5(nn.Module):
    def __init__(self) -> None:
        # Generate the same scratch model
        set_seed()
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.classifier = nn.Linear(256, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x

# 3.1. Baseline 학습 (Cross-Entropy Loss)

## Train & Test Functions

학습 및 검증은 아래 두 함수로 수행되며, Student/Teacher 모델 모두 동일한 루프 구조를 따릅니다.

In [6]:
def train(model,
          train_loader,
          epochs,
          learning_rate):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model,
         test_loader):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # for multiple outputs
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

## Load & Evaluate Teacher Model

Knowledge Distillation의 기준이 되는 **Teacher 모델(VGGCifar9)** 을 초기화하고, 앞서 불러온 pretrained 가중치를 적용합니다. 이후, CIFAR-10 test set에 대해 성능을 평가합니다.

In [7]:
teacher_model = VGGCifar9().cuda()
teacher_model.load_state_dict(state_dict)
test_accuracy_teacher = test(teacher_model, test_loader)

Test Accuracy: 92.95%


## 모델 초기화 일관성 확인

Knowledge Distillation 실험에서 **공정한 비교**를 위해 Student 모델의 초기화 상태가 동일한지 확인하는 과정입니다. 여기서는 동일한 구조를 가진 두 개의 `VGGCifar5` 모델을 생성하고, 첫 번째 convolution layer의 weight norm을 비교합니다.

In [8]:
# Print the norm of the first layer of the initial lightweight model
student_model = VGGCifar5().cuda()
print("Norm of 1st layer of student_model:", torch.norm(student_model.backbone[0].weight).item())

# Print the norm of the first layer of the new lightweight model
student_model_2 = VGGCifar5().cuda()
print("Norm of 1st layer of student_model_2:", torch.norm(student_model_2.backbone[0].weight).item())

Norm of 1st layer of student_model: 4.603283882141113
Norm of 1st layer of student_model_2: 4.603283882141113


## 모델 파라미터 수 비교

Teacher와 Student 모델 간의 **복잡도 차이**를 수치적으로 비교하기 위해 전체 파라미터 개수를 출력합니다. 이는 Knowledge Distillation의 핵심 가정인 "*성능은 높지만 무거운 Teacher → 가볍고 빠른 Student로 지식 이전*"을 정량적으로 뒷받침하는 자료가 됩니다.

In [9]:
total_params_teacher = "{:,}".format(sum(p.numel() for p in teacher_model.parameters()))
print(f"Teaher model parameters: {total_params_teacher}")
total_params_student = "{:,}".format(sum(p.numel() for p in student_model.parameters()))
print(f"Student model parameters: {total_params_student}")

Teaher model parameters: 9,228,362
Student model parameters: 964,170


## Student 모델 단독 학습 (Cross-Entropy Only)

본 단계에서는 Student 모델을 **Teacher의 도움 없이** 단독으로 학습시킵니다. 이 실험은 이후 Knowledge Distillation을 적용했을 때 얼마나 성능이 향상되는지를 비교하기 위한 **Baseline** 성능을 확보하는 과정입니다.

In [10]:
student_model = VGGCifar5().cuda()
train(student_model, get_train_loader(train_dataset), epochs=5, learning_rate=0.01)
test_accuracy_student_ce = test(student_model, test_loader)

Epoch 1/5, Loss: 1.3603065112972503
Epoch 2/5, Loss: 0.8317949570658262
Epoch 3/5, Loss: 0.5934104708302052
Epoch 4/5, Loss: 0.39868114389421994
Epoch 5/5, Loss: 0.26499412363142616
Test Accuracy: 82.44%


## 정확도 결과 요약

In [11]:
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")

Teacher accuracy: 92.95%
Student accuracy without teacher: 82.44%


# 3.2. Knowledge Distillation (Soft Targets)

## [실습 1] Knowledge Distillation 학습 함수 정의

아래 함수는 **Teacher 모델의 soft output**을 활용하여 Student 모델을 지도 학습하는 Knowledge Distillation (KD) 학습 루프입니다.  
기존 Cross-Entropy 학습에 더해, soft target을 이용한 추가적인 loss를 도입하여 Student가 Teacher의 예측 구조까지 학습할 수 있도록 합니다.


### KD 핵심 개념

- **Soft Targets**: Teacher의 출력(logits)을 softmax로 부드럽게 만든 확률 분포
- **Temperature (T)**: softmax 분포의 평탄함을 제어하며, 높을수록 클래스 간 정보가 더 많이 보존됨
- **Loss 조합**:
  - `CrossEntropyLoss`: Ground-truth label 기반 지도 손실
  - `KL-like Loss`: Teacher의 soft target 분포와 Student 예측 분포 간 차이를 최소화하는 손실
  - 두 손실을 weighted sum으로 조합

In [12]:
def train_knowledge_distillation(teacher,
                                 student,
                                 train_loader,
                                 epochs,
                                 learning_rate,
                                 T,  # temperature
                                 soft_target_loss_weight,
                                 ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()

            ##################### YOUR CODE STARTS HERE #####################
            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the student logits by applying softmax
            # Hint: nn.functional.softmax()
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            student_prob = nn.functional.softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - student_prob.log())) / student_prob.size(0) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
            ##################### YOUR CODE ENDS HERE #######################

            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

## Knowledge Distillation 학습 수행

앞서 정의한 `train_knowledge_distillation()` 함수를 이용하여, Teacher 모델의 soft prediction을 기반으로 Student 모델을 학습시킵니다. 이후, 테스트 정확도를 측정하여 **기존 CE-only 학습과의 성능 차이**를 비교합니다.

In [13]:
student_model = VGGCifar5().cuda()
train_knowledge_distillation(teacher=teacher_model,
                             student=student_model,
                             train_loader=get_train_loader(train_dataset),
                             epochs=5,
                             learning_rate=0.01,
                             T=10,
                             soft_target_loss_weight=0.5,
                             ce_loss_weight=0.5)

test_accuracy_student_ce_and_kd = test(student_model, test_loader)

Epoch 1/5, Loss: 7.115619836256022
Epoch 2/5, Loss: 3.4852485614054647
Epoch 3/5, Loss: 2.233771598857382
Epoch 4/5, Loss: 1.5219025212480588
Epoch 5/5, Loss: 1.1669407553989868
Test Accuracy: 85.43%


## 정확도 결과 요약

In [15]:
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")

Teacher accuracy: 92.95%
Student accuracy without teacher: 82.44%
Student accuracy with CE + KD: 85.43%


# 3.3. Cosine Loss Minimization (Cosine Loss)

## Cosine Similarity 기반 KD 모델 정의

본 실험에서는 **Teacher와 Student 모델의 내부 표현(hidden representation)** 을 정렬하여 학습 효과를 높이고자 합니다.  
이를 위해 기존 VGG 구조를 변형하여 **flatten된 feature representation을 반환하는** 모델을 정의합니다.

In [16]:
class VGGCifar9_Cosine(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv4', nn.Conv2d(256, 512, 3, padding=1, bias=False)),
            ('bn4', nn.BatchNorm2d(512)),
            ('relu4', nn.ReLU(True)),
            ('conv5', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn5', nn.BatchNorm2d(512)),
            ('relu5', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv6', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn6', nn.BatchNorm2d(512)),
            ('relu6', nn.ReLU(True)),
            ('conv7', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn7', nn.BatchNorm2d(512)),
            ('relu7', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        conv_output = torch.flatten(x, 1)
        conv_output_after_pooling = torch.nn.functional.avg_pool1d(conv_output, 2)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x, conv_output_after_pooling

class VGGCifar5_Cosine(nn.Module):
    def __init__(self) -> None:
        # Generate the same scratch model
        set_seed()
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.classifier = nn.Linear(256, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        conv_output = torch.flatten(x, 1)
        x = x.mean([2, 3])
        # conv_output = x
        x = self.classifier(x)
        return x, conv_output

## Cosine Loss 기반 KD를 위한 모델 초기화

앞서 정의한 **representation-level KD 구조**를 활용하여 Teacher와 Student 모델을 초기화합니다.  
Teacher는 기존 `VGGCifar9`의 학습된 가중치를 그대로 활용하며, Student는 새로 초기화하여 학습을 수행할 준비를 합니다.


In [17]:
teacher_model_cosine = VGGCifar9_Cosine().cuda()
teacher_model_cosine.load_state_dict(state_dict)

student_model_cosine = VGGCifar5_Cosine().cuda()

## Cosine Distillation을 위한 Representation 차원 확인

CosineEmbeddingLoss를 적용하기 위해서는 Teacher와 Student 모델이 반환하는 **hidden representation**이 동일한 차원을 가져야 합니다.  
아래 코드는 임의의 입력(batch) 데이터를 각각의 모델에 통과시켜, 출력되는 **logits**과 **flatten된 convolutional feature vector**의 shape을 출력합니다.

In [18]:
# Create a sample input tensor
sample_input = torch.randn(128, 3, 32, 32).cuda() # Batch size: 128, Filters: 3, Image size: 32x32

# Pass the input through the student
logits, hidden_representation = student_model_cosine(sample_input)

# Print the shapes of the tensors
print("Student logits shape:", logits.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Pass the input through the teacher
logits, hidden_representation = teacher_model_cosine(sample_input)

# Print the shapes of the tensors
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

Student logits shape: torch.Size([128, 10])
Student hidden representation shape: torch.Size([128, 1024])
Teacher logits shape: torch.Size([128, 10])
Teacher hidden representation shape: torch.Size([128, 1024])


## [실습 2] Cosine Similarity 기반 KD 학습 함수 정의

이 함수는 **Teacher와 Student의 내부 표현(hidden representation)** 간 유사도를 **CosineEmbeddingLoss**를 통해 극대화하는 방식으로 Student를 학습시킵니다.  
이는 soft-label 기반 KD와 달리 **representation-level distillation**으로 분류되며, Student의 feature extractor 품질 향상에 초점을 둡니다.

In [19]:
def train_cosine_loss(teacher,
                      student,
                      train_loader,
                      epochs,
                      learning_rate,
                      hidden_rep_loss_weight,
                      ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()

            ##################### YOUR CODE STARTS HERE #####################
            # Forward pass with the teacher model and keep only the hidden representation
            with torch.no_grad():
                _, teacher_hidden_representation = teacher(inputs)

            # Forward pass with the student model
            student_logits, student_hidden_representation = student(inputs)

            # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is
            # the case where loss minimization leads to cosine similarity increase.
            # Hint: cosine_loss(x, y, target)에서 target은 1로 이루어진 vector이며, torch.ones(inputs.size(0)).cuda())를 사용
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation,
                                          target=torch.ones(inputs.size(0)).cuda())

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
            ##################### YOUR CODE ENDS HERE #######################

            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

## Cosine Similarity 기반 Knowledge Distillation 실험

이 실험에서는 Teacher와 Student의 **internal feature vector** 간 유사도를 기반으로 하는 **CosineEmbeddingLoss**를 적용하여 Student 모델을 학습시킵니다.  
이 방법은 soft label이 아닌, feature-level에서의 표현력 유사성을 유도하는 방식으로 distillation을 수행합니다.

In [20]:
# Train and test the lightweight network with cross entropy loss
student_model_cosine = VGGCifar5_Cosine().cuda()
train_cosine_loss(teacher=teacher_model_cosine,
                  student=student_model_cosine,
                  train_loader=get_train_loader(train_dataset),
                  epochs=5,
                  learning_rate=0.01,
                  hidden_rep_loss_weight=0.5,
                  ce_loss_weight=0.5)

test_accuracy_student_ce_and_cosine_loss = test(student_model_cosine, test_loader)

Epoch 1/5, Loss: 0.8624538840235346
Epoch 2/5, Loss: 0.5542182272200085
Epoch 3/5, Loss: 0.4163331224028107
Epoch 4/5, Loss: 0.31045055736208815
Epoch 5/5, Loss: 0.2399151202129281
Test Accuracy: 82.91%


## 정확도 결과 요약

In [22]:
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_student_ce_and_cosine_loss:.2f}%")

Teacher accuracy: 92.95%
Student accuracy without teacher: 82.44%
Student accuracy with CE + KD: 85.43%
Student accuracy with CE + CosineLoss: 82.91%


# 3.4. Intermediate Regressor (Regressor + MSE)

## Feature Map Shape 비교

Hint-based Knowledge Distillation에서는 Teacher와 Student의 **중간 feature map**을 정렬(MSE 등)하기 위해,  
서로의 **convolutional output shape**을 일치시키거나 맞춰주는 작업이 필요합니다.  
아래 코드는 이 과정을 준비하기 위해 각 모델의 **backbone 출력 형태**를 비교합니다.

In [23]:
# Pass the sample input only from the convolutional feature extractor
convolutional_fe_output_student = student_model.backbone(sample_input)
convolutional_fe_output_teacher = teacher_model.backbone(sample_input)

# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

Student's feature extractor output shape:  torch.Size([128, 256, 2, 2])
Teacher's feature extractor output shape:  torch.Size([128, 512, 2, 2])


## Hint-based KD를 위한 Regressor 포함 모델 정의

이 실험에서는 Teacher의 중간 feature map과 Student의 feature map을 직접 정렬하기 위해 **trainable regressor**를 도입합니다.  
Teacher와 Student의 convolution output은 채널 수가 다르기 때문에, Student의 feature map을 **regressor**를 통해 변환하여 동일한 차원으로 맞춥니다.

In [24]:
class VGGCifar9_Regressor(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv4', nn.Conv2d(256, 512, 3, padding=1, bias=False)),
            ('bn4', nn.BatchNorm2d(512)),
            ('relu4', nn.ReLU(True)),
            ('conv5', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn5', nn.BatchNorm2d(512)),
            ('relu5', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv6', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn6', nn.BatchNorm2d(512)),
            ('relu6', nn.ReLU(True)),
            ('conv7', nn.Conv2d(512, 512, 3, padding=1, bias=False)),
            ('bn7', nn.BatchNorm2d(512)),
            ('relu7', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        conv_feature_map = x
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x, conv_feature_map

class VGGCifar5_Regressor(nn.Module):
    def __init__(self) -> None:
        # Generate the same scratch model
        set_seed()
        super().__init__()
        self.backbone = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, 64, 3, padding=1, bias=False)),
            ('bn0', nn.BatchNorm2d(64)),
            ('relu0', nn.ReLU(True)),
            ('pool0', nn.MaxPool2d(2)),
            ('conv1', nn.Conv2d(64, 128, 3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d(2)),
            ('conv2', nn.Conv2d(128, 256, 3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2', nn.ReLU(True)),
            ('pool2', nn.MaxPool2d(2)),
            ('conv3', nn.Conv2d(256, 256, 3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(256)),
            ('relu3', nn.ReLU(True)),
            ('pool3', nn.MaxPool2d(2)),
        ]))
        self.regressor = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512)
        )
        self.classifier = nn.Linear(256, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        regressor_output = self.regressor(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x, regressor_output

## Hint-based KD용 Teacher 모델 초기화 및 가중치 로딩

Hint-based Knowledge Distillation에서는 Teacher 모델의 중간 feature map을 **지도 정보로 활용**합니다.  
이를 위해 기존 학습된 `VGGCifar9`의 가중치를 기반으로 하되, **feature map 반환이 가능한 구조**로 변경한 `VGGCifar9_Regressor` 모델을 사용합니다.

In [25]:
teacher_model_reg = VGGCifar9_Regressor().cuda()
teacher_model_reg.load_state_dict(state_dict)

<All keys matched successfully>

## [실습 3] Hint-based Knowledge Distillation 학습 함수 정의 (MSE Loss 기반)

이 함수는 **중간 feature map**을 기준으로 Teacher와 Student의 표현을 정렬하기 위해 **Mean Squared Error (MSE) Loss**를 사용하는 Hint-based KD 학습 방식입니다.  
이를 통해 Student의 feature extractor가 Teacher의 중간 표현력을 모방하도록 유도합니다.


**학습 개념 요약**

| 손실 종류      | 역할                                       |
|----------------|--------------------------------------------|
| CrossEntropy   | 정답 라벨 기반 분류 학습                   |
| MSE Loss       | Teacher feature map ↔ Student regressed feature map 간 오차 최소화 |

In [26]:
def train_mse_loss(teacher,
                   student,
                   train_loader,
                   epochs,
                   learning_rate,
                   feature_map_weight,
                   ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()

            ##################### YOUR CODE STARTS HERE #####################
            # Again ignore teacher logits
            with torch.no_grad():
                _, teacher_feature_map = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss
            ##################### YOUR CODE ENDS HERE #######################

            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

## Hint-based KD (Regressor + MSE Loss) 실험

이 실험에서는 Student가 Teacher의 **중간 feature map**을 직접 모방하도록 유도하는 **Hint-based Knowledge Distillation**을 수행합니다.  
이를 위해 Student에 **trainable regressor layer**를 도입하고, Teacher의 표현을 **MSE Loss**로 정렬합니다.

In [27]:
student_model_reg = VGGCifar5_Regressor().cuda()
train_mse_loss(teacher=teacher_model_reg,
               student=student_model_reg,
               train_loader=get_train_loader(train_dataset),
               epochs=5,
               learning_rate=0.01,
               feature_map_weight=0.5,
               ce_loss_weight=0.5)

test_accuracy_student_ce_and_mse_loss = test(student_model_reg, test_loader)

Epoch 1/5, Loss: 0.7331331200764307
Epoch 2/5, Loss: 0.4394583729526881
Epoch 3/5, Loss: 0.3140947205559982
Epoch 4/5, Loss: 0.20931232139430084
Epoch 5/5, Loss: 0.13904573413950708
Test Accuracy: 83.11%


## 정확도 결과 요약

In [28]:
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_student_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_student_ce_and_mse_loss:.2f}%")

Teacher accuracy: 92.95%
Student accuracy without teacher: 82.44%
Student accuracy with CE + KD: 85.43%
Student accuracy with CE + CosineLoss: 82.91%
Student accuracy with CE + RegressorMSE: 83.11%
