In [1]:
%matplotlib inline


Using a pre-trained Vision Transformer Model
===========================


Vision Transformer models apply the cutting-edge attention-based
transformer models, introduced in Natural Language Processing to achieve
all kinds of the state of the art (SOTA) results, to Computer Vision
tasks. Facebook Data-efficient Image Transformers `DeiT`

https://ai.facebook.com/blog/data-efficient-image-transformers-a-promising-new-technique-for-image-classification>

is a Vision Transformer model trained on ImageNet for image
classification.



What is DeiT
---------------------

Convolutional Neural Networks (CNNs) have been the main models for image
classification since deep learning took off in 2012, but CNNs typically
require hundreds of millions of images for training to achieve the
SOTAresults. DeiT is a vision transformer model that requires a lot less
data and computing resources for training to compete with the leading
CNNs in performing image classification, which is made possible by two
key components of of DeiT:

-  Data augmentation that simulates training on a much larger dataset;
-  Native distillation that allows the transformer network to learn from
   a CNN’s output.

DeiT shows that Transformers can be successfully applied to computer
vision tasks, with limited access to data and resources. For more
details on DeiT, see

https://github.com/facebookresearch/deit

and paper https://arxiv.org/abs/2012.12877




Classifying Images with DeiT
-------------------------------




In [4]:
 !pip install -q timm pandas requests

## Now use DeiT to classify this image:
<img src="https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png">

In [5]:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())

2.6.0+cu124


Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


269


The output should be 269, which, according to the ImageNet list of class
index to

(see https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)

maps to ‘timber wolf, grey wolf, gray wolf, Canis lupus’.




