## Optimisation for Mobile

This notebook contains all tools to save and optimise a trained PyTorch model for mobile.

In [None]:
import sys
sys.path.insert(0, "../src")

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision

from config import *
from utils import *
from model import MODELS, FinetunedImageClassifier
from transform import ImageTransformer
from data import ImageDataset

## Load Trained Classifier

In [None]:
# specify which model to use
MODEL = "resnet18"
assert MODEL in MODELS, f"Specified model has to be one of {list(MODELS.keys())}"

In [None]:
# specify the paths of the most recently trained model
model_path = os.path.join(MODEL_PATH, MODEL, f"{MODEL}.pt")
config_path = os.path.join(MODEL_PATH, MODEL, "config.json")
transforms_path = os.path.join(MODEL_PATH, MODEL, "transforms.pkl")

In [None]:
# load transform
transform = load_pickle(transforms_path)

In [None]:
# load model
config = load_json(config_path)
class2id = config['class2id']
id2class = {i:c for c,i in class2id.items()}
model = FinetunedImageClassifier(**config)
model.load_state_dict(torch.load(model_path))

## Inference on Test Split

In [None]:
# define test split and loader
test_data = ImageDataset(split="test", include_classes=list(class2id.keys()), ratio=1.0)
test_loader = DataLoader(test_data, 16)

# load batch of 16 images
images, labels = next(iter(test_loader))

# predict
logits = model(transform(images))
preds = logits.argmax(-1)

# show images with ground truth
show_images(images, titles=[f"True: {id2class[t.item()]}\nPred: {id2class[p.item()]}" for t, p in zip(labels, preds)], show=True)

In [None]:
# inference on dummy
dummy = torch.rand(3, 224, 224).unsqueeze(0)
logits = model(dummy)

logits.shape

## Optimisation

---

From now on we are following the docs on [PyTorch Mobile](https://pytorch.org/mobile). PyTorch supports deploying trained machine learning models on mobile devices (by default: iOS and Android). This page summarises the necessary steps:

When a PyTorch model is trained or retrained, or when a pre-trained model is available, for mobile deployment, follow the the recipes outlined in this summary so mobile apps can successfully use the model:

1. **Fusing**. To fuse a list of PyTorch modules into a single module to reduce the model size before quantization, read the [Fuse Modules recipe](https://pytorch.org/tutorials/recipes/fuse.html).
2. **Quantisation.** To reduce the model size and make it run faster without losing much on accuracy, read the [Quantization Recipe](https://pytorch.org/tutorials/recipes/quantization.html).
3. **TorchScript.** To convert the model to TorchScipt and (optional) optimize it for mobile apps, read the [Script and Optimize for Mobile Recipe](https://pytorch.org/tutorials/recipes/script_optimized.html).

### Fusing

---

Following the [PyTorch Fusing Recipe](https://pytorch.org/tutorials/recipes/fuse.html): Model fusing is done before model quantisation. It describes the process of combining multiple PyTorch modules into a single modules to reduce its size and memory footprint. This may make the model **run faster** and **improve its accuracy**.

_Note: Fusing is skipped at this point._

In [None]:
# skipped

### Quantisation

---

Follwoing the [PyTorch Quantisation Recipe](https://pytorch.org/tutorials/recipes/quantization.html): Quantisation describes the process of converting model's weights and activations from a (default) 32-bit float representation to  to 8-bit ints. This process reduces the model's size to 1/4 (25%) of its original size and speeds up inference between 2-4x while maintaining equal or similar model accuracy.

There are generally three approaches to use mobile quantisation:

1. Use Pretrained Quantized Models: This approach is easy but only works for a subset of the models on Torchvision's Model Hub. There is support for `MobileNet v2`, `ResNet 18`, `ResNet 50`, `Inception v3`, `GoogleNet` and some more. 
 
2. Post Training Dynamic Quantisation: Not yet supported for convolutional layers in CNNs and therefore disregarded here.
 
3. Post Training Static Quantisation: Convert all weights and activation to the smaller data type after training is completed. This approach is arguably the easiest to implement. 
 
 
4. Quantization-aware Training: Inserts fake quantisation to all weights and activations during training. Often used in CNN architectures.

In [None]:
# post training static quantisation

model.eval()
backend = "qnnpack" # for arm cpu (for x86 architectures, chooose 'fbgemm'
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
qmodel = torch.quantization.prepare(model, inplace=False)
qmodel = torch.quantization.convert(qmodel, inplace=False)

## TorchScript

---

Following the [PyTorch Script and Optimize for Mobile Receip](https://pytorch.org/tutorials/recipes/script_optimized.html). To run a fused and quantised model in high-performance C++ environments (like iOS and Android), the model needs to be converted to `TorchScript` and can optionally be further optimised.

There are two basic ways to convert a PyTorch model to TorchScript:

1. The Trace Method: Uses a dummy input for the model. Only works if the model does not have any control flow.    

2. The Script Method: 

In [None]:
# convert to torchscript
torchscript_model = torch.jit.script(model)

In [None]:
from torch.utils.mobile_optimizer import optimize_for_mobile

# optimise for mobile
optimised_torchscript_model = optimize_for_mobile(torchscript_model)

In [None]:
# save model to disk
torchscript_model_path = os.path.join(MODEL_PATH, MODEL, f"{MODEL}.pth")

optimised_torchscript_model.save(torchscript_model_path)

## Save as PyTorch Lite

---



In [None]:
# save as pytorch lite
lite_model_path = os.path.join(MODEL_PATH, MODEL, f"{MODEL}.ptl")
optimised_torchscript_model._save_for_lite_interpreter(lite_model_path)

In [None]:
# compare size of two model
print(f"Original Model Size: {round(os.path.getsize(model_path) / 1000 ** 2, 1)} MB")
print(f"TorchScript Model Size: {round(os.path.getsize(torchscript_model_path) / 1000 ** 2, 1)} MB")
print(f"PyTorch Lite Model Size: {round(os.path.getsize(lite_model_path) / 1000 ** 2, 1)} MB")