[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/e2eAIOK/blob/main/demo/ma/distiller/Model_Adapter_Distiller_Walkthrough_VIT_to_ResNet18_CIFAR100_train_with_logits.ipynb)

# Model Adapter Distiller Walkthrough DEMO - train with saved logits
Model Adapter is a convenient framework can be used to reduce training and inference time, or data labeling cost by efficiently utilizing public advanced models and datasets. It mainly contains three components served for different cases: Finetuner, Distiller, and Domain Adapter. 

Distiller is based on knowledge distillation technology, it can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure. However, during the distillation process, teacher forwarding usually takes a lot of time. We can use logits saving function in distiller to save predictions from teacher in adavance, then lots of time can be saved during student training. 

This demo mainly introduces the training of Distiller with saved logits. In last [save logits demo](./Model_Adapter_Distiller_Walkthrough_VIT_to_ResNet18_on_CIFAR100_save_logits.ipynb) we have show how to save predicting logits from VIT on CIFAR100, and here we will show how to use the saved logits to guide the training of student ResNet18.

To use logits saved before for backbone training, we just need to update three steps on [previous pipeline](./Model_Adapter_Distiller_Walkthrough_VIT_to_ResNet18_CIFAR100.ipynb):
- Wrap train_dataset with DataWrapper, set save_logits to False
- When define Distiller, set use_saved_logits to be True
- When epoch changes, call dataset.set_epoch(epoch)

# Content

* [Model Adapter Distiller Overview](#Model-Adapter-Distller-Overview)
* [1. Environment Setup](#1.-Environment-Setup)
* [2. Distiller Training with saved logits](#2.-Distiller-Training-with-saved-logits)
    * [2.1 Prepare Data](#2.1-Prepare-Data)
    * [2.2 Create Transferrable Model](#2.2-Create-Transferrable-Model)
    * [2.3 Train and Evaluate](#2.3-Train-and-Evaluate)

# Model Adapter Distiller Overview
Distiller is based on knowledge distillation technology, it can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure. Teacher is a large model pretrained on specific dataset, which contains sufficient knowledge for this task, while the student model has much smaller structure. Distiller trains the student not only on the dataset, but also with the help of teacher’s knowledge. With distiller, we can take use of the knowledge from the existing pretrained large models but use much less training time. It can also significantly improve the converge speed and predicting accuracy of a small model, which is very helpful for inference.

<img src="../imgs/distiller.png" width="60%">
<center>Model Adapter Distiller Structure</center>

## 1. Environment Setup

### (Option 1) Use Pip install - recommend

In [None]:
!pip install e2eAIOK-ModelAdapter --pre

### (Option 2) Use Docker 

Step1. prepare code
   ``` bash
   git clone https://github.com/intel/e2eAIOK.git
   cd e2eAIOK
   git submodule update --init –recursive
   ```
    
Step2. build docker image
   ``` bash
   python3 scripts/start_e2eaiok_docker.py -b pytorch112 --dataset_path ${dataset_path} -w ${host0} ${host1} ${host2} ${host3} --proxy  "http://addr:ip"
   ```
   
Step3. run docker and start conda env
   ``` bash
   sshpass -p docker ssh ${host0} -p 12347
   conda activate pytorch-1.12.0
   ```
  
Step4. Start the jupyter notebook and tensorboard service
   ``` bash
   nohup jupyter notebook --notebook-dir=/home/vmagent/app/e2eaiok --ip=${hostname} --port=8899 --allow-root &
   nohup tensorboard --logdir /home/vmagent/app/data/tensorboard --host=${hostname} --port=6006 & 
   ```
   Now you can visit demso in `http://${hostname}:8899/`, and see tensorboad log in ` http://${hostname}:6006`.

# 2. Distiller Training with saved logits

Import lib

In [None]:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from timm.utils import accuracy
import timm
import transformers
import datetime

## 2.1 Prepare Data
### Prepare transformer and dataset
For student, we can use original image size 32x32.

Note: Data preprocessor for student and teacher can be different, but for all the process with random augmentation, they must keep same.

In [None]:
IMAGE_MEAN = [0.5, 0.5, 0.5]
IMAGE_STD = [0.5, 0.5, 0.5]

train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])

test_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.ToTensor(),
  transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])

In [None]:
data_folder='./data' # dataset location
train_set = datasets.CIFAR100(root=data_folder, train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR100(root=data_folder, train=False, download=True, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


### Warp dataset with DataWrapper
Warp train dataset with DataWrapper, which helps to load data augmentation information and corresponding saved logits. Remember set save_logits flag to False.

The logits should be saved from last [save logits demo](./Model_Adapter_Distiller_customized_ResNet50_on_CIFAR100_save_logits.ipynb), if not, you can directly download from [here](to-be-added) and put it at `logits_path`

In [None]:
from e2eAIOK.ModelAdapter.engine_core.distiller.utils import logits_wrap_dataset
logits_path = "./data" # path for saved logits
train_set = logits_wrap_dataset(train_set, logits_path=logits_path, num_classes=100, save_logits=False)

### Create dataloader

In [None]:
train_loader = DataLoader(dataset=train_set, batch_size=128, shuffle=True, num_workers=1, drop_last=False)
validate_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=True, num_workers=1, drop_last=False)

## 2.2 Create Transferrable Model

### Create Backbone model

In [None]:
backbone = timm.create_model('resnet18', pretrained=False, num_classes=100)

(optional & recommend) Optimized weight initilization, can enhance initial learning.

In [None]:
from e2eAIOK.common.trainer.model.model_utils.model_utils import initWeights
backbone.apply(initWeights)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

### Create teacher model
We still need to prepare teacher model [vit_base-224-in21k-ft-cifar100 from HuggingFace](https://huggingface.co/edumunozsala/vit_base-224-in21k-ft-cifar100) for a complete Distiller definition. 

In [None]:
from transformers import ViTForImageClassification
teacher_model = ViTForImageClassification.from_pretrained('edumunozsala/vit_base-224-in21k-ft-cifar100')

### Define Distiller
Here we define a distiller using KD algorithm, and it take a teacher model as input.

If teacher comes from Hugginface, please clarify "teacher_type" with a name starting with "huggingface", otherwise no need.

Set use_saved_logits to be True when we want to load logits saved before.

In [None]:
from e2eAIOK.ModelAdapter.engine_core.distiller import KD
distiller= KD(teacher_model, use_saved_logits=True,teacher_type="huggingface_vit")

### Make Model transferrable with distiller

In [None]:
from e2eAIOK.ModelAdapter.engine_core.transferrable_model import *
loss_fn = torch.nn.CrossEntropyLoss()
model = make_transferrable_with_knowledge_distillation(backbone,loss_fn,distiller)

## 2.3 Train and Evaluate

### create optimizer and scheduler

In [None]:
################# create optimizer #################
init_lr = 0.01
weight_decay = 0.005
momentum = 0.9
optimizer = optim.SGD(model.parameters(),lr=init_lr, weight_decay=weight_decay,momentum=momentum)
################# create scheduler #################
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

### Create Trainer

In [None]:
max_epoch = 1 # max 1 epoch
print_interval = 10 

In [None]:
class Trainer:
    def __init__(self, model, optimizer, scheduler):
        self._model = model
        self._optimizer = optimizer
        self._scheduler = scheduler
        
    def train(self, train_dataloader, valid_dataloader, max_epoch):
        ''' 
        :param train_dataloader: train dataloader
        :param valid_dataloader: validation dataloader
        :param max_epoch: steps per epoch
        '''
        for epoch in range(0, max_epoch):
            train_dataloader.dataset.set_epoch(epoch) # Update epoch for dataset
            ################## train #####################
            self._model.train()  # set training flag
            for (cur_step,(data, label)) in enumerate(train_dataloader):
                self._optimizer.zero_grad()
                output = self._model(data)
                loss_value = self._model.loss(output, label) # transferrable model has loss attribute
                loss_value.backward() 
                if cur_step%print_interval == 0:
                    batch_acc = accuracy(output.backbone_output,label)[0] # use output.backbone_output instead of output
                    dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
                    print("[{}] epoch {} step {} : total loss {:4f}, training backbone loss {:.4f}, distiller loss {:4f}, training batch acc {:.4f}".format(
                      dt, epoch, cur_step, loss_value.total_loss.item(),loss_value.backbone_loss.item(), loss_value.distiller_loss.item(), batch_acc.item())) 
                self._optimizer.step()
            self._scheduler.step()
            ################## evaluate ######################
            self.evaluate(valid_dataloader)
            
    def evaluate(self, valid_dataloader):
        with torch.no_grad():
            self._model.eval()  
            backbone = self._model.backbone # use backbone in evaluation
            loss_cum = 0.0
            sample_num = 0
            acc_cum = 0.0
            total_step = len(valid_dataloader)
            for (cur_step,(data, label)) in enumerate(valid_dataloader):
                output = backbone(data)
                batch_size = data.size(0)
                sample_num += batch_size
                loss_cum += loss_fn(output, label).item() * batch_size
                acc_cum += accuracy(output, label)[0].item() * batch_size
                if cur_step%print_interval == 0:
                    print(f"step {cur_step}/{total_step}")
            dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
            loss_value = loss_cum/sample_num
            acc_value = acc_cum/sample_num

            print("[{}] evaluation loss {:.4f}, evaluation acc {:.4f}".format(
                dt, loss_value, acc_value))

## Train and evaluate

In [None]:
%%time
trainer = Trainer(model, optimizer, scheduler)
trainer.train(train_loader,validate_loader,max_epoch)

[2023-02-13 07:29:34] epoch 0 step 0 : total loss 3.773066, training backbone loss 6.0263, distiller loss 3.522707, training batch acc 0.7812
[2023-02-13 07:29:35] epoch 0 step 10 : total loss 2.993048, training backbone loss 5.2592, distiller loss 2.741256, training batch acc 1.5625
[2023-02-13 07:29:36] epoch 0 step 20 : total loss 2.633675, training backbone loss 4.6729, distiller loss 2.407095, training batch acc 3.1250
[2023-02-13 07:29:37] epoch 0 step 30 : total loss 2.505098, training backbone loss 4.5768, distiller loss 2.274914, training batch acc 3.9062
[2023-02-13 07:29:38] epoch 0 step 40 : total loss 2.404030, training backbone loss 4.4975, distiller loss 2.171418, training batch acc 2.3438
[2023-02-13 07:29:38] epoch 0 step 50 : total loss 2.420904, training backbone loss 4.4462, distiller loss 2.195874, training batch acc 3.1250
[2023-02-13 07:29:39] epoch 0 step 60 : total loss 2.248125, training backbone loss 4.3429, distiller loss 2.015368, training batch acc 3.9062
