
## Model Architecture and training

In [1]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

In [2]:
model = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
x = torch.randn(1, 3)

# this is where you'd train the model

## EZKL 

In [3]:

import os 
import ezkl


model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')

witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')



In [4]:

import json 


# Flips the neural net into inference mode
model.eval()
model.to('cpu')

    # Export the model
torch.onnx.export(model,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})


SEQ_LEN = 10
shape = (SEQ_LEN, 3)
# sequence of length 10
x = torch.randn(*shape)

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data_json = dict(input_data = [data_array])

print(data_json)

    # Serialize data into file:
json.dump( data_json, open(data_path, 'w' ))


{'input_data': [[1.6220875978469849, 0.3597317636013031, 0.8345175981521606, 1.9685660600662231, 0.45953869819641113, 0.35968947410583496, 0.7673524022102356, -0.058587852865457535, 0.3262140154838562, -0.33208921551704407, -0.6318570375442505, 1.1284105777740479, 0.46596136689186096, 1.4800872802734375, 1.363403558731079, 0.035137902945280075, -0.641409695148468, -0.05915956571698189, 0.5282636880874634, 0.9504408240318298, 0.40337294340133667, -0.1421440988779068, 2.072631359100342, -1.0321298837661743, -0.8974454402923584, 0.7173476219177246, 0.5790023803710938, -0.8475150465965271, -0.7141340374946594, -1.978621006011963]]}




In [5]:


run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", SEQ_LEN)]

# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True

res = ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True


Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns 

In [6]:
cal_path = os.path.join("calibration.json")

data_array = (torch.randn(10, *shape).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))

ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 12 columns for non-linearity table.
Using 12 columns for non-linearity table.
Using 12 columns for non-linearity table.
Using 12 columns for non-linearity table.
Using 12 col

True

In [7]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [8]:
# srs path
res = ezkl.get_srs( settings_path)

In [9]:
# now generate the witness file 
witness_path = "lstmwitness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [10]:
res = ezkl.mock(witness_path, compiled_model_path)
assert res == True

In [11]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [12]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

{'instances': [['82ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'd100000000000000000000000000000000000000000000000000000000000000', '78ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '18ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0d01000000000000000000000000000000000000000000000000000000000000', '5bffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'cefeffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '7f01000000000000000000000000000000000000000000000000000000000000', 'beffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '55ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'e801000000000000000000000000000000000000000000000000000000000000', '7700000000000000000000000000000000000000000000000000000000000000', '1affffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '1302000000000000000000000000000000000000000000000000000000000000', '5bffffef93f5e1439170b97948e8332

In [13]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )

assert res == True
print("verified")

verified
