# Converting ONNX models to GoMLX

## Imports and setup

In [1]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx".


In [2]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"
	"github.com/janpfeifer/must"

	"github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
    "github.com/gomlx/onnx-gomlx/onnx"
	"github.com/pkg/errors"

	_ "github.com/gomlx/gomlx/backends/xla"
)

var (
    _ = Add
    backend = backends.New()
)


## Linear Model

The `linear_regression.onnx` was created manually using python (see accompanying `onnx-py.ipynb` notebook):

```python
feature_dim = 5
X = make_tensor_value_info('X', TensorProto.FLOAT, ["batch_size", feature_dim])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, ["batch_size"])
A_initializer = onnx.helper.make_tensor('A', TensorProto.FLOAT, [feature_dim], [100.0, 10.0, 1.0, 0.1, 0.01])
B_initializer = onnx.helper.make_tensor('B', TensorProto.FLOAT, [], [7000.0])
node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')
graph = make_graph([node1, node2], 'lr', [X], [Y], initializer=[A_initializer, B_initializer])
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
```

In [3]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)
must.M(model.PrintVariables(os.Stdout))
fmt.Println()
must.M(model.PrintGraph(os.Stdout))
fmt.Println()

// Convert ONNX variables to GoMLX context (which stores variables):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
fmt.Printf("Variables loaded from %q:\n", modelPath)
for v := range ctx.IterVariables() {
    fmt.Printf("\t- %s: %s\n", v.ScopeAndName(), v.Value().GoStr())
}
fmt.Println()

// Execute it with GoMLX/XLA:
gomlxFn := func(ctx *context.Context, x *Node) *Node {
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"X": x})[0]
}
x := [][]float32{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}
fmt.Println("Example invocation:")
fmt.Printf("\tX*A+B=%v\n", context.ExecOnce(backend, ctx, gomlxFn, x).GoStr())

ONNX Model:
	# inputs:	1
		[#0] X: (Float32) [batch_size, 5]
	# outputs:	1
		[#0] Y: (Float32) [batch_size]
	# nodes:	2
	# tensors (variables):	2
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "MatMul"}
	IR Version:	10
	Operator Sets:	[v21]

2 tensors (variables):
	"A": (Float32)[5]
	"B": (Float32)
0 sparse tensors (variables)

Model Graph "lr":
"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]

Variables loaded from "/home/janpf/work/onnx/linear_regression.onnx":
	- /ONNX/A: (Float32)[5]: []float32{100, 10, 1, 0.1, 0.01}
	- /ONNX/B: float32(7000)

Example invocation:
	X*A+B=(Float32)[2]: []float32{7123.45, 7679}


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [6]:

func miniLM() {
    modelPath = data.ReplaceTildeInDir("~/work/onnx/model.onnx") // all-MiniLM-L6-v2
    // Read and print the onnx model:
    model := must.M1(onnx.ReadFile(modelPath))
    fmt.Printf("%s\n", model)
    
    // Create a file for debugging with detailed graph.
    debugFile := must.M1(os.Create("graph.txt"))
    must.M(model.PrintVariables(debugFile))
    fmt.Println()
    must.M(model.PrintGraph(debugFile))
    fmt.Println()
    must.M(debugFile.Close())
    
    // Convert ONNX variables to GoMLX context (which stores variables):
    ctx := context.New()
    must.M(model.VariablesToContext(ctx))
    
    // Execute it with GoMLX/XLA:
    gomlxFn := func(ctx *context.Context, inputIDs, attentionMask, tokenTypeIDs *Node) *Node {
        outputs := model.CallGraph(ctx, inputIDs.Graph(), map[string]*Node{
            "input_ids": inputIDs,
            "attention_mask": attentionMask,
            "token_type_ids": tokenTypeIDs})
        return outputs[0]
    }
    sentences := []string{
        "This is an example sentence", 
        "Each sentence is converted"}
    // Encoding to tokens done in Python and pasted here.
    inputIDs := [][]int64{
        {101, 2023, 2003, 2019, 2742, 6251,  102},
        { 101, 2169, 6251, 2003, 4991,  102,    0}}
    tokenTypeIDs := [][]int64{
        {0, 0, 0, 0, 0, 0, 0},
        {0, 0, 0, 0, 0, 0, 0}}
    attentionMask := [][]int64{
        {1, 1, 1, 1, 1, 1, 1},
        {1, 1, 1, 1, 1, 1, 0}}
    fmt.Println("Example invocation:")
    fmt.Printf("\tall-MiniLM-L6-v2(%#v)=\n", sentences)
    var embeddings *tensors.Tensor
    err := exceptions.TryCatch[error](func() {
        embeddings = context.ExecOnce(backend, ctx, gomlxFn, inputIDs, tokenTypeIDs, attentionMask)
    })
    must.M(err)
    fmt.Printf("%s", embeddings.GoStr())
}

In [7]:
%env GOMLX_BACKEND="xla:cuda"
%%
miniLM()

Set: GOMLX_BACKEND="xla:cuda"
ONNX Model:
	Producer:	pytorch / 2.5.0
	# inputs:	3
		[#0] input_ids: (Int64) [batch_size, sequence_length]
		[#1] attention_mask: (Int64) [batch_size, sequence_length]
		[#2] token_type_ids: (Int64) [batch_size, sequence_length]
	# outputs:	1
		[#0] last_hidden_state: (Float32) [batch_size, sequence_length, 384]
	# nodes:	780
	# tensors (variables):	101
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "Cast", "Concat", "Constant", "ConstantOfShape", "Div", "Equal", "Erf", "Expand", "Gather", "MatMul", "Mul", "Pow", "ReduceMean", "Reshape", "Shape", "Slice", "Softmax", "Sqrt", "Sub", "Transpose", "Unsqueeze", "Where"}
	IR Version:	7
	Operator Sets:	[v14]



Example invocation:
	all-MiniLM-L6-v2([]string{"This is an example sentence", "Each sentence is converted"})=
(Float32)[2 7 384]: [][][]float32{{{0.03582147, 0.04050019, -0.12638353, -0.24778308, 0.21231164, 0.44405022, -0.1843396, 0.24394046, -0.22653413, -0.14214146, -0.21293658, 0.4322553,