<a href="https://colab.research.google.com/github/haepaly/Knowledge-Distillation/blob/main/%5B%EA%B8%B0%EA%B3%84%ED%95%99%EC%8A%B5%5D_Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[1] 기본 torch 관련 함수 import

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim

[2] GPU 사용 확인

In [None]:
torch.cuda.is_available() # gpu 사용 확인

use_cuda=torch.cuda.is_available()
device=torch.device("cuda" if use_cuda else "cpu")
device #cuda

[3] Data 처리 관련

In [None]:
# Data Augmentation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding = 4),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441),(0.267, 0.256, 0.276))
])

# Cifar-10 data download
train_dataset = datasets.CIFAR10(root= '/data', train = True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root= '/data', train = False, download=True, transform=transform)

[4] 하이퍼 파라미터 설정 및 Loader 설정 \\
 * 지나친 고가 GPU 사용 경쟁을 막기 위해서 하기 파라미터는 수정 하지 말 것

In [None]:
# Hyper-parameter for network
# Don't change it for this assignment
epochs = 60
batchsize = 128
learning_rate = 0.01 #편의를 위해서 고정함
momentum = 0.9
weight_decay = 0.0001

# loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=2)

[5] Teacher Model 생성 및 학습 수행
* VGG style 10 layer network

In [None]:
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride =2),

            nn.Conv2d(32,64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride =2),

            nn.Conv2d(64,128,kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128,kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride =2),

            nn.Conv2d(128,256,kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride =2),
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(1024,128),
            nn.Linear(128,10),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

teacher = Teacher().to(device) #Teacher모델 gpu에 생성
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay)

print("============Training start=============")
for epoch in range(epochs):
    teacher.train()
    train_loss = 0
    correct = 0
    total = 0

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        output = teacher(images)
        loss = criterion(output,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()

    if epoch % 10 == 0:
        print('Epoch: {:3d} | Batch_idx: {:3d} |  Loss: {:.4f} | Acc: {:3.2f}%'.format(
            epoch, idx, train_loss / (idx + 1), 100. * correct / total))
print("============Training finished=============")


teacher.eval()  # 모델 평가모드
with torch.no_grad():
    correct = 0
    val_acc = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = teacher(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    print('Accuracy on the test set: {}'.format(val_acc))

torch.save({
    'epoch': epoch,
    'model_state_dict': teacher.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),},
    '/content/teacher_model.pth') # 모델 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장

[6] Student Model 생성 및 학습
* VGG style 2 layer network


In [None]:
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student,self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3, padding =1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(16,16,kernel_size=3, padding =1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layers = nn.Linear(1024,10)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x= self.fc_layers(x)
        return x

student = Student().to(device)  #Student모델 gpu에 생성
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay)

print("============Training start=============")
for epoch in range(epochs):
    student.train()
    train_loss = 0
    correct = 0
    total = 0

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        output = student(images)
        loss = criterion(output,labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()

    if epoch % 10 == 0:
        print('Epoch: {:3d} | Batch_idx: {:3d} |  Loss: {:.4f} | Acc: {:3.2f}%'.format(
            epoch, idx, train_loss / (idx + 1), 100. * correct / total))

print("============Training finished=============")

student.eval()  # 모델 평가모드
with torch.no_grad(): #no gradient
    correct = 0
    val_acc = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = student(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    print('Accuracy on the test set: {}'.format(val_acc))

torch.save({
    'epoch': epoch,
    'model_state_dict': student.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
    }, '/content/student_model.pth') # 모델 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장

[7] Knowledge Distillation

In [None]:
trained_teacher = Teacher().to(device)
model_ckp = torch.load('/content/teacher_model.pth')
trained_teacher.load_state_dict(model_ckp['model_state_dict']) #teacher model 새로 생성후 teacher_checkpoint load를 가져와서 trained_teacher에 적용


lambda_ = 0.0001 #Knowledge distillation을 위한 parameters (lamda_, T)
T = 4.5
kl_div_loss = nn.KLDivLoss() #Knowledge distillation을 위한 Cost function

print("============Training start=============")
for epoch in range(epochs):
    student.train() #위에 train된 student 모델이 아니라 위에 만든 student를 그대로 다시 적용
    train_loss = 0
    correct = 0
    total = 0

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        s_output = student(images)
        t_output =trained_teacher(images)

        loss_SL = criterion(s_output, labels) # Standard Learning loss
        loss_KD = kl_div_loss(F.log_softmax(s_output / T, dim=1),
                            F.softmax(t_output / T, dim=1))
        loss = (1 - lambda_) * loss_SL + lambda_ * T * T * loss_KD  # total_loss = (1 −λ)⋅loss_SL +λ⋅T^2 ⋅loss_KD)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(s_output.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()

    if epoch % 10 == 0:
        print('Epoch: {:3d} | Batch_idx: {:3d} |  Loss: {:.4f} | Acc: {:3.2f}%'.format(
            epoch, idx, train_loss / (idx + 1), 100. * correct / total))

print("============Training finished=============")


student.eval()  # 모델 평가모드
with torch.no_grad(): #no gradient
    correct = 0
    val_acc = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = student(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
print('Accuracy on the test set: {}'.format(val_acc))


torch.save({
    'epoch': epoch,
    'model_state_dict': student.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),},
    '/content/KD_model.pth') # KD 적용한 student 모델의 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장