<a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/getting_started.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Note: When running notebooks in this repository with Google Colab, some users may see
the following warning message:

![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)

Please click `Restart Session` and run again.

In [2]:
!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
!pip install ai-edge-torch

Looking in links: https://download.pytorch.org/whl/nightly/torch_nightly.html, https://download.pytorch.org/whl/nightly/torch_nightly.html, https://download.pytorch.org/whl/nightly/torch_nightly.html
Ignoring torch_xla: markers 'python_version == "3.11"' don't match your environment
Collecting torch_xla@ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240429-cp310-cp310-linux_x86_64.whl (from -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt (line 14))
  Using cached https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240429-cp310-cp310-linux_x86_64.whl (83.1 MB)
Ignoring torch_xla: markers 'python_version == "3.9"' don't match your environment


In [3]:
import numpy as np
import ai_edge_torch
import torch
import torchvision

  self.pid = os.fork()


# Sample PyTorch Model

Instantiate `resnet18` as a sample model from PyTorch's `torchvision` package. We also provide it with a sample input and execute it directly via PyTorch.

In [4]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = model(*sample_inputs)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 102MB/s]


# Conversion
The `convert` function provided by the `ai_edge_torch` package allows conversion from a PyTorch model to an on-device model. The conversion process also requires a model's sample input for tracing and shape inference.

**Note**: The source PyTorch model needs to be compliant with `torch.export` introduced in PyTorch 2.1.0 .

In [5]:
edge_model = ai_edge_torch.convert(model, sample_inputs)

# Inference
Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of [TFLite inference in Python](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python) are abstracted away with this API.

In [6]:
edge_output = edge_model(*sample_inputs)

# Validation
Here, we make sure that the output generated by the on-device prepared model created by `ai_edge_torch` matches the output generated by PyTorch.

In [7]:
if np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

Inference result with Pytorch and TfLite was within tolerance


# Serialization
The on-device prepared model also provides an `export` interface which can be used to serialize the model. This serializes the model as a TFLite Flatbuffers file.

In [8]:
edge_model.export('model.tflite')

# Download the tflite flatbuffer which can be used with the existing TfLite APIs.
# from google.colab import files
# files.download('resnet.tflite')

# Visualization
The TFLite flatbuffer can be visualized using the AI Edge Model Explorer.

In [9]:
!pip install ai-edge-model-explorer

import model_explorer
model_explorer.visualize('model.tflite')

Collecting ai-edge-model-explorer
  Downloading ai_edge_model_explorer-0.1.0-1-py3-none-any.whl (2.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/2.2 MB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.2/2.2 MB[0m [31m32.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
Collecting ai-edge-model-explorer-adapter==0.1.1 (from ai-edge-model-explorer)
  Downloading ai_edge_model_explorer_adapter-0.1.1-cp310-cp310-manylinux_2_17_x86_64.whl (82.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.3/82.3 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting jedi>=0.16 (from ipython->ai-edge-model-explorer)
  Downloading jedi-0.19.1-py2.py3-none-a

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>