## Mnist Clan-ssifier ;)

Here we demonstrate how to use the EZKL package to build an MNIST classifier for on-chain handrawn digit recognition.
The proofs get submitted to a contract that assigns the users account to a digit clan (0-9). The contract keeps track of the member count of each clan. The clan with the most members is the winner!

![zk-gaming-diagram-transformed](https://file.notion.so/f/f/f9535faf-4480-4499-9059-a48ba240eaa9/cd13414a-ecd8-4b8f-90a1-8a2311baa278/Untitled.png?id=365d66ee-e653-4ec3-8eb6-6d2b6306455a&table=block&spaceId=f9535faf-4480-4499-9059-a48ba240eaa9&expirationTimestamp=1701568800000&signature=VJ9p3YsOjYjeLxmkVEWOJw_3VmM6IBkTYxMwQUFKeus&downloadName=Untitled.png)
> **A typical ZK application flow**. For all the image classifictiton hackers out there — this is an fairly straight forward example. A user computes a ZKML-proof that they have calculated a valid classification of a hand drawn digit from a MNIST trained lenet model. They submit this proof to a verifier contract which governs a set of clans, along with the output values of the model (length 10 tensor whereby the index with the max value represented the prediction), and the clan count updates according the lenets model's prediction. 

In [1]:
!pip3 install ezkl==5.0.8


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torchvision"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tf2onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

# make sure you have the dependencies required here already installed
import ezkl
import os
import json
import time
import random
import logging


# uncomment for more descriptive logging
FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.INFO)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # Convolutional encoder
        self.conv1 = nn.Conv2d(1, 6, 5)  # 1 input channel, 6 output channels, 5x5 kernel
        self.conv2 = nn.Conv2d(6, 16, 5) # 6 input channels, 16 output channels, 5x5 kernel

        # Fully connected layers / Dense block
        self.fc1 = nn.Linear(16 * 4 * 4, 120) 
        self.fc2 = nn.Linear(120, 84)         # 120 inputs, 84 outputs
        self.fc3 = nn.Linear(84, 10)          # 84 inputs, 10 outputs (number of classes)

    def forward(self, x):
        # Convolutional block
        x = F.avg_pool2d(F.sigmoid(self.conv1(x)), (2, 2)) # Convolution -> Sigmoid -> Avg Pool
        x = F.avg_pool2d(F.sigmoid(self.conv2(x)), (2, 2)) # Convolution -> Sigmoid -> Avg Pool

        # Flattening
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = self.fc3(x)  # No activation function here, will use CrossEntropyLoss later
        return x


In [9]:
import numpy as np
import os
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import Adam  # Import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

def normalize_img(image, label):
  return torch.round(image), label

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 256
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = LeNet().to(device)
adam = Adam(model.parameters())  # Using Adam with a learning rate of 1e-3
loss_fn = CrossEntropyLoss()
# we set training to 10 epochs for reducing overhead for CI testing.
# set to 100 for better accuracy
all_epoch = 20
prev_acc = 0
for current_epoch in range(all_epoch):
    model.train()
    for idx, (train_x, train_label) in enumerate(train_loader):
        train_x = train_x.to(device)
        # normalize the image to 0 or 1 to reflect the inputs from the drawing board
        train_x = train_x.round()
        train_label = train_label.to(device)
        adam.zero_grad()  # Use adam optimizer
        predict_y = model(train_x.float())
        loss = loss_fn(predict_y, train_label.long())
        loss.backward()
        adam.step()  # Use adam optimizer
    all_correct_num = 0
    all_sample_num = 0
    model.eval()

    for idx, (test_x, test_label) in enumerate(test_loader):
        test_x = test_x.to(device)
         # normalize the image to 0 or 1 to reflect the inputs from the drawing board
        test_x = test_x.round()
        test_label = test_label.to(device)
        predict_y = model(test_x.float()).detach()
        predict_y = torch.argmax(predict_y, dim=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
    print('test accuracy: {:.3f}'.format(acc), flush=True)
    if not os.path.isdir("models"):
        os.mkdir("models")
    torch.save(model, 'models/mnist_{:.3f}.pkl'.format(acc))
    prev_acc = acc


In [25]:
import os

model_path = os.path.join('network_lenet.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('key.pk')
vk_path = os.path.join('key.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')

In [26]:
import torch
import json

model.eval()  # Set the model to evaluation mode

# # Fetch a single data point from the train_dataset
# # Ensure train_dataset is already loaded and accessible
train_data_point, _ = next(iter(train_dataset))
train_data_point = train_data_point.unsqueeze(0)  # Add a batch dimension

# Verify the device (CPU or CUDA) and transfer the data point to the same device as the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_data_point = train_data_point.to(device)

# # Export the model to ONNX format
torch.onnx.export(model, train_data_point, model_path, export_params=True, opset_version=12, do_constant_folding=True, input_names=['input_0'], output_names=['output'])

# Convert the tensor to numpy array and reshape it for JSON serialization
x = train_data_point.cpu().detach().numpy().reshape([-1]).tolist()
data = {'input_data': [x]}
with open('input.json', 'w') as f:
    json.dump(data, f)

print(f"Model exported to {model_path} and input data saved to input.json")

Model exported to network_lenet.onnx and input data saved to input.json


In [27]:
import ezkl

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "private"
run_args.param_visibility = "fixed"
run_args.output_visibility = "public"
run_args.num_inner_cols = 2
run_args.variables = [("batch_size", 1)]

# Capture set of data points
num_data_points = 30

# Fetch 30 data points from the train_dataset
data_points = []
for i, (data_point, _) in enumerate(train_dataset):
    if i >= num_data_points:
        break
    data_points.append(data_point)

# Stack the data points to create a batch
train_data_batch = torch.stack(data_points)

# Add a batch dimension if not already present
if train_data_batch.dim() == 3:
    train_data_batch = train_data_batch.unsqueeze(0)

x = train_data_batch.cpu().detach().numpy().reshape([-1]).tolist()

data = dict(input_data = [x])

cal_path = os.path.join('cal_data.json')

# Serialize data into file:
json.dump( data, open(cal_path, 'w' ))

!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True

res = ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources", scales=[1,7])
assert res == True

INFO ezkl.graph.model 2023-12-01 23:20:41,236 model.rs:724 set batch_size to 1
INFO ezkl.graph.model 2023-12-01 23:20:41,248 model.rs:440 [34mmodel has[0m [34m1[0m [34minstances[0m
INFO ezkl.graph.model 2023-12-01 23:20:41,248 model.rs:1330 calculating num of constraints using dummy model layout...
INFO ezkl.graph.model 2023-12-01 23:20:41,326 model.rs:1403 [34mmodel uses[0m [34m152748[0m [34mrows[0m (coord=[33m305497[0m, constants=[31m289450[0m)
INFO ezkl.graph.model 2023-12-01 23:20:41,339 model.rs:724 set batch_size to 1
INFO ezkl.execute 2023-12-01 23:20:41,371 execute.rs:641 num of calibration batches: 30
INFO ezkl.graph.model 2023-12-01 23:20:41,373 model.rs:724 set batch_size to 1
INFO ezkl.graph.model 2023-12-01 23:20:41,385 model.rs:440 [34mmodel has[0m [34m1[0m [34minstances[0m
INFO ezkl.graph.model 2023-12-01 23:20:41,387 model.rs:1330 calculating num of constraints using dummy model layout...
INFO ezkl.graph.model 2023-12-01 23:20:41,471 model.rs:1403 

In [28]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

INFO ezkl.graph.model 2023-12-01 23:22:01,685 model.rs:724 set batch_size to 1


In [29]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)

INFO ezkl.execute 2023-12-01 23:22:04,616 execute.rs:486 SRS downloaded


In [30]:
# now generate the witness file
witness_path = "witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

INFO ezkl.graph 2023-12-01 23:22:36,097 mod.rs:706 input scales: [1]


In [32]:
res = ezkl.mock(witness_path, compiled_model_path)
assert res == True

INFO ezkl.execute 2023-12-01 23:23:27,082 execute.rs:909 Mock proof
INFO ezkl.graph.vars 2023-12-01 23:23:27,084 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:27,098 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:27,106 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  "num_selectors": 28
}
INFO ezkl.graph.model 2023-12-01 23:23:27,119 model.rs:1054 model layout...
INFO ezkl.graph.model 2023-12-01 23:23:28,770 model.rs:1137 [34mmodel uses[0m [34m6154[0m [34mrows[0m (coord=[33m12308[0m, constants=[31m7042[0m)


In [33]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

INFO ezkl.pfsys.srs 2023-12-01 23:23:30,908 srs.rs:23 loading srs from "kzg.srs"
INFO ezkl.execute 2023-12-01 23:23:30,914 execute.rs:2011 downsizing params to 14 logrows
INFO ezkl.graph.vars 2023-12-01 23:23:30,915 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:30,916 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:30,917 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  "num_selectors": 28
}
INFO ezkl.graph.model 2023-12-01 23:23:30,920 model.rs:1054 model layout...
INFO ezkl.graph.model 2023-12-01 23:23:32,703 model.rs:1137 [34mmodel uses[0m [34m6154[0m [34mrows[0m (coord=[33m12308[0m, constants=[31m7042[0m)
INFO ezkl.pfsys 2023-12-01 23:23:33,713 mod.rs:401 VK took 2.797
INFO ezkl.graph.vars 2023-12-01 23:23:33,714 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:33,714 model.rs:1020 configuring model
INFO ezkl.graph 20

In [34]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

INFO ezkl.pfsys.srs 2023-12-01 23:23:38,095 srs.rs:23 loading srs from "kzg.srs"
INFO ezkl.execute 2023-12-01 23:23:38,097 execute.rs:2011 downsizing params to 14 logrows
INFO ezkl.pfsys 2023-12-01 23:23:38,098 mod.rs:636 loading proving key from "key.pk"
INFO ezkl.graph.vars 2023-12-01 23:23:38,099 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:38,102 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:38,104 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  "num_selectors": 28
}
INFO ezkl.pfsys 2023-12-01 23:23:38,173 mod.rs:470 proof started...
INFO ezkl.graph.vars 2023-12-01 23:23:38,174 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:38,175 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:38,175 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  

{'instances': [[[14385415396251402209, 2429374486035521128, 12558163205804149944, 2583518171365219058], [4476394681747374096, 9457141985490438420, 9584886409590048210, 451740047718875803], [1408065332295237670, 17849026197112403344, 6623204158280506835, 3378725622546023985], [0, 0, 0, 0], [14385415396251402209, 2429374486035521128, 12558163205804149944, 2583518171365219058], [10902020042510041094, 17381486299841078119, 5900175412809962030, 2475245527108272378], [4476394681747374096, 9457141985490438420, 9584886409590048210, 451740047718875803], [3483395353741361115, 3494632259903994625, 6657987792994187913, 108272644256946680], [4476394681747374096, 9457141985490438420, 9584886409590048210, 451740047718875803], [415066004289224689, 11886516471525959549, 3696305541684646538, 3035258219084094862]]], 'proof': '0fe5a1a2215bd7e416412464805d077ce46ac20c3c64fb41458369fcbc21591a1441735133b06499d06afe62d16db95629afd177ae2bec94fe125d6fb242dd73295baf959583aa18e84b9ae150fd334128f555819877c2101ae5c

In [36]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")

INFO ezkl.pfsys.srs 2023-12-01 23:23:44,176 srs.rs:23 loading srs from "kzg.srs"
INFO ezkl.execute 2023-12-01 23:23:44,183 execute.rs:2011 downsizing params to 14 logrows
INFO ezkl.pfsys 2023-12-01 23:23:44,184 mod.rs:614 loading verification key from "key.vk"
INFO ezkl.graph.vars 2023-12-01 23:23:44,184 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:44,185 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:44,185 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  "num_selectors": 28
}
INFO ezkl.execute 2023-12-01 23:23:44,194 execute.rs:1734 verify took 0.5
INFO ezkl.execute 2023-12-01 23:23:44,195 execute.rs:1739 verified: true


verified


We can now create an EVM / `.sol` verifier that can be deployed on chain to verify submitted proofs using a view function.

In [37]:

abi_path = 'test.abi'
sol_code_path = 'test.sol'

res = ezkl.create_evm_verifier(
        vk_path,
        srs_path,
        settings_path,
        sol_code_path,
        abi_path,
    )
assert res == True

INFO ezkl.execute 2023-12-01 23:23:46,347 execute.rs:76 checking solc installation..
INFO ezkl.pfsys.srs 2023-12-01 23:23:46,348 srs.rs:23 loading srs from "kzg.srs"
INFO ezkl.execute 2023-12-01 23:23:46,350 execute.rs:2011 downsizing params to 14 logrows
INFO ezkl.pfsys 2023-12-01 23:23:46,351 mod.rs:614 loading verification key from "key.vk"
INFO ezkl.graph.vars 2023-12-01 23:23:46,351 vars.rs:408 number of blinding factors: 5
INFO ezkl.graph.model 2023-12-01 23:23:46,351 model.rs:1020 configuring model
INFO ezkl.graph 2023-12-01 23:23:46,352 mod.rs:1246 circuit size: 
 {
  "num_advice_columns": 6,
  "num_challenges": 0,
  "num_fixed": 6,
  "num_instances": 1,
  "num_selectors": 28
}


## Verify on the evm

In [21]:
# Make sure anvil is running locally first
# run with $ anvil -p 3030
# we use the default anvil node here
import json

address_path = os.path.join("address.json")

res = ezkl.deploy_evm(
    address_path,
    sol_code_path,
    'http://127.0.0.1:3030'
)

assert res == True

with open(address_path, 'r') as file:
    addr = file.read().rstrip()

INFO ezkl.execute 2023-12-01 23:16:32,549 execute.rs:76 checking solc installation..


RuntimeError: Failed to run deploy_evm: error sending request for url (http://127.0.0.1:3030/): error trying to connect: tcp connect error: Connection refused (os error 61)

In [None]:
# make sure anvil is running locally
# $ anvil -p 3030

res = ezkl.verify_evm(
    proof_path,
    addr,
    "http://127.0.0.1:3030"
)
assert res == True