In [2]:
# export model to onnx format without weights and biases

import sys
sys.path.insert(0, '../src') # to be able to import functions from src

import torch
from unet import *

# network hyperparameters
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
in_channels=3 # rgb

m = ContextUnet(in_channels, n_feat=n_feat, n_cfeat=n_cfeat, height=height)
samples = torch.randn(9, 3, height, height) 
t = torch.tensor([1 / 500])[:, None, None, None]
torch.onnx.export(m,         # model being run 
         (samples, t),       # model input (or a tuple for multiple inputs) 
         "mymodel.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['samples', 'timestamp'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
) 

import onnx
onnx_model = onnx.load("mymodel.onnx")
onnx.checker.check_model(onnx_model)

In [3]:
import onnxruntime as ort
import numpy as np

# Load the ONNX model
onnx_model_path = "mymodel.onnx"
ort_session = ort.InferenceSession(onnx_model_path)

# Prepare input data
samples = np.random.randn(9, 3, 16, 16).astype(np.float32)  # match the shape and data type of your input
t = np.array([1 / 500], dtype=np.float32).reshape(1, 1, 1, 1)  # match the shape and data type of your input

# Run the model
ort_inputs = {
    'samples': samples,
    'timestamp': t
}
ort_outs = ort_session.run(None, ort_inputs)

# Output
model_output = ort_outs[0]
print("Model output shape:", model_output.shape)
print("Model output:", model_output)


Model output shape: (9, 3, 16, 16)
Model output: [[[[ 5.92298627e-01  6.90294087e-01  8.16784322e-01 ...  3.84475201e-01
     4.00609106e-01  1.24330327e-01]
   [ 1.28828660e-01  4.37980622e-01  4.48588639e-01 ...  2.30109394e-01
     3.26428860e-01  2.37278566e-01]
   [ 4.56582129e-01  3.86125833e-01  3.90646636e-01 ...  7.28621006e-01
    -1.96697697e-01  2.56734997e-01]
   ...
   [ 2.56967843e-01  5.45957088e-01  2.14910045e-01 ... -2.23727703e-01
     1.00483537e-01  1.13288477e-01]
   [ 2.76653886e-01  1.05547279e-01  3.31719071e-01 ...  1.17423698e-01
    -7.99360126e-02  7.07309470e-02]
   [-1.03338212e-01  4.02984358e-02  3.21439117e-01 ... -7.60961175e-02
     1.32259667e-01 -7.89057016e-02]]

  [[ 1.78859502e-01 -2.05364972e-01  1.34391665e-01 ... -1.08623803e-01
     1.49976552e-01 -3.28602940e-02]
   [ 1.29437521e-01  4.53616858e-01 -3.57999861e-01 ... -4.65477765e-01
    -3.07404876e-01  1.63260832e-01]
   [-1.35039300e-01 -4.28656638e-01  5.84466755e-01 ... -9.36597660e-0

