# AIOK Model Adapter Distiller Customized DEMO - Save logits
Model Adapter is a convenient framework can be used to reduce training and inference time, or data labeling cost by efficiently utilizing public advanced models and datasets. It mainly contains three components served for different cases: Finetuner, Distiller, and Domain Adapter. 

This demo mainly introduces the usage of Distiller saving logits function. Take image classification as an example, it shows how to use distiller to save logits from ResNet50 pretrained model, which will be used to guide the learning of ResNet18 in next [demo](./Model_Adapter_Distiller_customized_ResNet18_CIFAR100_train_with_logits).

# Content

* [Model Adapter Distiller Overview](#Model-Adapter-Distller-Overview)
* [Environment Setup](#Environment-Setup)
* [Save Logits with Distiller](#Save-Logits-with-Distiller)
    * [Prepare Data](#Prepare-Data)
    * [Create Distiller](#Create-Distiller)
    * [Save Logits](#Save-Logits)

# Model Adapter Distiller Overview
Distiller is based on knowledge distillation technology, it can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure. Teacher is a large model pretrained on specific dataset, which contains sufficient knowledge for this task, while the student model has much smaller structure. Distiller trains the student not only on the dataset, but also with the help of teacher’s knowledge. With distiller, we can take use of the knowledge from the existing pretrained large models but use much less training time. It can also significantly improve the converge speed and predicting accuracy of a small model, which is very helpful for inference.

<img src="../imgs/distiller.png" width="60%">
<center>Model Adapter Distiller Structure</center>

However, during the distillation process, teacher forwarding usually takes a lot of time. We can use logits saving function in distiller to save predictions from teacher in adavance, then lots of time can be saved during student training. In this notebook, we will show how to use logits saving function, and here we still take ResNet50 on CIFAR100 as an example.

To enable saving logits function, we just need to add two steps in original pipeline:
- Wrap train_dataset with DataWrapper
- Call prepare_logits() in Distiller

# Environment Setup

1. prepare code
    ``` bash
    git clone https://github.com/intel/e2eAIOK.git
    cd e2eAIOK
    git submodule update --init –recursive
    ```
2. build docker image
   ```
   cd Dockerfile-ubuntu18.04 && docker build -t e2eaiok-pytorch112 . -f DockerfilePytorch112 && cd .. && yes 
   ```
3. run docker
   ``` bash
   docker run -it --name model_adapter --shm-size=10g --privileged --network host \
   -v ${dataset_path}:/home/vmagent/app/data  \
   -v `pwd`:/home/vmagent/app/e2eaiok \
   -w /home/vmagent/app/e2eaiok e2eaiok-pytorch112 /bin/bash 
   ```
4. Run in conda and set up e2eAIOK
   ```bash
   conda activate pytorch-1.12.0
   python setup.py sdist && pip install dist/e2eAIOK-*.*.*.tar.gz
   ```
5. Start the jupyter notebook and tensorboard service
   ``` bash
   nohup jupyter notebook --notebook-dir=/home/vmagent/app/e2eaiok --ip=${hostname} --port=8899 --allow-root &
   nohup tensorboard --logdir /home/vmagent/app/data/tensorboard --host=${hostname} --port=6006 & 
   ```
   Now you can visit demso in `http://${hostname}:8899/`, and see tensorboad log in ` http://${hostname}:6006`.

# Save Logits with Distiller

Import lib

In [1]:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import timm
import sys,os

## Prepare Data
### Prepare transformer and dataset
For teacher, as pretrained model is trained on large imgage size, scale 32\*32 to 112\*112

In [2]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) # mean for 3 channels
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)  # std for 3 channels

train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.Resize(112), 
  transforms.ToTensor(),
  transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])

In [3]:
data_folder='/home/vmagent/app/data/dataset/cifar' # dataset location
train_set = datasets.CIFAR100(root=data_folder, train=True, download=True, transform=train_transform)

Files already downloaded and verified


### Warp dataset with DataWrapper
Warp train dataset with DataWrapper, which helps to save data augmentation information during the forwarding of teacher model.

In [4]:
from e2eAIOK.ModelAdapter.engine_core.distiller.utils import logits_wrap_dataset
logits_path = "/home/vmagent/app/data/model/demo/distiller/cifar100_kd_res50PretrainI21k/logits_demo" # path to save the logits
train_set = logits_wrap_dataset(train_set, logits_path=logits_path, num_classes=100, save_logits=True)

### Create dataloader

Note: We need to save all the data without any sampling, make sure you have disable "channel_last" or "sampler" in dataloader, which can avoid data lossing in later logits using process.

In [5]:
train_loader = DataLoader(dataset=train_set, batch_size=128, shuffle=True, num_workers=1, drop_last=False)

## Create Distiller

### Create teacher model
To use distiller, we need to prepare teacher model to guide the training. Directly download teacher model Resnet50 pretrained on CIFAR100 from [here](), and put it at `${dataset}/model/demo/baseline/cifar100_res50PretrainI21k/cifar100_res50_pretrain_imagenet21k.pth`.

In [6]:
pretrain_model = "/home/vmagent/app/data/model/demo/baseline/cifar100_res50PretrainI21k/cifar100_res50_pretrain_imagenet21k.pth"
teacher_model = timm.create_model('resnet50', pretrained=False, num_classes=100)
teacher_model.load_state_dict(torch.load(pretrain_model), strict=True)

<All keys matched successfully>

### Create Distiller with KD type
Here we define a distiller using KD algorithm, and it take a teacher model as input.

In [7]:
from e2eAIOK.ModelAdapter.engine_core.distiller import KD
distiller= KD(teacher_model)

## Save Logits

Call prepare_logits() of distiller to save the logits.

In [8]:
distiller.prepare_logits(train_loader, epochs=1)

2023-02-07 08:34:20 save 0/391
2023-02-07 08:34:32 save 10/391
2023-02-07 08:34:43 save 20/391
2023-02-07 08:34:54 save 30/391
2023-02-07 08:35:05 save 40/391
2023-02-07 08:35:17 save 50/391
2023-02-07 08:35:28 save 60/391
2023-02-07 08:35:39 save 70/391
2023-02-07 08:35:50 save 80/391
2023-02-07 08:36:02 save 90/391
2023-02-07 08:36:13 save 100/391
2023-02-07 08:36:24 save 110/391
2023-02-07 08:36:35 save 120/391
2023-02-07 08:36:46 save 130/391
2023-02-07 08:36:59 save 140/391
2023-02-07 08:37:25 save 150/391
2023-02-07 08:37:51 save 160/391
2023-02-07 08:38:19 save 170/391
2023-02-07 08:38:42 save 180/391
2023-02-07 08:39:10 save 190/391
2023-02-07 08:39:33 save 200/391
2023-02-07 08:39:58 save 210/391
2023-02-07 08:40:15 save 220/391
2023-02-07 08:40:27 save 230/391
2023-02-07 08:40:39 save 240/391
2023-02-07 08:40:51 save 250/391
2023-02-07 08:41:07 save 260/391
2023-02-07 08:41:20 save 270/391
2023-02-07 08:41:33 save 280/391
2023-02-07 08:41:47 save 290/391
2023-02-07 08:42:00 s