# Converting ONNX models to GoMLX

## Imports and setup

In [1]:
%version
!*rm -f go.work && go work init
!*go work use . "${HOME}/Projects/gonb" "${HOME}/Projects/gomlx" "${HOME}/Projects/onnx-gomlx" "${HOME}/Projects/go-huggingface"
%goworkfix

**GoNB** version [v0.10.6](https://github.com/janpfeifer/gonb/releases/tag/v0.10.6) / Commit: [0e5f587a077810d058202b76a127651a02bd4382](https://github.com/janpfeifer/gonb/tree/0e5f587a077810d058202b76a127651a02bd4382)


	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/onnx-gomlx" to local directory "/home/janpf/Projects/onnx-gomlx".
	- Added replace rule for module "github.com/gomlx/go-huggingface" to local directory "/home/janpf/Projects/go-huggingface".


In [2]:
import (
	"fmt"
	"os"

	"github.com/gogo/protobuf/proto"
	"github.com/janpfeifer/must"

	"github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/data"
	"github.com/gomlx/gomlx/types"
	"github.com/gomlx/gopjrt/dtypes"
    "github.com/gomlx/onnx-gomlx/onnx"
    "github.com/gomlx/go-huggingface/hub"
    "github.com/gomlx/go-huggingface/tokenizers"
	"github.com/pkg/errors"

	_ "github.com/gomlx/gomlx/backends/xla"
)

var (
    _ = Add
    backend = backends.New()
)


## Linear Model

The `linear_regression.onnx` was created manually using python (see accompanying `onnx-py.ipynb` notebook):

```python
feature_dim = 5
X = make_tensor_value_info('X', TensorProto.FLOAT, ["batch_size", feature_dim])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, ["batch_size"])
A_initializer = onnx.helper.make_tensor('A', TensorProto.FLOAT, [feature_dim], [100.0, 10.0, 1.0, 0.1, 0.01])
B_initializer = onnx.helper.make_tensor('B', TensorProto.FLOAT, [], [7000.0])
node1 = make_node('MatMul', ['X', 'A'], ['XA'], 'XA')
node2 = make_node('Add', ['XA', 'B'], ['Y'], 'Y')
graph = make_graph([node1, node2], 'lr', [X], [Y], initializer=[A_initializer, B_initializer])
onnx_model = make_model(graph)
check_model(onnx_model)
with open("linear_regression.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
```

In [3]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// Read and print the onnx model:
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)
must.M(model.PrintVariables(os.Stdout))
fmt.Println()
must.M(model.PrintGraph(os.Stdout))
fmt.Println()

// Convert ONNX variables to GoMLX context (which stores variables):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
fmt.Printf("Variables loaded from %q:\n", modelPath)
for v := range ctx.IterVariables() {
    fmt.Printf("\t- %s: %s\n", v.ScopeAndName(), v.Value().GoStr())
}
fmt.Println()

// Execute it with GoMLX/XLA:
gomlxFn := func(ctx *context.Context, x *Node) *Node {
    
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"X": x})[0]
}
x := [][]float32{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}
fmt.Println("Example invocation:")
fmt.Printf("\tX*A+B=%v\n", context.ExecOnce(backend, ctx, gomlxFn, x).GoStr())

ONNX Model:
	# inputs:	1
		[#0] X: (Float32) [batch_size, 5]
	# outputs:	1
		[#0] Y: (Float32) [batch_size]
	# nodes:	2
	# tensors (variables):	2
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "MatMul"}
	IR Version:	10
	Operator Sets:	[v21]

2 tensors (variables):
	"A": (Float32)[5]
	"B": (Float32)
0 sparse tensors (variables)

Model Graph "lr":
"XA":	[MatMul]
	Inputs: ["X" "A"]
	Outputs: ["XA"]
"Y":	[Add]
	Inputs: ["XA" "B"]
	Outputs: ["Y"]

Variables loaded from "/home/janpf/work/onnx/linear_regression.onnx":
	- /ONNX/A: (Float32)[5]: []float32{100, 10, 1, 0.1, 0.01}
	- /ONNX/B: float32(7000)

Example invocation:
	X*A+B=(Float32)[2]: []float32{7123.45, 7679}


## Updating a Model

One of the main uses of `onnx-gomlx` is to convert the model to GoMLX and then fine-tune it. 
The fine-tuned (or simply updated) model can be used as usual as a GoMLX model, or the variables
can be written back to the original ONNX model -- for intance if inference is being done somewhere else.

Example:

1. We load the Linear ONNX model
2. We update the bias variable "B": from 7000 to 8000.
3. We save the changed ONNX model.
4. We re-run the example above, using the changed model, and observe B is now 8000.

In [8]:
var modelPath = data.ReplaceTildeInDir("~/work/onnx/linear_regression.onnx") // all-MiniLM-L6-v2

%%
// 1. Load linear model:
model := must.M1(onnx.ReadFile(modelPath))

// 2. Convert ONNX variables to GoMLX context (which stores variables) and update "B" to 8000
ctx := context.New()
must.M(model.VariablesToContext(ctx))
bVar := ctx.In(onnx.ModelScope).GetVariable("B")
fmt.Printf("B value was %s\n", bVar.Value())
bVar.SetValue(tensors.FromValue(float32(8000)))

// 3. Update the ONNX model variables, and save it.
must.M(model.ContextToONNX(ctx))
must.M(model.SaveToFile(modelPath+"~"))

// 4. Re-run example on updated model:
model = must.M1(onnx.ReadFile(modelPath+"~"))
ctx = context.New()  // Create a new context.
must.M(model.VariablesToContext(ctx))
bVar = ctx.In(onnx.ModelScope).GetVariable("B")
fmt.Printf("B value now is %s\n", bVar.Value())
gomlxFn := func(ctx *context.Context, x *Node) *Node {    
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"X": x})[0]
}
x := [][]float32{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}
fmt.Println("Example invocation:")
fmt.Printf("\tX*A+B=%v\n", context.ExecOnce(backend, ctx, gomlxFn, x))

B value was float32(7000.0000)
B value now is float32(8000.0000)
Example invocation:
	X*A+B=[2]float32{8123.4502, 8679.0000}


## Simple LSTM model

The model was created with the Python code:

```python

```

In [6]:
%%
// Read and print the onnx model:
modelPath := os.Getenv("HOME")+"/work/onnx/test_lstm.onnx"
model := must.M1(onnx.ReadFile(modelPath))
fmt.Printf("%s\n", model)
must.M(model.PrintVariables(os.Stdout))
fmt.Println()
must.M(model.PrintGraph(os.Stdout))
fmt.Println()

// Convert ONNX variables to GoMLX context (which stores variables):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
fmt.Printf("Variables loaded from %q:\n", modelPath)
for v := range ctx.IterVariables() {
    fmt.Printf("\t- %s: %s\n", v.ScopeAndName(), v.Shape())
}
fmt.Println()

// Execute it with GoMLX/XLA:
gomlxFn := func(ctx *context.Context, x *Node) *Node {
    return model.CallGraph(ctx, x.Graph(), map[string]*Node{"input": x})[0]
}
x := [][]int32{{0, 1, 2, 3, 4, 5, 6}}
fmt.Printf("lstm(x) = \t%v\n", context.ExecOnce(backend, ctx, gomlxFn, x).GoStr())

ONNX Model:
	Producer:	pytorch / 2.5.0
	# inputs:	1
		[#0] "input": (Int32) [1, sequence_length]
	# outputs:	1
		[#0] "output": (Float32) [1, 3]
	# nodes:	15
	# tensors (variables):	6
	# sparse tensors (variables):	0
	Op types:	[]string{"Concat", "Constant", "ConstantOfShape", "Gather", "Gemm", "LSTM", "Shape", "Squeeze", "Transpose", "Unsqueeze"}
	IR Version:	9
	Operator Sets:	[v20]

6 tensors (variables):
	"embedding.weight": (Float32)[30522 5]
	"fc.weight": (Float32)[3 11]
	"fc.bias": (Float32)[3]
	"onnx::LSTM_108": (Float32)[1 44 5]
	"onnx::LSTM_109": (Float32)[1 44 11]
	"onnx::LSTM_110": (Float32)[1 88]
0 sparse tensors (variables)

Model Graph "main_graph":
"/embedding/Gather":	[Gather]
	Inputs: ["embedding.weight" "input"]
	Outputs: ["/embedding/Gather_output_0"]
"/lstm/Shape":	[Shape]
	Inputs: ["/embedding/Gather_output_0"]
	Outputs: ["/lstm/Shape_output_0"]
"/lstm/Constant":	[Constant]
	Inputs: []
	Outputs: ["/lstm/Constant_output_0"]
	Attributes: value (TENSOR: (Int64))
"/lst

In [4]:
!ls

graph.dot
graph.svg
graph.txt
linear_regression.onnx
model.onnx
model_shapes_fixed.onnx
model_shapes.txt
modified_model.onnx
onnx-go.ipynb
onnx-py.ipynb
test_lstm.onnx
test_.onnx


## [Sentence Enconding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

From the downloaded file `model.onnx`.

In [5]:
var (
    // Model IDs to use.
    miniID = "sentence-transformers/all-MiniLM-L6-v2"
    bertID = "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english"
    protectID = "protectai/deberta-v3-base-zeroshot-v1-onnx"
    robertaID = "SamLowe/roberta-base-go_emotions-onnx"
    distillbertID = "onnxport/distilbert-base-uncased-onnx"
)

func onnxTest(modelID string, verbose bool, targetOutputs ...string) []*tensors.Tensor {
    repo := hub.New(modelID)
    modelFileName := "model.onnx"
    if !repo.HasFile(modelFileName) { 
        modelFileName = "onnx/model.onnx"
        if !repo.HasFile(modelFileName) { 
            log.Printf("Could not find \"model.onnx\" for repo %q", repo)
        }
    }
    modelPath := must.M1(repo.DownloadFile(modelFileName))
    model := must.M1(onnx.ReadFile(modelPath))
    if verbose {
        fmt.Printf("\n%s\n", modelID)
        fmt.Printf("%s\n", model)
    }
        
    // Create a file for debugging with detailed graph.
    if verbose {
        debugFile := must.M1(os.Create("graph.txt"))
        must.M(model.PrintVariables(debugFile))
        fmt.Println()
        must.M(model.PrintGraph(debugFile))
        fmt.Println()
        must.M(debugFile.Close())

        dotFile := must.M1(os.Create("graph.dot"))
        must.M(model.PrintGraphviz(dotFile))
        must.M(dotFile.Close())
    }
    
    // Convert ONNX variables to GoMLX context (which stores variables):
    ctx := context.New()
    must.M(model.VariablesToContext(ctx))
    
    // Execute it with GoMLX/XLA:
    sentences := []string{
        "This is an example sentence", 
        "Each sentence is converted"}

    // Encoding to tokens done in Python and pasted here.
    inputIDs := [][]int64{
        {101, 2023, 2003, 2019, 2742, 6251,  102},
        { 101, 2169, 6251, 2003, 4991,  102,    0}}
    tokenTypeIDs := [][]int64{
        {0, 0, 0, 0, 0, 0, 0},
        {0, 0, 0, 0, 0, 0, 0}}
    attentionMask := [][]int64{
        {1, 1, 1, 1, 1, 1, 1},
        {1, 1, 1, 1, 1, 1, 0}}
    if verbose {
        fmt.Println("Example invocation:")
        fmt.Printf("\tsentences: %q\n", sentences)
    }
    var embeddings []*tensors.Tensor
    err := exceptions.TryCatch[error](func() {
        embeddings = context.ExecOnceN(
            backend, ctx, func (ctx *context.Context, inputs []*Node) []*Node {
                inputsMap := map[string]*Node{
                    "input_ids": inputs[0],
                    "attention_mask": inputs[1]}
                if model.NumInputs() == 3 {
                    inputsMap["token_type_ids"] = inputs[2]
                }
                return model.CallGraph(ctx, inputs[0].Graph(), inputsMap, targetOutputs...)
            }, inputIDs, attentionMask, tokenTypeIDs)
    })
    must.M(err)
    return embeddings
}

func p(modelID, target string) {
    output := onnxTest(modelID, false, target)[0]
    fmt.Printf("\n%s:\n", target)
    fmt.Printf("%s\n", output)
}

### Last layer result

In [6]:
%%
fmt.Printf("\touput #0: %s\n", onnxTest(distillbertID, true)[0])

Downloaded 1/1 files, 266 MB downloaded         

onnxport/distilbert-base-uncased-onnx
ONNX Model:
	Producer:	pytorch / 1.12.0
	# inputs:	2
		[#0] "input_ids": (Int64) [batch, sequence]
		[#1] "attention_mask": (Int64) [batch, sequence]
	# outputs:	1
		[#0] "last_hidden_state": (Float32) [batch, sequence, 768]
	# nodes:	593
	# tensors (variables):	106
	# sparse tensors (variables):	0
	Op types:	[]string{"Add", "Cast", "Concat", "Constant", "Div", "Equal", "Erf", "Expand", "Gather", "Identity", "MatMul", "Mul", "Pow", "ReduceMean", "Reshape", "Shape", "Slice", "Softmax", "Sqrt", "Sub", "Transpose", "Unsqueeze", "Where"}
	IR Version:	6
	Operator Sets:	[v11]



Example invocation:
	sentences: ["This is an example sentence" "Each sentence is converted"]
	ouput #0: [2][7][768]float32{
 {{-0.2348, -0.2102, -0.0227, ..., -0.0947, 0.0382, 0.4671},
  {-0.7355, -0.5579, -0.3003, ..., -0.3094, 0.3252, 0.4205},
  {-0.6125, -0.3603, -0.0217, ..., -0.1959, -0.0726, 1.0924},
  ...,
  {-0.5531, 0.046

### Parse ONNX model for several repos.

In [8]:
import (
	"github.com/gomlx/go-huggingface/hub"
)

var (
    modelIDs = []string {
        "sentence-transformers/all-MiniLM-L6-v2",
        "protectai/deberta-v3-base-zeroshot-v1-onnx",
        "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english",
        "KnightsAnalytics/distilbert-NER",
        "SamLowe/roberta-base-go_emotions-onnx",
    }
    hfAuthToken = os.Getenv("HF_TOKEN")

    sentences = []string{
        "This is an example sentence", 
        "Each sentence is converted"}
)

%%
for _, modelID := range modelIDs {
    fmt.Printf("\n%s:\n", modelID)
    repo := hub.New(modelID).WithAuth(hfAuthToken)
    must.M(repo.DownloadInfo(false))  // Not needed, but reports an error if repo info could not be downloaded.
    modelFileName := "model.onnx"
    if !repo.HasFile(modelFileName) { 
        modelFileName = "onnx/model.onnx"
        if !repo.HasFile(modelFileName) { 
            log.Printf("Could not find \"model.onnx\" for repo %q", repo)
        }
    }
    modelPath := must.M1(repo.DownloadFile(modelFileName))
    model := must.M1(onnx.ReadFile(modelPath))
    inputNames, _ := model.Inputs()
    fmt.Printf("\tInputs: %q\n", inputNames)
    outputNames, outputShapes := model.Outputs()
    for ii, n := range outputNames {
        fmt.Printf("\tOutput #%d: %s - %s\n", ii, n, outputShapes[ii])
    }

    // Convert ONNX variables to GoMLX context (which stores variables):
    ctx := context.New()
    must.M(model.VariablesToContext(ctx))

    // TODO: use github.com/gomlx/go-huggingface to tokenize according to the model.
    inputIDs := [][]int64{
        {101, 2023, 2003, 2019, 2742, 6251,  102},
        { 101, 2169, 6251, 2003, 4991,  102,    0}}
    tokenTypeIDs := [][]int64{
        {0, 0, 0, 0, 0, 0, 0},
        {0, 0, 0, 0, 0, 0, 0}}
    attentionMask := [][]int64{
        {1, 1, 1, 1, 1, 1, 1},
        {1, 1, 1, 1, 1, 1, 0}}

    
    embeddings := context.ExecOnce(
        backend, ctx, 
        func (ctx *context.Context, inputs []*Node) *Node {
            inputsMap := map[string]*Node{
                "input_ids": inputs[0],
                "attention_mask": inputs[1]}
            if len(inputNames) == 3 {
                inputsMap["token_type_ids"] = inputs[2]
            }
            modelOutputs := model.CallGraph(ctx, inputs[0].Graph(), inputsMap)
            return modelOutputs[0]
        }, inputIDs, attentionMask, tokenTypeIDs)    
    fmt.Printf("\tEmbeddings:\t%s\n", embeddings)
}


sentence-transformers/all-MiniLM-L6-v2:
	Inputs: ["input_ids" "attention_mask" "token_type_ids"]
	Output #0: last_hidden_state - (Float32) [batch_size, sequence_length, 384]
	Embeddings:	[2][7][384]float32{
 {{0.0366, -0.0162, 0.1682, ..., 0.0554, -0.1644, -0.2967},
  {0.7239, 0.6399, 0.1888, ..., 0.5946, 0.6206, 0.4897},
  {0.0064, 0.0203, 0.0448, ..., 0.3464, 1.3170, -0.1670},
  ...,
  {0.1479, -0.0643, 0.1457, ..., 0.8837, -0.3316, 0.2975},
  {0.5212, 0.6563, 0.5607, ..., -0.0399, 0.0412, -1.4036},
  {1.0824, 0.7140, 0.3986, ..., -0.2301, 0.3243, -1.0313}},
 {{0.2802, 0.1165, -0.0418, ..., 0.2711, -0.1685, -0.2961},
  {0.8729, 0.4545, -0.1091, ..., 0.1365, 0.4580, -0.2042},
  {0.4752, 0.5731, 0.6304, ..., 0.6526, 0.5612, -1.3268},
  ...,
  {0.6113, 0.7920, -0.4685, ..., 0.0854, 1.0592, -0.2983},
  {0.4115, 1.0946, 0.2385, ..., 0.8984, 0.3684, -0.7333},
  {0.1374, 0.5555, 0.2678, ..., 0.5426, 0.4665, -0.5284}}}

protectai/deberta-v3-base-zeroshot-v1-onnx:
	Inputs: ["input_ids" "atte