In [8]:
import torch
import torch.nn as nn

In [9]:
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
    self.batchNorm1 = nn.BatchNorm2d(16)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    self.batchNorm2 = nn.BatchNorm2d(32)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
    self.batchNorm3 = nn.BatchNorm2d(64)
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
    self.batchNorm4 = nn.BatchNorm2d(128)
    self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.fc1 = nn.Linear(18432, 256)
    self.dropout = nn.Dropout(0.5)
    self.fc2 = nn.Linear(256, 8)
  def forward(self, x):
    x = torch.relu(self.batchNorm1(self.conv1(x)))

    x = self.pool1(x)

    x = torch.relu(self.batchNorm2(self.conv2(x)))
    x = self.pool2(x)

    x = torch.relu(self.batchNorm3(self.conv3(x)))
    x = self.pool3(x)

    x = torch.relu(self.batchNorm4(self.conv4(x)))
    x = self.pool4(x)


    x = x.view(-1, 18432)

    x = torch.relu(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)
    return x

In [10]:
state_dict = torch.load("poke_model_v2_10epoch.pth", map_location=torch.device("cpu"))
model = SimpleCNN()
model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
model.eval()
sample_inputs = (torch.randn(1, 3, 200, 200),)
torch_output = model(*sample_inputs)

In [None]:
import numpy
import ai_edge_torch

edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)

ModuleNotFoundError: No module named 'torch_xla'

In [None]:
edge_output = edge_model(*sample_inputs)

In [None]:
if (numpy.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

In [None]:
edge_model.export('model.tflite')