# o1js Benchmark

This notebook benchmarks an MLP model implemented in PyTorch, which mirrors the structure of an MLP model defined using o1js in TypeScript. The model is then exported to ONNX format.

In [None]:
# Install necessary packages
!pip install torch numpy onnx ezkl

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import numpy as np
import onnx
import os
import json
import ezkl

In [None]:
exp_num = os.environ["EXP_NUM"]

In [None]:
! mkdir -p mlp{exp_num}

In [None]:
model_path = os.path.join(f'mlp{exp_num}/mlp.onnx')
compiled_model_path = os.path.join(f'mlp{exp_num}/model.compiled')
pk_path = os.path.join(f'mlp{exp_num}/pk.key')
vk_path = os.path.join(f'mlp{exp_num}/test.vk')
settings_path = os.path.join(f'mlp{exp_num}/settings.json')

witness_path = os.path.join(f'mlp{exp_num}/witness.json')
data_path = os.path.join('input.json')
onnx_path =  os.path.join(f"mlp{exp_num}/mlp.onnx")


## Define the MLP Model

In [None]:
# Define the MLP class with the same structure as the o1js MLP
class MLP(nn.Module):
    def __init__(self, depth):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(5, 5) for _ in range(depth)])
        self.output = nn.Linear(5, 1)
        self.relu = nn.ReLU()
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.zeros_(m.weight.data)

    def forward(self, x):
        for layer in self.layers:
            x = self.relu(layer(x))
        x = self.output(x)
        return x

## Initialize the Model and Set Parameters

In [None]:
# Initialize the model
depth = 2**int(exp_num)
model = MLP(depth=depth)

# Manually set the weights and biases to match the o1js example
# with torch.no_grad():
#     # `hidden1` expects a [1, 5] weight matrix for a [5] input vector
#     model.hidden1.weight = nn.Parameter(torch.tensor([[2.0, 4.0, 3.0, 1.0, 5.0]]))  # shape [1, 5]
#     model.hidden1.bias = nn.Parameter(torch.tensor([3.0]))  # shape [1]

#     # `hidden2` expects a [1, 5] weight matrix for a [5] input vector (which is repeated 5 times)
#     model.hidden2.weight = nn.Parameter(torch.tensor([[3.0, 1.0, 4.0, 2.0, 6.0]]))  # shape [1, 5]
#     model.hidden2.bias = nn.Parameter(torch.tensor([2.0]))  # shape [1]

#     # `output` expects a [1, 1] weight matrix for a [1] input vector
#     model.output.weight = nn.Parameter(torch.tensor([[1.0]]))  # shape [1, 1]
#     model.output.bias = nn.Parameter(torch.tensor([5.0]))  # shape [1]

## Perform Forward Pass

In [None]:
# Create input data (same as in o1js example)
# read in ./input_json
data = json.load(open("input.json", 'r'))

# convert to torch tensor
input_data = torch.tensor(data['input_data'], requires_grad=True)

# Perform forward pass through the network
output = model(input_data)
print("Model Output:", output)

## Export the Model to ONNX Format

In [None]:
# Export the model to ONNX format
torch.onnx.export(
    model,                          # Model being run
    input_data,                     # Model input (or a tuple for multiple inputs)
    onnx_path,                      # Where to save the model (can be a file or file-like object)
    export_params=True,             # Store the trained parameter weights inside the model file
    opset_version=10,               # The ONNX version to export the model to
    do_constant_folding=True,       # Whether to execute constant folding for optimization
    input_names=['input'],          # The model's input names
    output_names=['output'],        # The model's output names
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}  # Variable length axes
)

print(f"Model exported to {onnx_path}")

## Validate the ONNX Model

In [None]:
# Load and check the ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print('ONNX model is valid')

In [None]:
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path)
assert res == True

In [None]:

# res = ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
# assert res == True

In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True


In [None]:
# srs path
res = ezkl.get_srs(settings_path)

In [None]:
# now generate the witness file 

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)