<a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/getting_started.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [10]:
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [19]:
!git clone https://github.com/noahzhy/SALPR.git

Cloning into 'SALPR'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (101/101), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 101 (delta 41), reused 88 (delta 28), pack-reused 0 (from 0)[K
Receiving objects: 100% (101/101), 6.82 MiB | 24.25 MiB/s, done.
Resolving deltas: 100% (41/41), done.


Note: When running notebooks in this repository with Google Colab, some users may see
the following warning message:

![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)

Please click `Restart Session` and run again.

In [11]:
!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt
!pip install ai-edge-torch-nightly



In [12]:
import numpy as np
import ai_edge_torch
import torch
import torchvision

# Sample PyTorch Model

Instantiate `resnet18` as a sample model from PyTorch's `torchvision` package. We also provide it with a sample input and execute it directly via PyTorch.

In [13]:
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)



In [27]:
%cd  /content/SALPR
!python3 eval.py

/content/SALPR
{'model_name': 'tinyLPR', 'lr': 0.0003, 'batch_size': 128, 'epochs': 100, 'eval_freq': 5, 'time_steps': 8, 'blank_id': 0, 'num_classes': 68, 'img_size': [32, 96], 'input_shape': [32, 96, 1], 'seed': 0, 'checkpoint_path': 'backup/m_size_0.9915.pth', 'train': {'maxT': 8, 'image_dir': '/workspace/datasets/lpr/images/train', 'data_aug': True}, 'val': {'maxT': 8, 'image_dir': '/workspace/datasets/lpr/images/val', 'data_aug': False}, 'test': {'maxT': 8, 'image_dir': '/workspace/datasets/lpr/images/test', 'data_aug': False}}
Traceback (most recent call last):
  File "/content/SALPR/eval.py", line 34, in <module>
    edge_model = ai_edge_torch.convert(model, dummy_input)
                 ^^^^^^^^^^^^^
NameError: name 'ai_edge_torch' is not defined


# Conversion
The `convert` function provided by the `ai_edge_torch` package allows conversion from a PyTorch model to an on-device model. The conversion process also requires a model's sample input for tracing and shape inference.

**Note**: The source PyTorch model needs to be compliant with `torch.export` introduced in PyTorch 2.1.0 .

In [14]:
edge_model = ai_edge_torch.convert(resnet18, sample_inputs)

# Inference
Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of [TFLite inference in Python](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python) are abstracted away with this API.

In [15]:
edge_output = edge_model(*sample_inputs)

# Validation
Here, we make sure that the output generated by the on-device prepared model created by `ai_edge_torch` matches the output generated by PyTorch.

In [16]:
if np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

Inference result with Pytorch and TfLite was within tolerance


# Serialization
The on-device prepared model also provides an `export` interface which can be used to serialize the model. This serializes the model as a TFLite Flatbuffers file.

In [17]:
edge_model.export('resnet.tflite')

# Download the tflite flatbuffer which can be used with the existing TfLite APIs.
# from google.colab import files
# files.download('resnet.tflite')

# Visualization
The TFLite flatbuffer can be visualized using the AI Edge Model Explorer.

In [18]:
!pip install ai-edge-model-explorer

import model_explorer
model_explorer.visualize('resnet.tflite')

ℹ️ Please re-run the cell in each new session

Loading extensions...
Loaded 8 extensions:
 - TFLite adapter (Flatbuffer)
 - TFLite adapter (MLIR)
 - TF adapter (MLIR)
 - TF adapter (direct)
 - GraphDef adapter
 - Pytorch adapter (exported program)
 - MLIR adapter
 - JSON adapter


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>