In [None]:
# Install required Python libraries
try:
    import subprocess
    subprocess.check_call(["pip", "install", "torch", "onnx", "numpy", "pillow"])
except:
    print("Ensure all dependencies are installed.")

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from PIL import Image

In [None]:
# LeNet-5 Model Definition
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Export the model to ONNX using the dummy input
dummy_input = torch.randn(1, 1, 28, 28).float()  # Batch size 1, 1 channel, 28x28 image
print(f"Dummy input shape: {dummy_input.shape}")

# Initialize the LeNet model
model = LeNet()

# Export the LeNet model
onnx_path = "lenet.onnx"
torch.onnx.export(
    model, dummy_input, onnx_path,
    input_names=["input"], output_names=["output"],
    opset_version=13
)

print(f"ONNX model exported to {onnx_path}")

In [None]:
# Load and preprocess the MNIST sample image
image_path = "../../models/data/1052.png"
image = Image.open(image_path).convert("L")  # Convert to grayscale
image = image.resize((28, 28))  # Resize to 28x28 pixels

# Normalize pixel values to [0, 1]
image_data = np.array(image).astype(np.float32) / 255.0  # Convert to float32 and normalize

# Flatten the 28x28 pixel data into a single row (1D array)
flattened_data = image_data.flatten().tolist()  # Convert to a Python list

# Wrap the flattened array in an outer array to make it [[]]
data = [flattened_data]

# Save the JSON file
data_path = "input.json"
with open(data_path, "w") as f:
    json.dump(data, f, indent=4)  # Pretty-print JSON for clarity

print(f"Input data saved to {data_path}")

In [None]:
import subprocess

# Paths
onnx_path = "lenet.onnx"
proof_path = "proof.json"
output_path = "output.json"

# Command for proof generation
# Before using the CLI, you need to generate the binary by running cargo build --release.
cmd = [
    "../../target/release/mina-zkml-cli", "proof",
    "-m", "lenet.onnx",
    "-i", data_path,
    "-o", proof_path,
    "--input-visibility", "public",
    "--output-visibility", "public"
]

# Run the command
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print(result.stdout)
if result.returncode == 0:
    print(f"Proof successfully generated at {proof_path}")
else:
    print(f"Error generating proof: {result.stderr}")

In [None]:
# Create a public output file from the proof
try:
    # Load proof.json
    with open(proof_path, "r") as proof_file:
        proof_data = json.load(proof_file)

    # Extract the "output" field
    if "output" in proof_data:
        output_data = proof_data["output"]
        
        # Save the output data to output.json
        with open(output_path, "w") as output_file:
            json.dump(output_data, output_file, indent=4)
        
        print(f"Output data successfully saved to {output_path}")
    else:
        print("No 'output' field found in proof.json")
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
# Command for proof verification
cmd = [
    "../../target/release/mina-zkml-cli", "verify",
    "-m", "lenet.onnx",
    "-i", data_path,
    "-p", proof_path,
    "-o", output_path,
    "--input-visibility", "public",
    "--output-visibility", "public"
]

# Run the command
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print(result.stdout)
if result.returncode == 0:
    print(f"Proof successfully verified at {proof_path}")
else:
    print(f"Error verifying proof: {result.stderr}")