# MLP Benchmark with ezkl Integration

This notebook benchmarks an MLP model implemented in PyTorch, which mirrors the structure of an MLP model defined using o1js in TypeScript. The model is exported to ONNX format and integrated with ezkl for generating and verifying Zero-Knowledge Proofs (ZKP). The notebook is structured to automate the steps for proving using `ezkl prove`.

In [ ]:
# Install necessary packages
!pip install torch numpy onnx ezkl

In [ ]:
# Import necessary libraries
import torch
import torch.nn as nn
import numpy as np
import onnx
import os
import json
import ezkl

## Set Experiment Parameters

In [ ]:
exp_num = 5  # Adjust depth complexity here

In [ ]:
! mkdir -p mlp{exp_num}

In [ ]:
model_path = os.path.join(f'mlp{exp_num}/mlp.onnx')
compiled_model_path = os.path.join(f'mlp{exp_num}/model.compiled')
pk_path = os.path.join(f'mlp{exp_num}/pk.key')
vk_path = os.path.join(f'mlp{exp_num}/test.vk')
settings_path = os.path.join(f'mlp{exp_num}/settings.json')

witness_path = os.path.join(f'mlp{exp_num}/witness.json')
data_path = os.path.join('input.json')
onnx_path =  os.path.join(f"mlp{exp_num}/mlp.onnx")

## Define the MLP Model

In [ ]:
# Define the MLP class with perceptron layers
class MLP(nn.Module):
    def __init__(self, depth):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(5, 5) for _ in range(depth)])
        self.output = nn.Linear(5, 1)
        self.relu = nn.ReLU()
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.zeros_(m.weight.data)

    def forward(self, x):
        for layer in self.layers:
            x = self.relu(layer(x))
        x = self.output(x)
        return x

## Initialize the Model

In [ ]:
# Initialize the model with adjustable depth
depth = 2 ** exp_num
model = MLP(depth=depth)


## Load Input Data from input.json

In [ ]:
# Load input data from input.json
def load_input_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    return torch.tensor(data['input_data'], requires_grad=True)

input_data = load_input_data('input.json')

## Perform Forward Pass

In [ ]:
# Perform forward pass through the network
output = model(input_data)
print("Model Output:", output)

Model Output: tensor([[0.0209],
        [0.0209]], grad_fn=<AddmmBackward0>)


## Export the Model to ONNX Format

In [ ]:
# Export the model to ONNX format
torch.onnx.export(
    model,                          # Model being run
    input_data,                     # Model input (or a tuple for multiple inputs)
    onnx_path,                      # Where to save the model (can be a file or file-like object)
    export_params=True,             # Store the trained parameter weights inside the model file
    opset_version=10,               # The ONNX version to export the model to
    do_constant_folding=True,       # Whether to execute constant folding for optimization
    input_names=['input'],          # The model's input names
    output_names=['output'],        # The model's output names
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}  # Variable length axes
)

print(f"Model exported to {onnx_path}")

Model exported to mlp5/mlp.onnx


## Validate the ONNX Model

In [ ]:
# Load and check the ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print('ONNX model is valid')

ONNX model is valid


## Generate Settings for ezkl

In [ ]:
res = ezkl.gen_settings(model_path, settings_path)
assert res == True

## Compile the Circuit

In [ ]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

## Generate Proofing and Verification Keys

In [ ]:
# Generate the SRS (Structured Reference String) needed for zk-SNARKs
res = ezkl.get_srs(settings_path)
assert res == True

In [ ]:
# Generate the proofing and verification keys
res = ezkl.setup(compiled_model_path, vk_path, pk_path)
assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

## Generate the Witness File

In [ ]:
# Generate the witness file for proving
res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

## Running the Proving Command

Finally, the proof can be generated and verified using the following command:
```bash
ezkl prove --witness models/mlp/mlp$i/witness.json --pk-path models/mlp/mlp$i/pk.key --compiled-circuit models/mlp/mlp$i/model.compiled --proof-path models/mlp/mlp$i/proof.json
```
This command runs the proof generation using the witness file, proofing key, and compiled circuit.