# Save logits from VIT with Distiller on CIFAR100
Distiller can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure.

* Teacher is a large model pretrained on specific dataset, which contains sufficient knowledge for this task, while the student model has much smaller structure. Distiller trains the student not only on the dataset, but also with the help of teacherâ€™s knowledge.
* Distiller can take use of the knowledge from the existing pretrained large models but use much less training time. It can also significantly improve the converge  speed and predicting accuracy of a small model, which is very helpful for inference.
![Distiller](../doc/imgs/distiller.png)

However, during the distillation process, teacher forwarding usually takes a lot of time. We can use logits saving function in distiller to save predictions from teacher in adavance, then lots of time can be saved during student training.

## Notebook Content
In this notebook, we will show how to use logits saving function, and here we still take VIT as an example.

To enable saving logits function, we just need to add two steps in original pipeline:
- Wrap train_dataset with DataWrapper
- Call prepare_logits() in Distiller and exit

Note: We need to save all the data without any sampling, make sure you have disable "channel_last" or "sampler" in dataloader, which can avoid data lossing in later logits using process.

## Environment Setup

In [1]:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import torch.optim as optim
import timm
import transformers
import datetime

In [2]:
import sys
sys.path.append("/home/vmagent/app/e2eAIOK/e2eAIOK/ModelAdapter/")
from ModelAdapter.engine_core.transferrable_model import make_transferrable_with_knowledge_distillation
from ModelAdapter.engine_core.distiller import KD
from ModelAdapter.engine_core.distiller.utils import logits_wrap_dataset

## Prepare Data
### Define Data Preprocessor for teacher
For teacher, as pretrained model is trained on large imgage size, scale 32x32 to 224x224

In [3]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) # mean for 3 channels
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)  # std for 3 channels

train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.Resize(224), 
  transforms.ToTensor(),
  transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])

### Prepare and warp dataset 

In [4]:
batch_size = 128
num_workers = 1 # data worker
data_folder='./dataset' # dataset location
train_set = datasets.CIFAR100(root=data_folder, train=True, download=True, transform=train_transform)

Files already downloaded and verified


Warp train dataset with DataWrapper, which helps to save data augmentation information

In [5]:
logits_path = './logits1' # path to save the logits
save_logits = True # save logits
train_set = logits_wrap_dataset(train_set, logits_path=logits_path, num_classes=100, save_logits=save_logits)

Create dataloader with new dataset.

In [6]:
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=False)

## Create Distiller and save logits
When define distiller, we need to define teacher_type with a name start with "huggingface" if the teacher model comes from hugging face. Otherwise, don't need to set it.

In [7]:
%%time
loss_fn = torch.nn.CrossEntropyLoss()
teacher_model = transformers.ViTForImageClassification.from_pretrained('edumunozsala/vit_base-224-in21k-ft-cifar100')
distiller= KD(teacher_model,teacher_type="huggingface_vit_base-224-in21k-ft-cifar100") #if model from huggingface, set teacher_type, otherwise no need

CPU times: user 3.05 s, sys: 353 ms, total: 3.4 s
Wall time: 21 s


Call prepare_logits() to save the logits

In [8]:
%%time
max_epoch = 1
distiller.prepare_logits(train_loader, max_epoch, device = "cpu")

Epoch 0 took 3.130923271179199 seconds
CPU times: user 1min 43s, sys: 21.3 s, total: 2min 4s
Wall time: 3.13 s
