In [1]:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0

1.9.0+cu111


In [2]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])


Downloading: "https://github.com/facebookresearch/deit/archive/main.zip" to /home/curt/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/curt/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


HBox(children=(FloatProgress(value=0.0, max=346319111.0), HTML(value='')))






In [3]:
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())

269


## scripting DeiT

In [4]:
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")

## quantizing DeiT

In [5]:
# Use 'fbgemm' for server inference and 'qnnpack' for mobile inference
backend = "fbgemm" # replaced with qnnpack causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")

In [7]:
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())

269


## optimizing DeiT

In [8]:
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

In [9]:
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())

269


  return forward_call(*input, **kwargs)


In [10]:
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

## Compare

In [13]:
for use_cuda in [False, True]:
    print('use_cuda', use_cuda)
    with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof1:
        out = model(img)
    with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof2:
        out = scripted_model(img)
    with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof3:
        out = scripted_quantized_model(img)
    with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof4:
        out = optimized_scripted_quantized_model(img)
    with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof5:
        out = ptl(img)

    print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
    print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
    print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
    print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
    print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))

use_cuda False
original model: 226.90ms
scripted model: 245.94ms
scripted & quantized model: 157.23ms
scripted & quantized & optimized model: 266.45ms
lite model: 238.65ms
use_cuda True
original model: 243.89ms
scripted model: 253.77ms
scripted & quantized model: 162.95ms
scripted & quantized & optimized model: 236.08ms
lite model: 233.64ms
