In [4]:
###### From https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# E is number of epochs
# N is number of samples; B is batch size
# D_in is input dimension; H is hidden dimension; D_out is output dimension.
E, N, B, D_in, H, D_out = 5, 64, 8, 1000, 100, 10

## Use GPU
dtype = torch.float
dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in, device=dev, dtype=dtype)
y = torch.randn(N, D_out, device=dev, dtype=dtype)

train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size=B)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
).to(dev)

loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, )

## Train the model
for epoch in range(E):
    for xb, yb in train_dl:
        # Forward pass: compute predicted y by passing x to the model.
        y_pred = model(xb)

        # Compute and print loss.
        loss = loss_fn(y_pred, yb)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable
        # weights of the model). This is because by default, gradients are
        # accumulated in buffers( i.e, not overwritten) whenever .backward()
        # is called. Checkout docs of torch.autograd.backward for more details.
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()

In [5]:
model.eval()

Sequential(
  (0): Linear(in_features=1000, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=10, bias=True)
)

In [6]:
# Generate input for the model so that it can be exported via tracing
ex = torch.randn(B, D_in, requires_grad=True, device=dev)
torch_out = model(ex)

# Export the model
torch.onnx.export(model,                     # model being run
                  ex,                        # model input (or a tuple for multiple inputs)
                  "threelayer2.onnx",        # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'B'},    # variable length axes
                                'output' : {0 : 'B'}})

In [8]:
# Use ONNX runtime for inferencing

import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession("threelayer2.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(ex)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!
