# knowledge distillation

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/ColabNotebooks/workspace/knowledge_distillation
!ls

/content/drive/MyDrive/ColabNotebooks/workspace/knowledge_distillation
checkpoint  data  knowledge_distillation.ipynb	wandb


## Setting

In [3]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 KB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.13.3 timm-0.6.13


In [4]:
!pip install wandb -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m61.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [5]:
import wandb
api_key = ''  # option
wandb.login(key=api_key)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import os

import torchvision

import timm

## Dataset

In [7]:
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda")


transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
     

Files already downloaded and verified
Files already downloaded and verified


In [8]:
print(train_dataset.__len__())
print(test_dataset.__len__())

50000
10000


## Model

In [9]:
# Teacher model : vit
teacher_net = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes = 10)

# requires_grad 확인
for name, param in teacher_net.named_parameters():
  if 'head.weight' == name:
    break
  param.requires_grad = False

pretrained_vitWeight = './checkpoint/epoch15_vit_base_patch16_224.pt'
checkpoint_vit = torch.load(pretrained_vitWeight)
teacher_net.load_state_dict(checkpoint_vit['net'])

# Student model : resnet
student_net = torchvision.models.resnet34(pretrained=True)
student_net.fc = nn.Linear(512, 10)

student_criterion = nn.CrossEntropyLoss()
# student_optimizer = optim.Adam(student_net.parameters(), lr=1e-4)
student_optimizer = optim.SGD(student_net.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.0002)

# pretrained_resnetWeight = '/content/drive/MyDrive/ColabNotebooks/workspace/knowledge_distillation/checkpoint/best_mode.pt'
# checkpoint_resnet = torch.load(pretrained_resnetWeight)
# student_net.load_state_dict(checkpoint_resnet['net'])


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 272MB/s]


### Distillation loss

In [10]:
# knowledge distillation loss
def distillation_loss(student_scores, targets, teacher_scores, T, alpha):
    # distillation loss + classification loss
    # student_scores : student model ouputs (soft label) 
    # targets : labels
    # teacher_scores: teacher model outputs (soft label)

    distillation_loss = nn.KLDivLoss()(F.log_softmax(student_scores/T), F.softmax(teacher_scores/T))  
    student_loss = F.cross_entropy(student_scores,targets) 

    # distillation_loss, student_loss의 weighted sum으로 계산
    return distillation_loss*(T*T * 2.0 + alpha) + student_loss*(1.-alpha)

# # val loss
# loss_func = nn.CrossEntropyLoss()

## Train

In [11]:
# train
def knowledge_distillation_train(epoch, student_net, teacher_net, student_optimizer, device):
    print('\n[ Train epoch: %d ]' % epoch)
    student_net.to(device), teacher_net.to(device)
    student_net.train(), teacher_net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        student_optimizer.zero_grad()
        
        student_outputs = student_net(inputs)
        teacher_outputs = teacher_net(inputs).detach()
        loss = distillation_loss(student_outputs, targets, teacher_outputs, T=20.0, alpha=0.7)
        # loss = criterion(outputs, targets)
        loss.backward()

        student_optimizer.step()
        train_loss += loss.item()
        _, predicted = student_outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print('\nCurrent batch:', str(batch_idx))
            print('Current benign train accuracy:', str(predicted.eq(targets).sum().item() / targets.size(0)))
            print('Current benign train loss:', loss.item())

    train_acc = 100. * correct / total
    train_loss = loss / total
    print('\nTotal benign train accuarcy:', 100. * correct / total)
    print('Total benign train loss:', train_loss)

    try:
      wandb.log({'Train Accuracy': train_acc, 'Tran Loss': train_loss, 'epoch':epoch})
    except:
      print('No wandb')
      pass

In [16]:
# test
best_test_acc = 0
def student_model_test(epoch, net, criterion, file_name):
    print('\n[ Test epoch: %d ]' % epoch)
    net.eval()
    loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        total += targets.size(0)

        outputs = net(inputs)
        loss += criterion(outputs, targets).item()

        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

    test_acc = 100. * correct / total
    test_loss = loss / total
    print('\nTest accuarcy:', 100. * correct / total)
    print('Test average loss:', loss / total)

    # epoch 5마다 저장 & best model 저장
    # if epoch % 5 == 0:
    state = {
        'net': net.state_dict(),
        'epoch': epoch
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, f'./checkpoint/epoch{epoch}_{file_name}')
    print('Model Saved!')

    # # best model 저장
    # if best_test_acc < test_acc:
    #   state = {
    #       'net': net.state_dict(),
    #       'epoch':epoch
    #   }
    #   if not os.path.isdir('checkpoint'):
    #       os.mkdir('checkpoint')
    #   torch.save(state, f'./checkpoint/best_mode.pt')
    #   print('Best Model Saved!')

    try:
      wandb.log({'Test Accuracy': test_acc, 'Test Loss': test_loss, 'epoch':epoch})
    except:
      print('No wandb')
      pass


In [17]:
# wandb config
WANDB_CONFIG = {
    'model':'resnet34_dislled_VitBase', 
    'batch_size':128,
    'epoch': 50,
    'pretrained': 'imagenet',
    'optimizer':'sgd',
    'lr':1e-4,
    'momentum':0.9,
    'weight_decay':0.0002,
    'criterion':'crossentropy'
}

wandb.init(project='knowledge_distillation', config=WANDB_CONFIG, name=WANDB_CONFIG['model'])

VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.120305…

In [18]:
# pretrained_resnetWeight = '/content/drive/MyDrive/ColabNotebooks/workspace/knowledge_distillation/checkpoint/best_mode.pt'
# checkpoint_resnet = torch.load(pretrained_resnetWeight)
# student_net.load_state_dict(checkpoint_resnet['net'])

# epoch = checkpoint_resnet['epoch']
start_epoch = 1

In [None]:
file_name = WANDB_CONFIG['model']
for epoch in range(start_epoch, WANDB_CONFIG['epoch']):
    knowledge_distillation_train(epoch, student_net, teacher_net, student_optimizer, device)
    student_model_test(epoch, student_net, student_criterion, f'{file_name}.pt')

wandb.finish()