# IMPORT Library:

In [1]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass


# here we create and (potentially train a model)

# make sure you have the dependencies required here already installed
from torch import nn
import ezkl
import os
import json
import torch
from PIL import Image

# DEFINE PATH

In [2]:
!mkdir generating_files

In [3]:
model_path = os.path.join('generating_files/network.onnx')
compiled_model_path = os.path.join('generating_files/network.compiled')
pk_path = os.path.join('generating_files/proving_key.pk')
vk_path = os.path.join('generating_files/verification_key.vk')
settings_path = os.path.join('generating_files/settings.json')

witness_path = os.path.join('generating_files/witness.json')
data_path = os.path.join('generating_files/input.json')

# CREATE .ONNX FILE

In [4]:
def divide_image_to_4_parts(x, w, h):
    # Create a copy of x without gradients
    x_copy = x.detach().clone()

    # Compute midpoints
    mid_w, mid_h = w // 2, h // 2

    # Divide into four parts
    top_left = x_copy[:, :, :mid_w, :mid_h]
    top_right = x_copy[:, :, :mid_w, mid_h:]
    bottom_left = x_copy[:, :, mid_w:, :mid_h]
    bottom_right = x_copy[:, :, mid_w:, mid_h:]

    return top_left, top_right, bottom_left, bottom_right

In [5]:
# Define a model
# Set the random seed for reproducibility
torch.manual_seed(42)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)

        self.commit_conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)

         # Initialize custom weights and bias
        self.commit_conv.weight = nn.Parameter(torch.randn(2, 1, 5, 5))  # Shape must match (out_channels, in_channels, kernel_size, kernel_size)
        self.commit_conv.bias = nn.Parameter(torch.randn(2))  # Shape must match the out_channels


    def forward(self, x):
        top_left, top_right, bottom_left, bottom_right = divide_image_to_4_parts(x, x.shape[2], x.shape[3])

        # mean conv result of each part
        mean_tl = torch.mean(self.commit_conv(top_left)).unsqueeze(0)
        mean_tr = torch.mean(self.commit_conv(top_right)).unsqueeze(0)
        mean_bl = torch.mean(self.commit_conv(bottom_left)).unsqueeze(0)
        mean_br = torch.mean(self.commit_conv(bottom_right)).unsqueeze(0)


        # Pass through convolution layer and flatten the result
        res = self.conv(x).flatten()

        # Concatenate sum tensor with flattened convolution output
        final = torch.cat((mean_tl, mean_tr, mean_bl, mean_br, res), dim=0)

        return final


circuit = MyModel()

In [6]:
# create a dummy data use for export onnx model (because export require run the model to trace)
x = torch.rand(1, 1, 28, 28)

In [7]:
# Flips the neural net into inference mode
circuit.eval()

# Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=18,          # the ONNX version to export the model to
                      do_constant_folding=False,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})


In [8]:
# print onnx file
import onnx

# Load the ONNX model
model = onnx.load(model_path)

# Print a human-readable representation of the model
print(onnx.helper.printable_graph(model.graph))

graph main_graph (
  %input[FLOAT, batch_sizex1x28x28]
) initializers (
  %conv.weight[FLOAT, 2x1x5x5]
  %conv.bias[FLOAT, 2]
  %commit_conv.weight[FLOAT, 2x1x5x5]
  %commit_conv.bias[FLOAT, 2]
) {
  %/Constant_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Shape_output_0 = Shape(%input)
  %/Constant_1_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_1_output_0)
  %/Constant_2_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Shape_1_output_0 = Shape(%input)
  %/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Gather_1_output_0 = Gather[axis = 0](%/Shape_1_output_0, %/Constant_3_output_0)
  %/Constant_4_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Div_output_0 = Div(%/Gather_output_0, %/Constant_4_output_0)
  %/Cast_output_0 = Cast[to = 7](%/Div_output_0)
  %/Cast_1_output_0 = Cast[to = 7](%/Cast_output_0)
  %/Constant_5_output_0 = Constant[value = <Scalar Tensor []>]()
  %/Div_1_ou

# SET UP

In [9]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "private"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed" # private by default
py_run_args.input_scale = 40
py_run_args.param_scale = 43
py_run_args.scale_rebase_multiplier = 10

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)

assert res == True

In [10]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [11]:
# srs path
res = await ezkl.get_srs(settings_path)

In [12]:
# Input for model, use grayscale image from MNIST dataset
from keras.datasets import mnist
(train_X, train_y), (test_X, test_y) = mnist.load_data()
train_X = torch.from_numpy(train_X) / 255.0
train_X = train_X.unsqueeze(1)
print(train_X.shape)
# print(train_X[0])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
torch.Size([60000, 1, 28, 28])


In [13]:
# Use the first image in the dataset to demo
data_array = ((train_X[0]).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump( data, open(data_path, 'w' ))

In [14]:
# now generate the witness file

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [15]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK


import time

# Start time
start_time = time.time()

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,

    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

# End time
end_time = time.time()

# Calculate elapsed time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.6f} seconds")

Elapsed time: 29.788386 seconds


# PROVING

In [16]:
# GENERATE A PROOF

# Start time
start_proving_time = time.time()

proof_path = os.path.join('generating_files/proof.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,

        "single",
    )

print(res)
assert os.path.isfile(proof_path)


# End time
end_proving_time = time.time()

# Calculate elapsed time
elapsed_proving_time = end_proving_time - start_proving_time
print(f"Elapsed time: {elapsed_time:.6f} seconds")

{'instances': [['000000000024a5b871d21db21819f20b00000000000000000000000000000000', '0000000000f8a7edcd55eafa12d4831900000000000000000000000000000000', '0000000000a0b4262d1b8e482ed8631400000000000000000000000000000000', '00000000005170d7cd3293a1431fd11900000000000000000000000000000000', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881b5d588181b64550b829a031e1724e6430', '010000f093f5e1439170b979d82c881

# VERIFY

In [17]:
# Start time
start_verify_time = time.time()

# VERIFY IT

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
    )

assert res == True
print("verified")


# End time
end_verify_time = time.time()

# Calculate elapsed time
elapsed_verify_time = end_verify_time - start_verify_time
print(f"Elapsed time: {elapsed_verify_time:.6f} seconds")

verified
Elapsed time: 0.163738 seconds


# Verifier translate from "instances" to "float"

In [18]:
# mean of commit-convolutional layer of the 1st part (top-left) of the image
print(ezkl.felt_to_float('000000000024a5b871d21db21819f20b00000000000000000000000000000000', 126))
# the result will be 0.18665149127670336

0.18665149127670336
