<a href="https://colab.research.google.com/github/elhamsh93/image_classification/blob/main/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

/bin/bash: nvidia-smi: command not found


In [1]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4


In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchmetrics
import torchvision
from torchvision.models import mobilenet_v2
from torchvision import datasets, transforms 
import os
import sys

import numpy as np


In [3]:
import logging
logger = logging.getLogger('KnowloedgeDistillation')
logging.info('test')

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class AverageMeter(object):
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

In [7]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

transfrom_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform= transfrom_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=False, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform= transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 40405589.00it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [13]:
def train(epoch):
  net.train()
  loss_total = AverageMeter()
  accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    loss_total.update(loss)
    accuracy(outputs.softmax(dim=-1), targets)
    # overfit in 50 epochs just to check if everything is working 
    # if batch_idx == 5:
    #   break
    scheduler.step()
    
  acc = accuracy.compute()
  # writer.add_scalar('Loss/train', loss_total.avg.item(), epoch)
  # writer.add_scalar('Acc/train', acc.item(), epoch)
  logger.info(f'Train: Epoch:{epoch} Loss: {loss_total.avg:.4} Accuracy:{acc:.4}' )


def test(epoch, checkpoint):
  net.eval()
  loss_total = AverageMeter()
  accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).cuda()
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to(device), targets.to(device)
      outputs = net(inputs)
      loss = criterion(outputs, targets)
      loss_total.update(loss)
      accuracy(outputs.softmax(dim=-1), targets)
    acc = accuracy.compute()
    # writer.add_scalar('Loss/test', loss_total.avg.item(), epoch)
    # writer.add_scalar('Acc/test', acc.item(), epoch)
    logger.info(f'Test: Epoch:{epoch} Loss:{loss_total.avg:.4} Accuracy:{acc:.4}')
    print()
  checkpoint.save(accuracy.compute(),'ckpt', epoch= epoch)


class Checkpoint(object):
  def __init__(self, model_name):
    self.best_acc = 0.
    self.folder = 'checkpoint'
    self.model_name = model_name
    os.makedirs(self.folder, exist_ok=True)
  
  def save(self, acc, epoch=-1):
    if acc > self.best_acc:
      logger.info('Saving checkpoint....')
      state = {
          'net': net.state_dict(),
          'acc': acc,
          'epoch': epoch,
      }
      path = os.path.join(os.path.abspath(self.folder), self.model_name+'.pth')
      torch.save(state, path)
      self.best_acc = acc

In [14]:
class LeNet5(nn.Module):
  def __init__(self):
    super(LeNet5, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(400, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 400)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [15]:
net = LeNet5()
net.to(device)

LeNet5(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)

In [None]:
checkpoint = Checkpoint('ckpt-lenet')
start, end = 0, 50
for epoch in range(start, end):
  train(epoch)
  test(epoch, checkpoint)


In [None]:
ckpt = torch.load('/content/checkpoint/ckpt-lenet.pth')
logger.info(f'Best Accuracy: {ckpt['acc']:.4}')

Train Teacher (MobileNet) base model

In [None]:
teacher = mobilenet_v2(True)
teacher.classifier[1] = nn.Linear(1280, 10)
teacher.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)

In [None]:
checkpoint = Checkpoint('ckpt-teacher')
start, end = 0, 50
for epoch in range(start, end):
  train(epoch)
  test(epoch, checkpoint)


In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha, T):
  KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                           F.softmax(teacher_outputs/T, dim=1)) * (alpha*T*T) + \
                           F.cross_entropy(outputs, labels) * (1. - alpha)
  return KD_loss

In [None]:
def train(epoch):
  net.train()
  loss_total = AverageMeter()
  accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = net(inputs)
    with torch.no_grad():
      outputs2 = teacher(inputs)
    loss = loss_fn_kd(outputs, targets, outputs2, 0.6, 10)
    loss.backward()
    optimizer.step()
    loss_total.update(loss)
    accuracy(outputs.softmax(dim=-1), targets)
    # overfit in 50 epochs just to check if everything is working 
    # if batch_idx == 5:
    #   break
    scheduler.step()
    
  acc = accuracy.compute()
  # writer.add_scalar('Loss/train', loss_total.avg.item(), epoch)
  # writer.add_scalar('Acc/train', acc.item(), epoch)
  logger.info(f'Train: Epoch:{epoch} Loss: {loss_total.avg:.4} Accuracy:{acc:.4}' )


def test(epoch, checkpoint):
  net.eval()
  loss_total = AverageMeter()
  accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).cuda()
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to(device), targets.to(device)
      outputs = net(inputs)
      loss = criterion(outputs, targets)
      loss_total.update(loss)
      accuracy(outputs.softmax(dim=-1), targets)
    acc = accuracy.compute()
    # writer.add_scalar('Loss/test', loss_total.avg.item(), epoch)
    # writer.add_scalar('Acc/test', acc.item(), epoch)
    logger.info(f'Test: Epoch:{epoch} Loss:{loss_total.avg:.4} Accuracy:{acc:.4}')
    print()
  checkpoint.save(accuracy.compute(),'ckpt', epoch= epoch)


In [None]:
net = LeNet5()
net.to(device)

In [None]:
teacher = mobilenet_v2(True)
teacher.classifier[1] = nn.Linear(1280, 10)
teacher.load_state_dict(torch.load('/content/checkpoint/ckpt-teacher.pth')['net'])
teacher.eval()
teacher.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)

In [None]:
checkpoint = Checkpoint('ckpt-student')
start, end = 0, 50
for epoch in range(start, end):
  train(epoch)
  test(epoch, checkpoint)
