In [4]:
# From:
# https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

import torch.onnx

from s2cell_ml import S2CellClassifierTask

from datasets import World1, Img2LocCombined
world1 = World1()

# Initialize model with weights from checkpoint
model = S2CellClassifierTask.load_from_checkpoint(
    "checkpoints/s2cell_ml_efn/efn_v2_s2_train1.ckpt",
    model_name="efn_v2_s2",
    label_mapping=world1.label_mapping,
    overfit=False,
)

model.eval()

S2CellClassifierTask(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): Batch

In [6]:
x = torch.randn(1, 3, 224, 224, requires_grad=True, device=model.device)
torch_out = model(x)

torch.onnx.export(
    model,
    x,
    "exports/s2cell_ml_efn_v2_s2_train1.onnx",
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

In [7]:
import onnx

onnx_model = onnx.load("exports/s2cell_ml_efn_v2_s2_train1.onnx")
onnx.checker.check_model(onnx_model)

In [11]:
import onnxruntime

ort_session = onnxruntime.InferenceSession(
    "exports/s2cell_ml_efn_v2_s2_train1.onnx",
    providers=["CPUExecutionProvider"],
)

ort_inputs = {ort_session.get_inputs()[0].name: x.detach().cpu().numpy()}
ort_outputs = ort_session.run(None, ort_inputs)

ort_outputs[0].shape

(1, 930)

In [14]:
import numpy
numpy.testing.assert_allclose(torch_out.detach().cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)