In [1]:
# From:
# https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

import torch.nn as nn
import torch.onnx

from s2cell_ml import S2CellClassifierTask

from datasets import World1, Img2LocCombined
world1 = World1()

# Initialize model with weights from checkpoint
model = S2CellClassifierTask.load_from_checkpoint(
    "checkpoints/s2cell_ml_tvit/f1_01514_2024_08.ckpt",
    model_name="tinyvit_21m_224",
    label_mapping=world1.label_mapping,
    overfit=False,
    export=True,
)

model.eval()

  from .autonotebook import tqdm as notebook_tqdm


Model: tinyvit_21m_224


S2CellClassifierTask(
  (model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (conv): Conv2d(384, 

In [3]:
x = torch.randn(1, 3, 224, 224, requires_grad=True, device=model.device)
torch_out = model(x)

torch.onnx.export(
    model,
    x,
    "exports/s2cell_ml_tvit_release0.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

In [4]:
import onnx

onnx_model = onnx.load("exports/s2cell_ml_tvit_release0.onnx")
onnx.checker.check_model(onnx_model)

In [5]:
import onnxruntime

ort_session = onnxruntime.InferenceSession(
    "exports/s2cell_ml_tvit_release0.onnx",
    providers=["CPUExecutionProvider"],
)

ort_inputs = {ort_session.get_inputs()[0].name: x.detach().cpu().numpy()}
ort_outputs = ort_session.run(None, ort_inputs)

ort_outputs[0].shape

(1, 930)

In [6]:
import numpy
numpy.testing.assert_allclose(torch_out.detach().cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 120 / 930 (12.9%)
Max absolute difference: 0.00059986
Max relative difference: 0.00410528
 x: array([[4.879482e-02, 2.033514e-02, 1.398582e-01, 1.716241e-02,
        1.246499e-03, 1.205014e-01, 2.431710e-02, 1.292941e-02,
        3.916854e-04, 9.244920e-03, 3.085332e-03, 8.270741e-03,...
 y: array([[4.883209e-02, 2.033532e-02, 1.398694e-01, 1.717880e-02,
        1.248211e-03, 1.207024e-01, 2.433154e-02, 1.294973e-02,
        3.929436e-04, 9.265512e-03, 3.088057e-03, 8.274287e-03,...