# Converting ONNX models to GoMLX

## Imports and setup

In [2]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx".


In [3]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"

    "github.com/janpfeifer/onnx-gomlx"
	"github.com/janpfeifer/must"

    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
	"github.com/pkg/errors"
)

var _ = Add

## Linear Model

The `linear_regression.onnx` was created manually using python:

```python
X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])
A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None])
B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None])

# outputs, the shape is left undefined
Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])

# It creates a node defined by the operator type MatMul,
# 'X', 'A' are the inputs of the node, 'XA' the output.
node1 = make_node('MatMul', ['X', 'A'], ['XA'],'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')

# from nodes to graph the graph is built from the list of nodes, the list of inputs, the list of outputs and a name.
graph = make_graph([node1, node2],  # nodes
                    'lr',  # a name
                    [X, A, B],  # inputs
                    [Y])  # outputs

# onnx graph there is no metadata in this case.
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

```

In [9]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// read the onnx model
model := must.M1(onnxgomlx.ReadONNXFile(modelPath))
graph := model.Proto.Graph
for _, n := range graph.Node {
    fmt.Printf("%q:\t[%s]\n", n.GetName(), n.GetOpType())
    fmt.Printf("\tInputs: %q\n", n.GetInput())
    fmt.Printf("\tOutputs: %q\n", n.GetOutput())
}

"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [10]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/model.onnx") // all-MiniLM-L6-v2

%%
// read the onnx model
model := must.M1(onnxgomlx.ReadONNXFile(modelPath))
graph := model.Proto.Graph
opTypes := types.MakeSet[string]()
for _, n := range graph.Node {
    opTypes.Insert(n.GetOpType())
}
fmt.Printf("%d nodes, %d different ops\n", len(graph.Node), len(opTypes))
fmt.Printf("\tOps: %q\n", opTypes)



780 nodes, 23 different ops
	Ops: map["Add":{} "Cast":{} "Concat":{} "Constant":{} "ConstantOfShape":{} "Div":{} "Equal":{} "Erf":{} "Expand":{} "Gather":{} "MatMul":{} "Mul":{} "Pow":{} "ReduceMean":{} "Reshape":{} "Shape":{} "Slice":{} "Softmax":{} "Sqrt":{} "Sub":{} "Transpose":{} "Unsqueeze":{} "Where":{}]
