This notebook demonstates how to distill pre-trained neural network with born-again neural network (BAN)

In [3]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List
import torch.nn.functional as F
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader

from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

In [3]:
### create class object VGG to identify the pretrained model architecture
class VGG(nn.Module):
  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))

    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    x = x.mean([2, 3])

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

In [6]:
### load pretrained model
checkpoint=torch.load("./vgg.cifar.pretrained.pth",map_location="cpu")
model = VGG().cuda()

model.load_state_dict(checkpoint['state_dict'])


<All keys matched successfully>

In [7]:
### create train and test dataset
image_size = 32
transforms = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
    ]),
    "test": ToTensor(),
}
dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=512,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
  )

100%|██████████| 170M/170M [00:03<00:00, 43.6MB/s]


In [10]:
### create class object for BornAgainNN distillation
class BornAgainNN:
  def __init__(self,model,model_class,lr,train_loader, test_loader,alpha=0.4, T=5):
    self.VGG=model_class
    self.pretrained_model=model

    self.lr=lr

    self.criterion = nn.CrossEntropyLoss()
    self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
    self.train_loader=train_loader
    self.test_loader=test_loader
    self.alpha=alpha
    self.T=T

  def train_generation(self,student, teacher, epochs):
    teacher.eval()
    optimizer = Adam(student.parameters(), lr=self.lr)
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in self.train_loader:
            images, labels = images.cuda(), labels.cuda()
            student_logits = student(images)

            if teacher:
                with torch.no_grad():
                    teacher_logits = teacher(images)

                # Soft targets
                soft_student = F.log_softmax(student_logits / self.T, dim=1)
                soft_teacher = F.softmax(teacher_logits / self.T, dim=1)

                loss_kd = self.kd_loss_fn(soft_student, soft_teacher) * self.T * self.T
                loss_ce = self.criterion(student_logits, labels)
                loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd
            else:
                loss = self.criterion(student_logits, labels)

            optimizer.zero_grad() ## reset gradients
            loss.backward() ### update gradients
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")
    return student


  def run_generation(self,epoch=7,generation=3):


    teacher = self.pretrained_model

    #acc = self.evaluate(teacher)


    for i in range(generation):
      student = self.VGG().cuda()
      student.load_state_dict(checkpoint['state_dict'])
      student = self.train_generation(student, teacher, epochs=epoch)
      acc = self.evaluate(student)
      # Teacher for next generation is current student
      teacher = copy.deepcopy(student)

      with open("log.txt","a") as file:
        file.write(f"generation {i +1}, Accuracy: {acc}")

  def evaluate(self, model):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = correct / total
        print(f"Accuracy: {acc:.4f}")
        return acc



In [2]:
trainer = BornAgainNN(
    model=model,

    model_class=VGG,
    lr=0.002,
    train_loader=dataloader["train"],
    test_loader=dataloader["test"],

)

trainer.run_generation(epoch=5,generation=3)