In [1]:
import torch
from torch import nn
import onnx
import onnxruntime as ort

from model import COREModel

In [2]:
import sys
print(sys.version)
print(torch.__version__)

3.9.6 (default, Oct 18 2022, 12:41:40) 
[Clang 14.0.0 (clang-1400.0.29.202)]
1.13.1


In [3]:
x_train = torch.ones([1,2], dtype=torch.long)

### Load big model that uses a torch.tril function

In [4]:
model = COREModel()
model.eval()
print(model.forward(x_train))

tensor([[ 2.2729, 14.2857,  1.7739,  0.5872, -1.4148]], grad_fn=<DivBackward0>)


# Export big model in onnx format to disk

In [5]:
bigmodel_onnx_filename = 'bigmodel.onnx'
torch.onnx.export(
    model,
    x_train,
    bigmodel_onnx_filename,
    input_names=['x'],
    output_names=['output'],
)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


### Print the Onnx Nodes in the big model. Confirming it contains Trilu nodes

In [6]:
onnx.load(bigmodel_onnx_filename)

ir_version: 7
producer_name: "pytorch"
producer_version: "1.13.1"
graph {
  node {
    output: "/item_embedding/Constant_output_0"
    name: "/item_embedding/Constant"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
  }
  node {
    input: "item_embedding.weight"
    input: "x"
    output: "/item_embedding/Gather_output_0"
    name: "/item_embedding/Gather"
    op_type: "Gather"
  }
  node {
    output: "/net/Constant_output_0"
    name: "/net/Constant"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
  }
  node {
    input: "x"
    input: "/net/Constant_output_0"
    output: "/net/Greater_output_0"
    name: "/net/Greater"
    op_type: "Greater"
  }
  node {
    output: "/net/Constant_1_output_0"
    name: "/net/Constant_1"
    op_t

### Loading the Onnx model

`
[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/net/Trilu'
`

In [7]:
# When loading neural graph, Onnx will crash with a Trilu Node NOT_IMPLEMENTED error
ort_sess = ort.InferenceSession(bigmodel_onnx_filename, providers=['CPUExecutionProvider'])
key = {'x': x_train.numpy()}
ort_sess.run(None, key)

NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/net/Trilu'

### Demonstrate that a nested model with torch.tril() working with Onnx

In [8]:
class SubModel(nn.Module):
    def forward(self, x):
        super(SubModel, self).__init__()
        return torch.tril(x)
    
class TrilModel(nn.Module):
    def __init__(self):
        super(TrilModel, self).__init__()
        self.net = SubModel()
    
    def forward(self, x):
        return self.net.forward(x)


In [9]:
model2 = TrilModel()
model2.eval()
model2.forward(x_train)

tensor([[1, 0]])

In [10]:
torch.onnx.export(
    model2,
    x_train,
    'mymodel2.onnx',
    input_names=['x'],
    output_names=['output'],
)

In [11]:
onnx.load('mymodel2.onnx')

ir_version: 7
producer_name: "pytorch"
producer_version: "1.13.1"
graph {
  node {
    output: "/Constant_output_0"
    name: "/Constant"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 7
        raw_data: "\000\000\000\000\000\000\000\000"
      }
      type: TENSOR
    }
  }
  node {
    input: "x"
    input: "/Constant_output_0"
    output: "output"
    name: "/Trilu"
    op_type: "Trilu"
    attribute {
      name: "upper"
      i: 0
      type: INT
    }
  }
  name: "torch_jit"
  input {
    name: "x"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 2
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 7
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 2
          }
        }
      }
    }
  }
}
opset_impor

In [12]:
ort_sess = ort.InferenceSession('mymodel2.onnx', providers=['CPUExecutionProvider'])
key = {'x': x_train.numpy()}
ort_sess.run(None, key)

[array([[1, 0]], dtype=int64)]