[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/e2eAIOK/blob/main/demo/ma/distiller/Model_Adapter_Distiller_Walkthrough_VIT_to_ResNet18_CIFAR100.ipynb)

# Model Adapter Distiller Walkthrough DEMO
Model Adapter is a convenient framework can be used to reduce training and inference time, or data labeling cost by efficiently utilizing public advanced models and those datasets from many domains. It mainly contains three components served for different cases: Finetuner, Distiller, and Domain Adapter. 

This demo mainly introduces the usage of Distiller. Take image classification as an example, it shows how to integrate distiller from VIT to ResNet18 on CIFAR100 dataset. This demo shows how to integrate distiller into a general training pipeline, you can find build-in simplied demo at [here](./Model_Adapter_Distiller_builtin_VIT_to_ResNet18_CIFAR100.ipynb).

# Content

* [Overview](#Overview)
    * [Model Adapter Distiller Overview](#Model-Adapter-Distiller-Overview)
* [Getting Started](#Getting-Started)
    * [1. Environment Setup](#1.-Environment-Setup)
    * [2. Data Prepare](#2.-Data-Prepare)
    * [3. Model Prepare](#3.-Model-Prepare)
    * [4. Train](#4.-Train)

# Overview

## Model Adapter Distiller Overview
Distiller is based on knowledge distillation technology, it can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure. Teacher is a large model pretrained on specific dataset, which contains sufficient knowledge for this task, while the student model has much smaller structure. Distiller trains the student not only on the dataset, but also with the help of teacher’s knowledge. With distiller, we can take use of the knowledge from the existing pretrained large models but use much less training time. It can also significantly improve the converge speed and predicting accuracy of a small model, which is very helpful for inference.

<img src="../imgs/distiller.png" width="60%">
<center>Model Adapter Distiller Structure</center>

# Getting Started

## 1. Environment Setup

### (Option 1) Use Pip install
We can directly install ModelAdapter module from Intel® End-to-End AI Optimization Kit with following command.

In [None]:
!pip install e2eAIOK-ModelAdapter --pre

### (Option 2) Use Docker 

We can also use Docker, which contains a complete environment.

Step1. prepare code
   ``` bash
   git clone https://github.com/intel/e2eAIOK.git
   cd e2eAIOK
   git submodule update --init –recursive
   ```
    
Step2. build docker image
   ``` bash
   python3 scripts/start_e2eaiok_docker.py -b pytorch112 --dataset_path ${dataset_path} -w ${host0} ${host1} ${host2} ${host3} --proxy  "http://addr:ip"
   ```
   
Step3. run docker and start conda env
   ``` bash
   sshpass -p docker ssh ${host0} -p 12347
   conda activate pytorch-1.12.0
   ```
  
Step4. Start the jupyter notebook and tensorboard service
   ``` bash
   nohup jupyter notebook --notebook-dir=/home/vmagent/app/e2eaiok --ip=${hostname} --port=8899 --allow-root &
   nohup tensorboard --logdir /home/vmagent/app/data/tensorboard --host=${hostname} --port=6006 & 
   ```
   Now you can visit demso in `http://${hostname}:8899/`, and see tensorboad log in ` http://${hostname}:6006`.

## 2. Data Prepare

Let's import some required modules.

In [None]:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from timm.utils import accuracy
import timm
import transformers
import datetime

First let's define transformer for dataset, which will be needed to augment input image. 

In [None]:
IMAGE_MEAN = [0.5, 0.5, 0.5]
IMAGE_STD = [0.5, 0.5, 0.5]

train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.Resize(224),  # pretrained model is trained on large imgage size, scale 32x32 to 224x224
  transforms.ToTensor(),
  transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])

test_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.Resize(224),  # pretrained model is trained on large imgage size, scale 32x32 to 224x224
  transforms.ToTensor(),
  transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])

Then let's define CIFAR100 dataset and download it with torchvision lib.

In [None]:
data_folder='./data' # dataset location
train_set = datasets.CIFAR100(root=data_folder, train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR100(root=data_folder, train=False, download=True, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


Finally we define dataloader, you can change batch_size and num_workers according to your own environment.

In [None]:
train_loader = DataLoader(dataset=train_set, batch_size=128, shuffle=True, num_workers=1, drop_last=False)
validate_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=True, num_workers=1, drop_last=False)

## 3. Model Prepare

**Prepare Student Model**

First we create a ResNe18 without pretrained weights as backbone model.

In [None]:
backbone = timm.create_model('resnet18', pretrained=False, num_classes=100)

(optional & recommend) Optimized weight initilization, can enhance initial learning.

In [None]:
from e2eAIOK.common.trainer.model.model_utils.model_utils import initWeights
backbone.apply(initWeights)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

**Prepare teacher model**

To use distiller, we need to prepare teacher model to guide the training. Here we select pretrained [vit_base-224-in21k-ft-cifar100 from HuggingFace](https://huggingface.co/edumunozsala/vit_base-224-in21k-ft-cifar100).

In [None]:
from transformers import ViTForImageClassification
teacher_model = ViTForImageClassification.from_pretrained('edumunozsala/vit_base-224-in21k-ft-cifar100')

**Define Distiller**

Here we define a distiller using KD algorithm, and it take a teacher model as input. If teacher comes from Hugginface, please clarify "teacher_type" with a name starting with "huggingface", otherwise no need.

In [None]:
from e2eAIOK.ModelAdapter.engine_core.distiller import KD
distiller= KD(teacher_model,teacher_type="huggingface_vit")

**Wrap model with Model Adapter**

Call the *make_transferrable_with_knowledge_distillation()* function, which take backbone model, distiller and a loss function as input. The output model will have the ability to do the knowledge distillation.

In [None]:
from e2eAIOK.ModelAdapter.engine_core.transferrable_model import *
loss_fn = torch.nn.CrossEntropyLoss()
model = make_transferrable_with_knowledge_distillation(backbone,loss_fn,distiller)

## 4. Train

**Create optimizer**

Create the Optimizer, here we choose SGD.

In [None]:
init_lr = 0.01
weight_decay = 0.0001
momentum = 0.9
optimizer = optim.SGD(model.parameters(),lr=init_lr, weight_decay=weight_decay,momentum=momentum)

**Create scheduler**

Here we choose a *ExponentialLR* scheduler, you can also change to other schedulers for your own task.

In [None]:
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

**Create Trainer**

Create a simple *Trainer*, which contains *train()* and *evaluate()* function for this simple ResNet50 training task.

In [None]:
max_epoch = 1 # max 1 epoch
print_interval = 10 

class Trainer:
    def __init__(self, model, optimizer, scheduler):
        self._model = model
        self._optimizer = optimizer
        self._scheduler = scheduler
        
    def train(self, train_dataloader, valid_dataloader, max_epoch):
        ''' 
        :param train_dataloader: train dataloader
        :param valid_dataloader: validation dataloader
        :param max_epoch: steps per epoch
        '''
        for epoch in range(0, max_epoch):
            ################## train #####################
            self._model.train()  # set training flag
            for (cur_step,(data, label)) in enumerate(train_dataloader):
                self._optimizer.zero_grad()
                output = self._model(data)
                loss_value = self._model.loss(output, label) # transferrable model has loss attribute
                loss_value.backward() 
                if cur_step%print_interval == 0:
                    batch_acc = accuracy(output.backbone_output,label)[0] # use output.backbone_output instead of output
                    dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
                    print("[{}] epoch {} step {} : total loss {:4f}, training backbone loss {:.4f}, distiller loss {:4f}, training batch acc {:.4f}".format(
                      dt, epoch, cur_step, loss_value.total_loss.item(),loss_value.backbone_loss.item(), loss_value.distiller_loss.item(), batch_acc.item())) 
                self._optimizer.step()
            self._scheduler.step()
            ################## evaluate ######################
            self.evaluate(valid_dataloader)
            
    def evaluate(self, valid_dataloader):
        with torch.no_grad():
            self._model.eval()  
            backbone = self._model.backbone # use backbone in evaluation
            loss_cum = 0.0
            sample_num = 0
            acc_cum = 0.0
            total_step = len(valid_dataloader)
            for (cur_step,(data, label)) in enumerate(valid_dataloader):
                output = backbone(data)
                batch_size = data.size(0)
                sample_num += batch_size
                loss_cum += loss_fn(output, label).item() * batch_size
                acc_cum += accuracy(output, label)[0].item() * batch_size
                if cur_step%print_interval == 0:
                    print(f"step {cur_step}/{total_step}")
            dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
            loss_value = loss_cum/sample_num
            acc_value = acc_cum/sample_num

            print("[{}] evaluation loss {:.4f}, evaluation acc {:.4f}".format(
                dt, loss_value, acc_value))

**Train and Evaluate**

Let's start trainer, train for 1 epoch and evaluate the final accuracy.

This will take some time(~45min), have a break and get a coffee!

Here we only train one epoch for a quick test, you may expect a result with accuracy around 0.15~0.17

You can get an optimized and accelerated training with saving logits function, refer to [logits saving demo](Model_Adapter_Distiller_customized_ResNet18_CIFAR100_train_with_logits.ipynb) and [training with saved logits demo](./Model_Adapter_Distiller_customized_ResNet18_CIFAR100_train_with_logits.ipynb) for more details.

In [None]:
%%time
trainer = Trainer(model, optimizer, scheduler)
trainer.train(train_loader,validate_loader,max_epoch)

[2023-02-13 05:49:36] epoch 0 step 0 : total loss 3.099823, training backbone loss 5.4708, distiller loss 2.836383, training batch acc 0.0000
[2023-02-13 05:50:46] epoch 0 step 10 : total loss 2.362927, training backbone loss 4.5393, distiller loss 2.121112, training batch acc 1.5625
[2023-02-13 05:51:57] epoch 0 step 20 : total loss 2.394127, training backbone loss 4.4548, distiller loss 2.165160, training batch acc 2.3438
[2023-02-13 05:53:06] epoch 0 step 30 : total loss 2.356652, training backbone loss 4.4124, distiller loss 2.128232, training batch acc 3.9062
[2023-02-13 05:54:14] epoch 0 step 40 : total loss 2.320713, training backbone loss 4.4267, distiller loss 2.086718, training batch acc 4.6875
[2023-02-13 05:55:21] epoch 0 step 50 : total loss 2.227285, training backbone loss 4.2046, distiller loss 2.007580, training batch acc 6.2500
[2023-02-13 05:56:33] epoch 0 step 60 : total loss 2.237528, training backbone loss 4.3497, distiller loss 2.002843, training batch acc 5.4688
