In [None]:
!pip install pytorch-lightning



In [None]:
import yaml
from collections import ChainMap

import torch
import torch.nn.functional as F
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision import datasets, transforms

from kd_training import KnowledgeDistillationTraining

In [None]:
class StudentModel(ResNet):
    def __init__(self):
        super(StudentModel, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10) #ResNet18
        self.conv1 = torch.nn.Conv2d(1, 64,
            kernel_size=(7, 7),
            stride=(2, 2),
            padding=(3, 3), bias=False)

    def forward(self, batch, temperature=1):
        logits = super(StudentModel, self).forward(batch)
        logits = logits / temperature
        prob = F.softmax(logits, dim=0)
        log_prob = F.log_softmax(logits, dim=0)
        return {"logits":logits, "prob":prob, "log_prob":log_prob}

class TeacherModel(ResNet):
    def __init__(self):
        super(TeacherModel, self).__init__(BasicBlock, [3, 4, 6, 3], num_classes=10) #ResNet34
        self.conv1 = torch.nn.Conv2d(1, 64,
            kernel_size=(7, 7),
            stride=(2, 2),
            padding=(3, 3), bias=False)

    def forward(self, batch, temperature=1):
        logits = super(TeacherModel, self).forward(batch)
        logits = logits / temperature
        prob = F.softmax(logits, dim=0)
        log_prob = F.log_softmax(logits, dim=0)
        return {"logits":logits, "prob":prob, "log_prob":log_prob}

In [None]:
class inference_pipeline:

    def __init__(self, device):
        self.device = device

    def run_inference_pipeline(self, model, data_loader):
        accuracy = 0
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                X, y = data[0].to(self.device), data[1].to(self.device)
                outputs = model(X)
                predicted = torch.max(outputs["prob"], 1)[1]
                accuracy += predicted.eq(y.view_as(predicted)).sum().item()
        accuracy = accuracy / len(data_loader.dataset)
        return {"inference_result": accuracy}

In [None]:
def get_data_for_kd_training(batch):
    data = torch.cat([sample[0] for sample in batch], dim=0)
    data = data.unsqueeze(1)
    return data,

In [None]:
config = yaml.load(open('./demo_config.yaml','r'), Loader=yaml.FullLoader)
device = torch.device("cuda")

In [None]:
# Create data loaders for training and validation
transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
              ])
train_kwargs = {'batch_size': 16, 'num_workers': 0}
test_kwargs = {'batch_size': 1000, 'num_workers': 0}
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_dataset, collate_fn=get_data_for_kd_training, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
val_data_loaders = {"accuracy_on_validation_set": test_loader}

In [None]:
# Create inference pipeline for validating the student model
inference_pipeline_example = inference_pipeline(device)

In [None]:
# Create student and teacher model
student_model = StudentModel()
teacher_model = TeacherModel()
teacher_model.load_state_dict(torch.load("resnet34_teacher.pt"))

<All keys matched successfully>

In [None]:
# Train a student model with knowledge distillation and get its performance on dev set
KD_resnet = KnowledgeDistillationTraining(train_data_loader = train_data_loader,
                                          val_data_loaders = val_data_loaders,
                                          inference_pipeline = inference_pipeline_example,
                                          student_model = student_model,
                                          teacher_model = teacher_model,
                                          num_gpu_used = config["knowledge_distillation"]["general"]["num_gpu_used"],
                                          final_loss_coeff_dict = config["knowledge_distillation"]["final_loss_coeff"],
                                          logging_param = ChainMap(config["knowledge_distillation"]["general"],
                                                                   config["knowledge_distillation"]["optimization"],
                                                                   config["knowledge_distillation"]["final_loss_coeff"],
                                                                   config["knowledge_distillation"]["pytorch_lightning_trainer"]),
                                          **ChainMap(config["knowledge_distillation"]["optimization"],
                                                     config["knowledge_distillation"]["pytorch_lightning_trainer"],
                                                     config["knowledge_distillation"]["comet_info"])
                                          )

Global seed set to 32
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [None]:
KD_resnet.start_kd_training()


  | Name          | Type         | Params
-----------------------------------------------
0 | student_model | StudentModel | 11.2 M
1 | teacher_model | TeacherModel | 21.3 M
-----------------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…


accuracy_on_validation_set :0.0969

GPU 0 current active MB: 131.81951999999998
GPU 0 current reserved MB: 157.2864


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.7284

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8059

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8302

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8462

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8577

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8675

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.874

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8793

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8844

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


accuracy_on_validation_set :0.8894

GPU 0 current active MB: 278.28992
GPU 0 current reserved MB: 312.475648

