# DieT

Vision transformers (ViT) have been able to achieve state-of-the-art performance on ImageNet without using convolution. Even ViT was only able to achieve this when trained with a large private labelled image dataset using extensive computing resources. In their paper, “Training data-efficient image transformers & distillation through attention”, Hugo Touvron, Matthieu Cord, et al. proposed a convolution-free transformer network, DeiT, that achieves top-1 accuracy of 83.1% on ImageNet with no external data. DeiT introduces a new teacher-student strategy specific to transformers that relies on a distillation token, similar to the class token already employed in transformer networks. 

To read about it more, please refer [this](https://analyticsindiamag.com/introducing-deit-data-efficient-image-transformers/) article.

# Image Classification with a pre-trained DeiT model

Install PyTorch Image Models (timm) 

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn tensorflow keras opencv-python pillow scikit-image torch torchvision \
     tqdm --user -q --no-warn-script-location

!python -m pip install timm==0.3.2 --user -q

import IPython
IPython.Application.instance().kernel.do_shutdown(True)


Download ImageNet class labels and create a list.

In [None]:
# !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
# Read the ImageNet categories
with open("imagenet_classes.txt", "r") as f:
    imagenet_categories = [s.strip() for s in f.readlines()] 

Import necessary libraries and classes

In [None]:
from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import torch
import timm
import torchvision
import torchvision.transforms as T
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
torch.set_grad_enabled(False); 

Create the data transform expected by DeiT

In [None]:
transform = T.Compose([
  T.Resize(256, interpolation=3),
  T.CenterCrop(224),
  T.ToTensor(),
  T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
]) 

Load the pre-trained model from TorchHub and get an image to perform inference on.

In [None]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval();
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# display the image 
im

  Transform the image and perform inference.

In [None]:
# transform the original image and add a batch dimension
img = transform(im).unsqueeze(0)

# compute the predictions
out = model(img)

# and convert them into probabilities
scores = torch.nn.functional.softmax(out, dim=-1)[0]

# get the index of the prediction with highest score
topk_scores, topk_label = torch.topk(scores, k=5, dim=-1)
for i in range(5):
  pred_name = imagenet_categories[topk_label[i]]
  print(f"Prediction index {i}: {pred_name:<25}, score: {topk_scores[i].item():.3f}")