In [None]:
!sudo apt-get -y update
!sudo apt-get -y install python3-pip
!sudo apt-get -y install python-is-python3
!wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz \
  && tar -zxvf flatc.tar.gz \
  && sudo chmod +x flatc \
  && sudo mv flatc /usr/bin/
!pip install -U pip \
  && pip install tensorflow==2.15.0 \
  && pip install -U onnx==1.15.0 \
  && python -m pip install onnx_graphsurgeon \
        --index-url https://pypi.ngc.nvidia.com \
  && pip install -U onnxruntime==1.16.3 \
  && pip install -U onnxsim==0.4.33 \
  && pip install -U simple_onnx_processing_tools \
  && pip install -U onnx2tf \
  && pip install -U protobuf==3.20.3 \
  && pip install -U h5py==3.7.0 \
  && pip install -U psutil==5.9.5 \
  && pip install -U ml_dtypes==0.2.0

In [None]:
!pip install timm torch

In [None]:
import timm
onnx_model_path = "deit_onnx_intermediate_model.onnx"
tf_model_path = "deit_tf_intermediate"
import deit
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_

def gelu():
    return nn.GELU(approximate='tanh')

class GELUapprx(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x * 1.702)

def deit_tiny_distilled_patch16_224(pretrained=True, **kwargs):
    model = deit.DistilledVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0, act_layer=GELUapprx, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"], strict=False)
    return model

sample_input = torch.rand((1, 3, 224, 224))
class CustomDeiTWithPreprocessing(nn.Module):
    def __init__(self, **kwargs):
        super(CustomDeiTWithPreprocessing, self).__init__()

        # Preprocessing layers
        self.normalization = nn.BatchNorm2d(3)
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
        self.variance = torch.tensor([(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2])

        # DeiT model
        self.deit_model = deit_tiny_distilled_patch16_224()

    def forward(self, x):
        # Apply normalization
        x = self.normalization(x)

        # Apply mean and variance scaling
        x = (x - self.mean.view(1, 3, 1, 1)) / torch.sqrt(self.variance.view(1, 3, 1, 1))

        # Forward pass through DeiT model
        x = self.deit_model(x)
        return x


# Create an instance of the custom model
custom_model = CustomDeiTWithPreprocessing()

torch.onnx.export(
    custom_model,                  # PyTorch Model
    sample_input,                    # Input tensor
    onnx_model_path,        # Output file (eg. 'output_model.onnx')
    opset_version=14,       # Operator support version
    input_names=['input'],
    output_names=['output']
)




In [None]:

!zip -r /content/file.zip /content/saved_model

  adding: content/saved_model/ (stored 0%)
  adding: content/saved_model/variables/ (stored 0%)
  adding: content/saved_model/variables/variables.index (deflated 33%)
  adding: content/saved_model/variables/variables.data-00000-of-00001 (deflated 83%)
  adding: content/saved_model/saved_model.pb (deflated 9%)
  adding: content/saved_model/deit_onnx_intermediate_model_float32.keras

In [None]:
from google.colab import files
files.download("/content/file.zip")

In [None]:
!onnx2tf -i deit_onnx_intermediate_model.onnx -oh5

In [None]:
! edgetpu_compiler  quantized_temporal_extractor.tflite  quantized_spatial_extractor.tflite -sa

Edge TPU Compiler version 16.0.384591198
Started a compilation timeout timer of 180 seconds.

Models compiled successfully in 22300 ms.

Input model: quantized_temporal_extractor.tflite
Input size: 2.57MiB
Output model: quantized_temporal_extractor_edgetpu.tflite
Output size: 3.92MiB
On-chip memory used for caching model parameters: 2.33MiB
On-chip memory remaining for caching model parameters: 0.00B
Off-chip memory used for streaming uncached model parameters: 319.62KiB
Number of Edge TPU subgraphs: 9
Total number of operations: 205
Operation log: quantized_temporal_extractor_edgetpu.log

Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs.
Number of operations that will run on Edge TPU: 185
Number of operations that will run on CPU: 20

Operator        

In [None]:
! curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -

! echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list

! sudo apt-get update

! sudo apt-get install edgetpu-compiler

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  2659  100  2659    0     0  35209      0 --:--:-- --:--:-- --:--:-- 35453
OK
deb https://packages.cloud.google.com/apt coral-edgetpu-stable main
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [119 kB]
Get:5 https://packages.cloud.google.com/apt coral-edgetpu-stable InRelease [6,332 B]
Get:6 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [737 kB]
Get:8 http://archive.ubuntu.com/ubuntu 