## Load ONNX model in tinygrad

In [9]:
import onnx
from extra.onnx import get_run_onnx

model = onnx.load("mnist_model.onnx")
# Create a callable object 'run_onnx' that executes the model
run_onnx = get_run_onnx(model)

## Count parameters

In [10]:
from prettytable import PrettyTable
import numpy as np

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    # ONNX uses model.graph.initializer to iterate through the parameters (nodes)
    for node in model.graph.initializer:
        # ONNX uses np.prod() to calculate parameter count by multiplying the dimensions (node.dims = parameter shape)
        num_params = np.prod(node.dims)
        table.add_row([node.name, num_params])
        total_params += num_params
    print(table)
    print(f"Total Trainable Params: {total_params}\n")
    return total_params

count_parameters(model)

+-----------+------------+
|  Modules  | Parameters |
+-----------+------------+
| l1.weight |    288     |
|  l1.bias  |     32     |
| l2.weight |   18432    |
|  l2.bias  |     64     |
| l3.weight |   16000    |
|  l3.bias  |     10     |
+-----------+------------+
Total Trainable Params: 34826



np.int64(34826)

## Get the MNIST dataset

In [11]:
from tinygrad.nn.datasets import mnist

X_train, Y_train, X_test, Y_test = mnist()
print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)

(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar


## Final probabilities

In [12]:
# Select the first test image
test_image = X_test[0:1]

# Run the ONNX model using run_onnx function with the test image as input
# The model expects an input with the key "input.1"
onnx_output = run_onnx({"input.1": test_image})

# Get the output tensor (single vector of 10 values, 1 for each digit class)
output_tensor = list(onnx_output.values())[0]

# Apply softmax and convert to numpy
tinygrad_probs = output_tensor.softmax().numpy()

# Print the resulting probabilities
print("tinygrad probabilities:", tinygrad_probs)

tinygrad probabilities: [[2.6173244e-14 4.2464655e-14 6.4881434e-08 4.5528861e-09 2.0712009e-17
  4.9732746e-11 1.3766536e-21 9.9999988e-01 8.9217121e-13 8.5283941e-10]]
