# Converting ONNX models to GoMLX

## Imports and setup

In [1]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx".


In [2]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"
	"github.com/janpfeifer/must"

	"github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
    "github.com/gomlx/onnx-gomlx/onnx"
	"github.com/pkg/errors"

	_ "github.com/gomlx/gomlx/backends/xla"
)

var (
    _ = Add
    backend = backends.New()
)


## Linear Model

The `linear_regression.onnx` was created manually using python (see accompanying `onnx-py.ipynb` notebook):

```python
feature_dim = 5
X = make_tensor_value_info('X', TensorProto.FLOAT, ["batch_size", feature_dim])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, ["batch_size"])
A_initializer = onnx.helper.make_tensor('A', TensorProto.FLOAT, [feature_dim], [100.0, 10.0, 1.0, 0.1, 0.01])
B_initializer = onnx.helper.make_tensor('B', TensorProto.FLOAT, [], [7000.0])
node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')
graph = make_graph([node1, node2], 'lr', [X], [Y], initializer=[A_initializer, B_initializer])
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
```

In [3]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)
must.M(model.PrintVariables(os.Stdout))
fmt.Println()
must.M(model.PrintGraph(os.Stdout))
fmt.Println()

// Convert ONNX variables to GoMLX context (which stores variables):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
fmt.Printf("Variables loaded from %q:\n", modelPath)
for v := range ctx.IterVariables() {
    fmt.Printf("\t- %s: %s\n", v.ScopeAndName(), v.Value().GoStr())
}
fmt.Println()

// Execute it with GoMLX/XLA:
gomlxFn := func(ctx *context.Context, x *Node) *Node {
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"X": x})[0]
}
x := [][]float32{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}
fmt.Println("Example invocation:")
fmt.Printf("\tX*A+B=%v\n", context.ExecOnce(backend, ctx, gomlxFn, x).GoStr())

ONNX Model:
	# inputs:	1
		[#0] X: (Float32) [batch_size, 5]
	# outputs:	1
		[#0] Y: (Float32) [batch_size]
	# nodes:	2
	# tensors (variables):	2
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "MatMul"}
	IR Version:	10
	Operator Sets:	[v21]

2 tensors (variables):
	"A": (Float32)[5]
	"B": (Float32)
0 sparse tensors (variables)

Model Graph "lr":
"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]

Variables loaded from "/home/janpf/work/onnx/linear_regression.onnx":
	- /ONNX/A: (Float32)[5]: []float32{100, 10, 1, 0.1, 0.01}
	- /ONNX/B: float32(7000)

Example invocation:
	X*A+B=(Float32)[2]: []float32{7123.45, 7679}


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [64]:
func miniLM(verbose bool, targetOutputs ...string) *tensors.Tensor {
    modelPath = data.ReplaceTildeInDir("~/work/onnx/model.onnx") // all-MiniLM-L6-v2
    // Read and print the onnx model:
    model := must.M1(onnx.ReadFile(modelPath))
    if verbose {
        fmt.Printf("%s\n", model)
    }
        
    // Create a file for debugging with detailed graph.
    if verbose {
        debugFile := must.M1(os.Create("graph.txt"))
        must.M(model.PrintVariables(debugFile))
        fmt.Println()
        must.M(model.PrintGraph(debugFile))
        fmt.Println()
        must.M(debugFile.Close())
    }
    
    // Convert ONNX variables to GoMLX context (which stores variables):
    ctx := context.New()
    must.M(model.VariablesToContext(ctx))
    
    // Execute it with GoMLX/XLA:
    gomlxFn := func(ctx *context.Context, inputIDs, attentionMask, tokenTypeIDs *Node) *Node {
        outputs := model.CallGraph(ctx, inputIDs.Graph(), map[string]*Node{
            "input_ids": inputIDs,
            "attention_mask": attentionMask,
            "token_type_ids": tokenTypeIDs}, targetOutputs...)
        return outputs[0]
    }
    sentences := []string{
        "This is an example sentence", 
        "Each sentence is converted"}
    // Encoding to tokens done in Python and pasted here.
    inputIDs := [][]int64{
        {101, 2023, 2003, 2019, 2742, 6251,  102},
        { 101, 2169, 6251, 2003, 4991,  102,    0}}
    tokenTypeIDs := [][]int64{
        {0, 0, 0, 0, 0, 0, 0},
        {0, 0, 0, 0, 0, 0, 0}}
    attentionMask := [][]int64{
        {1, 1, 1, 1, 1, 1, 1},
        {1, 1, 1, 1, 1, 1, 0}}
    fmt.Println("Example invocation:")
    fmt.Printf("\tall-MiniLM-L6-v2(%#v)=\n", sentences)
    var embeddings *tensors.Tensor
    err := exceptions.TryCatch[error](func() {
        embeddings = context.ExecOnce(backend, ctx, gomlxFn, inputIDs, tokenTypeIDs, attentionMask)
    })
    must.M(err)
    return embeddings
}

In [75]:
%env GOMLX_BACKEND="xla:cuda"
%%
embeddings := miniLM(true, "/embeddings/word_embeddings/Gather_output_0")
fmt.Printf("%s", embeddings)

Set: GOMLX_BACKEND="xla:cuda"
Example invocation:
	all-MiniLM-L6-v2([]string{"This is an example sentence", "Each sentence is converted"})=
[2][7][384]float32{
 {{-0.0176, -0.0076, 0.0471, ..., -0.0545, 0.0076, -0.0617},
  {-0.0019, -0.0074, -0.0267, ..., -0.0010, 0.0197, 0.1891},
  {-0.0208, -0.0279, -0.0515, ..., 0.0207, 0.0377, 0.0610},
  ...,
  {-0.0086, 0.0210, -0.0081, ..., -0.0094, 0.0128, 0.1179},
  {-0.0039, -0.0451, 0.0609, ..., -0.0773, 0.0549, -0.0595},
  {0.0332, -0.0085, -0.0400, ..., 0.0207, -0.0034, -0.0004}},
 {{-0.0176, -0.0076, 0.0471, ..., -0.0545, 0.0076, -0.0617},
  {0.0485, 0.0471, -0.0329, ..., -0.1225, -0.0107, 0.0374},
  {-0.0039, -0.0451, 0.0609, ..., -0.0773, 0.0549, -0.0595},
  ...,
  {0.1095, -0.0033, -0.1182, ..., 0.0625, 0.0140, -0.0135},
  {0.0332, -0.0085, -0.0400, ..., 0.0207, -0.0034, -0.0004},
  {-0.0200, -0.0034, -0.0147, ..., 0.0381, -0.0054, 0.0311}}}

In [72]:
%%
shape := shapes.Make(dtypes.F32, 3, 7, 384)
fmt.Printf("size=%d\n", shape.Size())
t := tensors.FromFlatDataAndDimensions(xslices.Iota(float32(1), shape.Size()), shape.Dimensions...)
fmt.Printf("%s\n", t)

size=8064
[3][7][384]float32{
 {{1.0000, 2.0000, 3.0000, ..., 382.0000, 383.0000, 384.0000},
  {385.0000, 386.0000, 387.0000, ..., 766.0000, 767.0000, 768.0000},
  {769.0000, 770.0000, 771.0000, ..., 1150.0000, 1151.0000, 1152.0000},
  ...,
  {1537.0000, 1538.0000, 1539.0000, ..., 1918.0000, 1919.0000, 1920.0000},
  {1921.0000, 1922.0000, 1923.0000, ..., 2302.0000, 2303.0000, 2304.0000},
  {2305.0000, 2306.0000, 2307.0000, ..., 2686.0000, 2687.0000, 2688.0000}},
 ...,
 {{5377.0000, 5378.0000, 5379.0000, ..., 5758.0000, 5759.0000, 5760.0000},
  {5761.0000, 5762.0000, 5763.0000, ..., 6142.0000, 6143.0000, 6144.0000},
  {6145.0000, 6146.0000, 6147.0000, ..., 6526.0000, 6527.0000, 6528.0000},
  ...,
  {6913.0000, 6914.0000, 6915.0000, ..., 7294.0000, 7295.0000, 7296.0000},
  {7297.0000, 7298.0000, 7299.0000, ..., 7678.0000, 7679.0000, 7680.0000},
  {7681.0000, 7682.0000, 7683.0000, ..., 8062.0000, 8063.0000, 8064.0000}}}
