# Converting ONNX models to GoMLX

## Imports and setup

In [6]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Replace rule for module "github.com/gomlx/onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx" already exists.
	- Replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx" already exists.
	- Replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb" already exists.


In [13]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"
	"github.com/janpfeifer/must"

	"github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
    "github.com/gomlx/onnx-gomlx/onnx"
	"github.com/pkg/errors"

	_ "github.com/gomlx/gomlx/backends/xla"
)

var (
    _ = Add
    backend = backends.New()
)


## Linear Model

The `linear_regression.onnx` was created manually using python (see accompanying `onnx-py.ipynb` notebook):

```python
X = make_tensor_value_info('X', TensorProto.FLOAT, ["batch_size", "feature_dim"])
A = make_tensor_value_info('A', TensorProto.FLOAT, ["feature_dim"])
B = make_tensor_value_info('B', TensorProto.FLOAT, [])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, ["batch_size"])
node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')
graph = make_graph([node1, node2],  # nodes
                    'lr',  # a name
                    [X, A, B],  # inputs
                    [Y])  # outputs
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
```

In [17]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)

// Print graph:
graph := model.Proto.Graph
for _, n := range graph.Node {
    fmt.Printf("%q:\t[%s]\n", n.GetName(), n.GetOpType())
    fmt.Printf("\tInputs: %q\n", n.GetInput())
    fmt.Printf("\tOutputs: %q\n", n.GetOutput())
}

// Try to execute it with GoMLX/XLA:
gomlxFn := func(x, a, b *Node) *Node {
    return model.BuildGraph(nil, []*Node{x, a, b})[0]
}

x := [][]float32{{1.0, 2.0, 3.0}, {100.0, 101.0, 102.0}}
a := []float32{1.0, 0.1, 0.01}
b := float32(1000.0)
fmt.Printf("\nX*A+B=%v\n", ExecOnce(backend, gomlxFn, x, a, b))


ONNX Model:
	# inputs:	3
		[#0] X: (Float32) [batch_size, feature_dim]
		[#1] A: (Float32) [feature_dim]
		[#2] B: (Float32)
	# outputs:	1
		[#0] Y: (Float32) [batch_size]
	# nodes:	2
	Op types:	[]string{"Add", "MatMul"}
	IR Version:	10
	Operator Sets:	[v21]

"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]

X*A+B=(Float32)[2]: [1001.23 1111.12]


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [9]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/model.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
var modelStr string
err := exceptions.Try(func() { modelStr = model.String() })
if err != nil {
    log.Fatalf("Error: %+v", err)
}
fmt.Printf("%s\n", modelStr)
// Print graph:
graph := model.Proto.Graph
for _, n := range graph.Node {
    fmt.Printf("%q:\t[%s]\n", n.GetName(), n.GetOpType())
    fmt.Printf("\tInputs: %q\n", n.GetInput())
    fmt.Printf("\tOutputs: %q\n", n.GetOutput())
}

ONNX Model:
	Producer:	pytorch / 2.5.0
	# inputs:	3
		[#0] input_ids: (Int64) [batch_size, sequence_length]
		[#1] attention_mask: (Int64) [batch_size, sequence_length]
		[#2] token_type_ids: (Int64) [batch_size, sequence_length]
	# outputs:	1
		[#0] last_hidden_state: (Float32) [batch_size, sequence_length, 384]
	# nodes:	780
	# tensors (variables):	101
		"embeddings.word_embeddings.weight": (Float32)[30522 384]
		"embeddings.position_embeddings.weight": (Float32)[512 384]
		"embeddings.token_type_embeddings.weight": (Float32)[2 384]
		"embeddings.LayerNorm.weight": (Float32)[384]
		"embeddings.LayerNorm.bias": (Float32)[384]
		"encoder.layer.0.attention.self.query.bias": (Float32)[384]
		"encoder.layer.0.attention.self.key.bias": (Float32)[384]
		"encoder.layer.0.attention.self.value.bias": (Float32)[384]
		"encoder.layer.0.attention.output.dense.bias": (Float32)[384]
		"encoder.layer.0.attention.output.LayerNorm.weight": (Float32)[384]
		"encoder.layer.0.attention.output.LayerNorm.b