In this demo, we use knowledge distillation to train a ResNet-18 model for image classification. We will show how to provide teacher model, student model, data loaders, inference pipeline and other arguments to the toolkit and start knowledge distillation training.

In [4]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/3c/9e/ddf2230626f5a56238d01a0739c62243170db1b136f8e0697db4402bec6d/pytorch_lightning-1.2.4-py3-none-any.whl (829kB)
[K     |▍                               | 10kB 17.0MB/s eta 0:00:01[K     |▉                               | 20kB 21.0MB/s eta 0:00:01[K     |█▏                              | 30kB 16.5MB/s eta 0:00:01[K     |█▋                              | 40kB 15.5MB/s eta 0:00:01[K     |██                              | 51kB 12.5MB/s eta 0:00:01[K     |██▍                             | 61kB 11.4MB/s eta 0:00:01[K     |██▊                             | 71kB 10.8MB/s eta 0:00:01[K     |███▏                            | 81kB 11.5MB/s eta 0:00:01[K     |███▋                            | 92kB 11.5MB/s eta 0:00:01[K     |████                            | 102kB 11.2MB/s eta 0:00:01[K     |████▍                           | 112kB 11.2MB/s eta 0:00:01[K     |████▊                   

## Download the toolkit 

In [1]:
!git clone https://github.com/georgian-io/Knowledge-Distillation-Toolkit.git

Cloning into 'Knowledge-Distillation-Toolkit'...
remote: Enumerating objects: 837, done.[K
remote: Counting objects: 100% (837/837), done.[K
remote: Compressing objects: 100% (671/671), done.[K
remote: Total 837 (delta 185), reused 771 (delta 140), pack-reused 0[K
Receiving objects: 100% (837/837), 3.31 MiB | 21.59 MiB/s, done.
Resolving deltas: 100% (185/185), done.


In [3]:
%cd Knowledge-Distillation-Toolkit/

/content/Knowledge-Distillation-Toolkit


In [4]:
import yaml
from collections import ChainMap

import torch
import torch.nn.functional as F
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision import datasets, transforms

from knowledge_distillation.kd_training import KnowledgeDistillationTraining

## Define the student model and teacher model

In [6]:
class StudentModel(ResNet):
    def __init__(self):
        super(StudentModel, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10) #ResNet18
        self.conv1 = torch.nn.Conv2d(1, 64,
            kernel_size=(7, 7),
            stride=(2, 2),
            padding=(3, 3), bias=False)

    def forward(self, batch, temperature=1):
        logits = super(StudentModel, self).forward(batch)
        logits = logits / temperature
        prob = F.softmax(logits, dim=0)
        log_prob = F.log_softmax(logits, dim=0)
        return {"logits":logits, "prob":prob, "log_prob":log_prob}

class TeacherModel(ResNet):
    def __init__(self):
        super(TeacherModel, self).__init__(BasicBlock, [3, 4, 6, 3], num_classes=10) #ResNet34
        self.conv1 = torch.nn.Conv2d(1, 64,
            kernel_size=(7, 7),
            stride=(2, 2),
            padding=(3, 3), bias=False)

    def forward(self, batch, temperature=1):
        logits = super(TeacherModel, self).forward(batch)
        logits = logits / temperature
        prob = F.softmax(logits, dim=0)
        log_prob = F.log_softmax(logits, dim=0)
        return {"logits":logits, "prob":prob, "log_prob":log_prob}

## Define the inference pipeline

In [7]:
class inference_pipeline:

    def __init__(self, device):
        self.device = device

    def run_inference_pipeline(self, model, data_loader):
        accuracy = 0
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                X, y = data[0].to(self.device), data[1].to(self.device)
                outputs = model(X)
                predicted = torch.max(outputs["prob"], 1)[1]
                accuracy += predicted.eq(y.view_as(predicted)).sum().item()
        accuracy = accuracy / len(data_loader.dataset)
        return {"inference_result": accuracy}

In [8]:
def get_data_for_kd_training(batch):
    data = torch.cat([sample[0] for sample in batch], dim=0)
    data = data.unsqueeze(1)
    return data,

## Read from demo_config.yaml, which contains all argument set up

In [10]:
config = yaml.load(open('./examples/resnet_compression_demo/demo_config.yaml','r'), Loader=yaml.FullLoader)
device = torch.device("cuda")

## Create training and validation data loaders
We will use the MNIST dataset

In [17]:
# Create data loaders for training and validation
transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
              ])
train_kwargs = {'batch_size': 16, 'num_workers': 0}
test_kwargs = {'batch_size': 1000, 'num_workers': 0}
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_dataset, collate_fn=get_data_for_kd_training, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
val_data_loaders = {"accuracy_on_validation_set": test_loader}

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Create an instance of inference pipeline

In [18]:
# Create inference pipeline for validating the student model
inference_pipeline_example = inference_pipeline(device)

## Create an instance of student model and teacher model

In [21]:
# Create student and teacher model
student_model = StudentModel()
teacher_model = TeacherModel()
teacher_model.load_state_dict(torch.load("./examples/resnet_compression_demo/trained_model/resnet34_teacher.pt"))

<All keys matched successfully>

## Pass data loaders, student and teacher model, inference pipeline and other argument set up into `KnowledgeDistillationTraining`

In [22]:
# Train a student model with knowledge distillation and get its performance on dev set
KD_resnet = KnowledgeDistillationTraining(train_data_loader = train_data_loader,
                                          val_data_loaders = val_data_loaders,
                                          inference_pipeline = inference_pipeline_example,
                                          student_model = student_model,
                                          teacher_model = teacher_model,
                                          num_gpu_used = config["knowledge_distillation"]["general"]["num_gpu_used"],
                                          final_loss_coeff_dict = config["knowledge_distillation"]["final_loss_coeff"],
                                          logging_param = ChainMap(config["knowledge_distillation"]["general"],
                                                                   config["knowledge_distillation"]["optimization"],
                                                                   config["knowledge_distillation"]["final_loss_coeff"],
                                                                   config["knowledge_distillation"]["pytorch_lightning_trainer"]),
                                          **ChainMap(config["knowledge_distillation"]["optimization"],
                                                     config["knowledge_distillation"]["pytorch_lightning_trainer"],
                                                     config["knowledge_distillation"]["comet_info"])
                                          )

Global seed set to 32
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


## Start knowledge distillation training

In [23]:
KD_resnet.start_kd_training()


  | Name          | Type         | Params
-----------------------------------------------
0 | student_model | StudentModel | 11.2 M
1 | teacher_model | TeacherModel | 21.3 M
-----------------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
129.836   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…


accuracy_on_validation_set :0.0666

GPU 0 current active MB: 131.81951999999998
GPU 0 current reserved MB: 157.2864


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.7412

GPU 0 current active MB: 266.921984
GPU 0 current reserved MB: 299.892736


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8167

GPU 0 current active MB: 266.923008
GPU 0 current reserved MB: 297.79558399999996


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8405

GPU 0 current active MB: 266.923008
GPU 0 current reserved MB: 299.892736


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8569

GPU 0 current active MB: 266.923008
GPU 0 current reserved MB: 297.79558399999996


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8696

GPU 0 current active MB: 266.923008
GPU 0 current reserved MB: 297.79558399999996



As the above output shows, validation accuracy of the student model improves in every training epoch