# Converting ONNX models to GoMLX

## Imports and setup

In [1]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx".


In [160]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"
	"github.com/janpfeifer/must"

	"github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
	"github.com/gomlx/gopjrt/dtypes"
    "github.com/gomlx/onnx-gomlx/onnx"
	"github.com/pkg/errors"

	_ "github.com/gomlx/gomlx/backends/xla"
)

var (
    _ = Add
    backend = backends.New()
)


## Linear Model

The `linear_regression.onnx` was created manually using python (see accompanying `onnx-py.ipynb` notebook):

```python
feature_dim = 5
X = make_tensor_value_info('X', TensorProto.FLOAT, ["batch_size", feature_dim])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, ["batch_size"])
A_initializer = onnx.helper.make_tensor('A', TensorProto.FLOAT, [feature_dim], [100.0, 10.0, 1.0, 0.1, 0.01])
B_initializer = onnx.helper.make_tensor('B', TensorProto.FLOAT, [], [7000.0])
node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')
graph = make_graph([node1, node2], 'lr', [X], [Y], initializer=[A_initializer, B_initializer])
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
```

In [3]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)
must.M(model.PrintVariables(os.Stdout))
fmt.Println()
must.M(model.PrintGraph(os.Stdout))
fmt.Println()

// Convert ONNX variables to GoMLX context (which stores variables):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
fmt.Printf("Variables loaded from %q:\n", modelPath)
for v := range ctx.IterVariables() {
    fmt.Printf("\t- %s: %s\n", v.ScopeAndName(), v.Value().GoStr())
}
fmt.Println()

// Execute it with GoMLX/XLA:
gomlxFn := func(ctx *context.Context, x *Node) *Node {
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"X": x})[0]
}
x := [][]float32{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}
fmt.Println("Example invocation:")
fmt.Printf("\tX*A+B=%v\n", context.ExecOnce(backend, ctx, gomlxFn, x).GoStr())

ONNX Model:
	# inputs:	1
		[#0] X: (Float32) [batch_size, 5]
	# outputs:	1
		[#0] Y: (Float32) [batch_size]
	# nodes:	2
	# tensors (variables):	2
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "MatMul"}
	IR Version:	10
	Operator Sets:	[v21]

2 tensors (variables):
	"A": (Float32)[5]
	"B": (Float32)
0 sparse tensors (variables)

Model Graph "lr":
"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]

Variables loaded from "/home/janpf/work/onnx/linear_regression.onnx":
	- /ONNX/A: (Float32)[5]: []float32{100, 10, 1, 0.1, 0.01}
	- /ONNX/B: float32(7000)

Example invocation:
	X*A+B=(Float32)[2]: []float32{7123.45, 7679}


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [205]:
func miniLM(verbose bool, targetOutputs ...string) []*tensors.Tensor {
    modelPath = data.ReplaceTildeInDir("~/work/onnx/model.onnx") // all-MiniLM-L6-v2
    // Read and print the onnx model:
    model := must.M1(onnx.ReadFile(modelPath))
    if verbose {
        fmt.Printf("%s\n", model)
    }
        
    // Create a file for debugging with detailed graph.
    if verbose {
        debugFile := must.M1(os.Create("graph.txt"))
        must.M(model.PrintVariables(debugFile))
        fmt.Println()
        must.M(model.PrintGraph(debugFile))
        fmt.Println()
        must.M(debugFile.Close())

        dotFile := must.M1(os.Create("graph.dot"))
        must.M(model.PrintGraphviz(dotFile))
        must.M(dotFile.Close())
    }
    
    // Convert ONNX variables to GoMLX context (which stores variables):
    ctx := context.New()
    must.M(model.VariablesToContext(ctx))
    
    // Execute it with GoMLX/XLA:
    sentences := []string{
        "This is an example sentence", 
        "Each sentence is converted"}

    // Encoding to tokens done in Python and pasted here.
    inputIDs := [][]int64{
        {101, 2023, 2003, 2019, 2742, 6251,  102},
        { 101, 2169, 6251, 2003, 4991,  102,    0}}
    tokenTypeIDs := [][]int64{
        {0, 0, 0, 0, 0, 0, 0},
        {0, 0, 0, 0, 0, 0, 0}}
    attentionMask := [][]int64{
        {1, 1, 1, 1, 1, 1, 1},
        {1, 1, 1, 1, 1, 1, 0}}
    if verbose {
        fmt.Println("Example invocation:")
        fmt.Printf("\tall-MiniLM-L6-v2(%#v)=\n", sentences)
    }
    var embeddings []*tensors.Tensor
    err := exceptions.TryCatch[error](func() {
        embeddings = context.ExecOnceN(
            backend, ctx, func (ctx *context.Context, inputs []*Node) []*Node {
                return model.CallGraph(ctx, inputs[0].Graph(), map[string]*Node{
                "input_ids": inputs[0],
                "attention_mask": inputs[1],
                "token_type_ids": inputs[2]}, targetOutputs...)
            }, inputIDs, attentionMask, tokenTypeIDs)
    })
    must.M(err)
    return embeddings
}

func p(target string) {
    output := miniLM(false, target)[0]
    fmt.Printf("\n%s:\n", target)
    fmt.Printf("%s\n", output)
}

In [207]:
%env GOMLX_BACKEND="xla:cuda"
%%
// p("/embeddings/Slice_output_0")
// p("/embeddings/position_embeddings/Gather_output_0")
// p("token_type_ids")
// p("embeddings.token_type_embeddings.weight")
// p("/embeddings/token_type_embeddings/Gather_output_0")
// p("/embeddings/Add_output_0")
p("/embeddings/Add_1_output_0")

Set: GOMLX_BACKEND="xla:cuda"

/embeddings/Add_1_output_0:
[2][7][384]float32{
 {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
  {-0.0200, -0.0014, -0.0177, ..., 0.0204, 0.0522, 0.1991},
  {-0.0196, -0.0336, -0.0319, ..., 0.0203, 0.0709, 0.0644},
  ...,
  {-0.0253, 0.0408, 0.0125, ..., -0.0270, 0.0377, 0.1133},
  {-0.0140, -0.0275, 0.0796, ..., -0.0748, 0.0774, -0.0657},
  {0.0318, -0.0032, -0.0210, ..., 0.0387, 0.0191, -0.0059}},
 {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
  {0.0304, 0.0531, -0.0238, ..., -0.1011, 0.0218, 0.0473},
  {-0.0027, -0.0508, 0.0805, ..., -0.0777, 0.0881, -0.0560},
  ...,
  {0.0928, 0.0165, -0.0976, ..., 0.0449, 0.0390, -0.0182},
  {0.0231, 0.0090, -0.0213, ..., 0.0232, 0.0191, -0.0066},
  {-0.0213, 0.0019, 0.0043, ..., 0.0561, 0.0170, 0.0256}}}


In [175]:
func parseShapes() {
    contents := string(must.M1(os.ReadFile("model_shapes.txt")))
    for _, line := range strings.Split(contents, "\n") {
        parts := strings.Split(line, "\t")
        if len(parts) != 3 {
            continue
        }
        nodeOutputName := parts[0]
        dtype := must.M1(dtypes.DTypeString(parts[1]))
        dimsStr := parts[2]
        dimsStr = strings.Trim(dimsStr, "()")
        var dims []int
        for _, dimStr := range strings.Split(dimsStr, ",") {
            if dimStr == "" {
                continue
            }
            dimStr = strings.Trim(dimStr, " ")
            dims = append(dims, must.M1(strconv.Atoi(dimStr)))
        }
        wantShape := shapes.Make(dtype, dims...)
        fmt.Printf("%s: %s\n", nodeOutputName, wantShape)
        output := miniLM(false, nodeOutputName)[0]
        if !output.Shape().Equal(wantShape) {
            fmt.Printf("\t*** got shape %s\n", output.Shape())
        }
    }
}
%%
parseShapes()

/Shape_output_0: (Int64)[2]
/Constant_output_0: (Int64)
/Gather_output_0: (Int64)
onnx::Slice_110: (Int64)[1 512]
/embeddings/Constant_output_0: (Int64)[1]
/embeddings/Constant_1_output_0: (Int64)[1]
/embeddings/Constant_2_output_0: (Int64)[1]
/embeddings/Unsqueeze_output_0: (Int64)[1]
/embeddings/Constant_3_output_0: (Int64)[1]
/embeddings/Slice_output_0: (Int64)[1 7]
/embeddings/word_embeddings/Gather_output_0: (Float32)[2 7 384]
/embeddings/token_type_embeddings/Gather_output_0: (Float32)[2 7 384]
/embeddings/Add_output_0: (Float32)[2 7 384]
/embeddings/position_embeddings/Gather_output_0: (Float32)[1 7 384]
/embeddings/Add_1_output_0: (Float32)[2 7 384]
/embeddings/LayerNorm/ReduceMean_output_0: (Float32)[2 7 1]
/embeddings/LayerNorm/Sub_output_0: (Float32)[2 7 384]
/embeddings/LayerNorm/Constant_output_0: (Float32)
/embeddings/LayerNorm/Pow_output_0: (Float32)[2 7 384]
/embeddings/LayerNorm/ReduceMean_1_output_0: (Float32)[2 7 1]
/embeddings/LayerNorm/Constant_1_output_0: (Float32

### Last layer result

In [208]:
%%
fmt.Printf("%s\n", miniLM(true)[0])

ONNX Model:
	Producer:	pytorch / 2.5.0
	# inputs:	3
		[#0] input_ids: (Int64) [batch_size, sequence_length]
		[#1] attention_mask: (Int64) [batch_size, sequence_length]
		[#2] token_type_ids: (Int64) [batch_size, sequence_length]
	# outputs:	1
		[#0] last_hidden_state: (Float32) [batch_size, sequence_length, 384]
	# nodes:	780
	# tensors (variables):	101
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "Cast", "Concat", "Constant", "ConstantOfShape", "Div", "Equal", "Erf", "Expand", "Gather", "MatMul", "Mul", "Pow", "ReduceMean", "Reshape", "Shape", "Slice", "Softmax", "Sqrt", "Sub", "Transpose", "Unsqueeze", "Where"}
	IR Version:	7
	Operator Sets:	[v14]



Example invocation:
	all-MiniLM-L6-v2([]string{"This is an example sentence", "Each sentence is converted"})=
[2][7][384]float32{
 {{0.0366, -0.0162, 0.1682, ..., 0.0554, -0.1644, -0.2967},
  {0.7239, 0.6399, 0.1888, ..., 0.5946, 0.6206, 0.4897},
  {0.0064, 0.0203, 0.0448, ..., 0.3464, 1.3170, -0.1670},
  ...,
  {0.1479, 