# KD(Knowledge Distilation, 지식 증류)

In [1]:
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# CIFAR-10
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

## Prep

### Teacher / Student 네트워크 구성

In [3]:
# Teacher model
class TeacherNN(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Student model
class StudentNN(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### 학습 함수

In [4]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for x, y_true in tqdm(train_loader):
            x, y_true = x.to(device), y_true.to(device)

            optimizer.zero_grad()
            y_pred = model(x)

            loss = criterion(y_pred, y_true)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for x, y_true in tqdm(test_loader):
            x, y_true = x.to(device), y_true.to(device)

            y_pred = model(x)
            _, predicted = torch.max(y_pred.data, 1)

            total += y_true.size(0)
            correct += (predicted == y_true).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

### 모델 선언 및 모델 별 파라미터 수 확인
- `teacherModel`: Teacher network
- `student_1`: Student network without KD
- `student_2`: Student network with KD

In [None]:
torch.manual_seed(42)
teacherModel = TeacherNN(num_classes=10).to(device)

torch.manual_seed(42)
student_1 = StudentNN(num_classes=10).to(device)

torch.manual_seed(42)
student_2 = StudentNN(num_classes=10).to(device)


total_params_teacher = "{:,}".format(sum(p.numel() for p in teacherModel.parameters()))
print(f"teacherModel parameters: {total_params_teacher}")

total_params_student_1 = "{:,}".format(sum(p.numel() for p in student_1.parameters()))
print(f"student_1 parameters: {total_params_student_1}")

total_params_student_2 = "{:,}".format(sum(p.numel() for p in student_2.parameters()))
print(f"student_2 parameters: {total_params_student_2}")

## 실험

### 실험 0: KD 없이 각 모델 성능 비교

In [None]:
train(teacherModel, train_loader, epochs=20, learning_rate=0.001, device=device)
test_accuracy_teacher = test(teacherModel, test_loader, device)

In [None]:
train(student_1, train_loader, epochs=20, learning_rate=0.001, device=device)
test_accuracy_student_1 = test(student_1, test_loader, device)

### 실험 1: KD 적용

![](https://pytorch.org/tutorials/_static/img/knowledge_distillation/distillation_output_loss.png)

In [None]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()  # 교차 엔트로피 손실 함수 정의
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # teacher model을 평가 모드로 설정 (학습대상 아님)
    student.train()  # student model을 학습 모드로 설정

    for epoch in range(epochs):
        running_loss = 0.0  # 각 epoch의 손실값을 추적

        for x, y_true in tqdm(train_loader):
            x, y_true = x.to(device), y_true.to(device)

            """
            1. 옵티마이저의 기울기 초기화

            2. teacher model의 예측값 계산
               - torch.no_grad() 컨텍스트 안에서 teacher model을 사용해 입력 데이터 x에 대한 예측값(로짓) 계산
               - 이 과정에서 기울기 계산되지 않도록 설정

            3. student model의 예측값 계산
               - student model을 사용해 입력 데이터 x에 대한 예측값(로짓) 계산

            4. teacher model의 로짓을 사용해 soft label 계산
               - T로 나눈 후 softmax를 적용해 soft label 구하기
               - soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)

            5. student model의 로짓을 부드럽게 만들어 soft label과 비교할 수 있도록 설정
               - T로 나눈 후 log_softmax 적용
               - soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            6. soft label 손실 계산
               - soft_targets와 soft_prob를 이용해 soft label 손실 계산
               - 계산된 손실에 대해 T^2을 곱하여 스케일링
         soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            7. 실제 레이블을 사용한 교차 엔트로피 손실 계산
               - y_true와 student_logits을 이용해 교차 엔트로피 손실 계산

            8. 두 손실을 가중치에 따라 결합
               - soft_target_loss_weight와 ce_loss_weight을 이용해 두 손실 결합

            9. 손실 값 역전파
               - loss.backward() 호출로 손실에 대한 기울기 계산

            10. 옵티마이저를 사용해 가중치 업데이트
               - optimizer.step() 호출로 student model의 파라미터 업데이트
            """

            running_loss += loss.item()  # 배치별 손실을 누적

        # 각 epoch이 끝날 때 손실 출력
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


In [None]:
train_knowledge_distillation(teacher=teacherModel, student=student_2, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_2 = test(student_2, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_1:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_2:.2f}%")

### 실험 2: KD 적용, Cosine loss

![](https://pytorch.org/tutorials/_static/img/knowledge_distillation/cosine_loss_distillation.png)

In [10]:
class TeacherCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
        return x, flattened_conv_output_after_pooling


class StudentNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        return x, flattened_conv_output

In [None]:
teacherCosine = TeacherCosine(num_classes=10).to(device)
teacherCosine.load_state_dict(teacherModel.state_dict())

print("Norm of 1st layer for teacherModel:", torch.norm(teacherModel.features[0].weight).item())
print("Norm of 1st layer for teacherCosine:", torch.norm(teacherCosine.features[0].weight).item())

torch.manual_seed(42)
studentCosine = StudentNNCosine(num_classes=10).to(device)
print("Norm of 1st layer for studentCosine:", torch.norm(studentCosine.features[0].weight).item())

In [None]:
# Dummy tensor을 입력하여 작동 확인
sample_input = torch.randn(128, 3, 32, 32).to(device)

logits, hidden_representation = studentCosine(sample_input)

# 텐서 shape 출력
print("Student logits shape:", logits.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Teacher 모델 텐서 입력
logits, hidden_representation = teacherCosine(sample_input)

# Shape 확인
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

In [None]:
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()  # 교차 엔트로피 손실 함수 정의
    cosine_loss = nn.CosineEmbeddingLoss()  # 코사인 임베딩 손실 함수 정의
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # teacher model을 평가 모드로 설정
    student.train()  # student model을 학습 모드로 설정

    for epoch in range(epochs):
        running_loss = 0.0

        for x, y_true in tqdm(train_loader):
            x, y_true = x.to(device), y_true.to(device)

            """
            1. 옵티마이저의 기울기 초기화

            2. teacher model의 예측값 계산
               - torch.no_grad() 컨텍스트 안에서 teacher model을 사용해 입력 데이터 x에 대한 예측값을 계산하고, 
                 그 중에서 hidden representation만 추출
               - _, teacher_hidden_representation = teacher(x)

            3. student model의 예측값 계산
               - student model을 사용해 입력 데이터 x에 대한 로짓과 hidden representation을 함께 계산
               - student_logits, student_hidden_representation = student(x)

            4. 코사인 손실 계산
               - student model과 teacher model의 숨겨진 표현을 비교해 코사인 손실 계산
               - 타겟 벡터는 모두 1로 설정 (코사인 유사도가 최대가 되도록 학습)
               - hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(x.size(0)).to(device))

            5. 실제 레이블을 사용한 교차 엔트로피 손실 계산
               - y_true와 student_logits을 이용해 교차 엔트로피 손실 계산

            6. 두 손실을 가중치에 따라 결합
               - hidden_rep_loss_weight와 ce_loss_weight를 이용해 두 손실 결합

            7. 손실 값 역전파
               - loss.backward() 호출로 손실에 대한 기울기 계산

            8. 옵티마이저를 사용해 가중치 업데이트
               - optimizer.step() 호출로 student model의 파라미터 업데이트
            """

            running_loss += loss.item()  # 배치별 손실을 누적

        # 각 epoch이 끝날 때 손실 출력
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


In [14]:
def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for x, y_true in tqdm(test_loader):
            x, y_true = x.to(device), y_true.to(device)

            y_pred, _ = model(x) # Inference 시에는 사용하지 않음
            _, predicted = torch.max(y_pred.data, 1)

            total += y_true.size(0)
            correct += (predicted == y_true).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
train_cosine_loss(teacher=teacherCosine, student=studentCosine, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)

test_accuracy_student_ce_and_cosine_loss = test_multiple_outputs(studentCosine, test_loader, device)

### 실험 2: KD 적용, Regression loss

![](https://pytorch.org/tutorials/_static/img/knowledge_distillation/fitnets_knowledge_distill.png)

In [16]:
class TeacherMSE(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherMSE, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

    
class StudentMSE(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentMSE, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # Include an extra regressor (in our case linear)
        self.regressor = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

In [None]:
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()  # 교차 엔트로피 손실 함수 정의
    mse_loss = nn.MSELoss()  # MSE 손실 함수 정의
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()
    student.train()

    for epoch in range(epochs):
        running_loss = 0.0  # 각 epoch의 손실값을 추적

        for x, y_true in tqdm(train_loader):
            x, y_true = x.to(device), y_true.to(device)

            """
            1. 옵티마이저의 기울기 초기화

            2. teacher model의 예측값 계산
               - torch.no_grad() 컨텍스트 안에서 teacher model을 사용해 입력 데이터 x에 대한 예측값을 계산하고,
                 그 중에서 feature map만 추출
               - _, teacher_feature_map = teacher(x)

            3. student model의 예측값 계산
               - student model을 사용해 입력 데이터 x에 대한 로짓과 feature map을 함께 계산
               - student_logits, regressor_feature_map = student(x)

            4. MSE 손실 계산
               - student model과 teacher model의 특징 맵을 비교해 MSE 손실 계산
               - hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            5. 실제 레이블을 사용한 교차 엔트로피 손실 계산
               - y_true와 student_logits을 이용해 교차 엔트로피 손실 계산

            6. 두 손실을 가중치에 따라 결합
               - feature_map_weight와 ce_loss_weight를 이용해 두 손실 결합

            7. 손실 값 역전파
               - loss.backward() 호출로 손실에 대한 기울기 계산

            8. 옵티마이저를 사용해 가중치 업데이트
               - optimizer.step() 호출로 student model의 파라미터 업데이트
            """

            running_loss += loss.item()  # 배치별 손실을 누적

        # 각 epoch이 끝날 때 손실 출력
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


In [None]:
torch.manual_seed(42)
studentMse = StudentMSE(num_classes=10).to(device)

# 가중치 복사
teacherMse = TeacherMSE(num_classes=10).to(device)
teacherMse.load_state_dict(teacherModel.state_dict())

In [None]:
train_mse_loss(teacher=teacherMse, student=studentMse, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_mse_loss = test_multiple_outputs(studentMse, test_loader, device)

In [None]:
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_1:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_2:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_student_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_student_ce_and_mse_loss:.2f}%")