# _Knowledge Distillation for Compression of ResNet50_
***
### _Santiago Giron_ <br>

In the following notebook, I examine my initial results applying _Knowledge Distilation_ to residual networks for image classification in PyTorch. In Knowledge Distillation, a deep _teacher network_ is used to help train a shallower _student network_.

## __ResNet50__

As the teacher network, I've selected _ResNet50_ pretrained on ImageNet.
> The model achieves an accuracy of 96.81 % on the CIFAR10 test set after training for 25 epochs.

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, models
from torchvision import transforms as T


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
valset = datasets.CIFAR10(".", train=False, download=True, 
                          transform=T.Compose([
                              T.Resize(256),
                              T.CenterCrop(224),
                              T.ToTensor(),
                              normalize,
                          ]))

dataloader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=False, num_workers=4)

teacher_model = models.resnet50(pretrained=False)
num_ftrs = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs, len(valset.classes))
teacher_model.load_state_dict(torch.load("resnet50.pt"))
teacher_model.eval()
teacher_model.to(device)

def get_model_acc(model, dataloader):
    running_corrects = 0
    for i, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        running_corrects += torch.sum(preds == labels.data)

    acc = running_corrects.double() / len(valset)
    print(f'Accuracy on validation set: {100. * acc:.4f}%')

get_model_acc(teacher_model, dataloader)

Files already downloaded and verified
Accuracy on validation set: 96.8100%


<br>In order to measure the latency of ResNet50, I profile the inference time of the model on some dummy data.

In [2]:
import torch.autograd.profiler as profiler

def get_latency_data(model):
    input_batch = torch.randn(1, 3, 224, 224)
    input_batch = input_batch.to(device)
    model.to(device)

    model(input_batch) # warm-up

    with torch.no_grad():
        with profiler.profile(record_shapes=False, use_cuda=True) as prof:
            with profiler.record_function("model_inference"):
                model(input_batch)

    print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=10))

get_latency_data(teacher_model)

---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  model_inference        17.30%       2.437ms        99.80%      14.063ms      14.063ms     666.241us         4.71%      14.133ms      14.133ms             1  
                     aten::conv2d         2.42%     341.473us        40.20%       5.665ms     106.892us     205.049us         1.45%       8.546ms     161.238us            53  
                aten::convolution         2.41%     339.169us        37.78%       5.324ms     100.449us     392.100us  

## __ResNet18__

As the student network I have chosen _ResNet18_ pretrained on ImageNet. This is a significantly shallower network than ResNet50 with less capacity to generalize. ResNet50 has a top-1 error rate of 23.85% on ImageNet, whereas ResNet18 has a top-1 error rate of 30.24%.

After fine-tuning the pretrained ResNet18 model for 25 epochs, I evaluate it's accuracy on CIFAR10.
> The pretrained ResNet18 network achieves an accuracy of 94.95% on the validation set.

In [3]:
student_model = models.resnet18(pretrained=False)
num_ftrs = student_model.fc.in_features
student_model.fc = nn.Linear(num_ftrs, len(valset.classes))
student_model.load_state_dict(torch.load("train_resnet18_25.pt"))
student_model.eval()
student_model.to(device)

get_model_acc(student_model, dataloader)

Accuracy on validation set: 94.9500%


## _Optimized Model_
To optimize the model, I apply Knowledge Distillation in the form proposed by Geoffrey Hinton et al. in _Distilling the Knowledge in a Neural Network_. This involved using ResNet50 as a larger "cumbersome" model to train the "small" model ResNet18. I used the weights of ResNet18 trained on ImageNet as initialization, then trained it with Knowledge Distillation on CIFAR10 for 25 epochs. The Knowledge Distillation loss is defined as:<br>
>$$
\mathcal{L}(x;W) = \alpha * \mathcal{H}(y, \sigma(z_s; T=1)) + \beta * \mathcal{H}(\sigma(z_c; T=\tau), \sigma(z_s, T=\tau))
$$
(source: https://intellabs.github.io/distiller/knowledge_distillation.html)

Here $x$ is the input, $W$ are the small model's parameters, $y$ is the ground truth label. $\mathcal{H}$ is the cross-entropy loss and $\sigma$ is the softmax function parameterized by the temperature $T$. $z_s$ and $z_c$ are the logits of the "small" model and the "cumbersome" model respectively. The first term on the right is the cross-entropy of the small model's output and the target, and the second term is the cross-entropy of the small model outputs and the large model outputs. The total loss is a weighted average parameterized by coefficients $\alpha$ and $\beta$.<br>

I trained ResNet18 using the Knowledge Distillation loss presented above, substituting cross-entropy for the Kulback-Leibler divergence loss function. I used a temperature $T$ of 10 and an $\alpha$ of 0.2, with $\beta = 1 - \alpha$. Hinton et. al report better results with $\alpha$ significanly smaller than $\beta$. After some experimentation with the temperature $T$, I found that setting it to 10 produced adequate results.<br>

Below is a comparison of the number of model parameters in the teacher and student networks.

In [4]:
resnet50_params = sum(param.numel() for param in teacher_model.parameters())
print(f'ResNet50 parameters: {resnet50_params:,}')

ResNet50 parameters: 23,528,522


In [5]:
resnet18_params = sum(param.numel() for param in student_model.parameters())
print(f'ResNet18 parameters: {resnet18_params:,}')
print(f'{100 * (resnet18_params/resnet50_params):.2f}% of the parameters of ResNet50.')

ResNet18 parameters: 11,181,642
47.52% of the parameters of ResNet50.


>ResNet18 has __52.48% fewer parameters__ than ResNet50.

In [6]:
distil_model = models.resnet18(pretrained=False)
num_ftrs = distil_model.fc.in_features
distil_model.fc = nn.Linear(num_ftrs, len(valset.classes))
distil_model.load_state_dict(torch.load("distil_resnet18_25.pt"))
distil_model.eval()
distil_model.to(device)

get_model_acc(distil_model, dataloader)

Accuracy on validation set: 95.4000%


In [7]:
get_latency_data(distil_model)

---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  model_inference        18.29%       1.044ms        99.18%       5.661ms       5.661ms     482.272us         7.86%       5.924ms       5.924ms             1  
                     aten::conv2d         2.54%     145.160us        41.53%       2.370ms     118.511us      97.505us         1.59%       3.828ms     191.378us            20  
                aten::convolution         2.44%     139.185us        38.98%       2.225ms     111.253us      98.559us  

The distilled ResNet18 model achieves an accuracy of 95.4% CIFAR10.
> The inference speed is roughly __2.3x__ faster than ResNet50.