In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import json
import asyncio
from sklearn import datasets
from tqdm import tqdm
import os
import time
import tqdm

import ezkl
import os


images, labels = datasets.load_digits(return_X_y=True)
images = torch.tensor(images, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)


hidden_dim = 256
input_dim = 64
output_dim = 10
num_params = (1 + input_dim) * hidden_dim + (1 + hidden_dim) * output_dim

batch_size = 3
print(f'num_params={num_params}')


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class EvaluateModel(nn.Module):

    def __init__(self, model):
        super().__init__()
        # self.X = X
        self.model = model
    
    def forward(self, X, y):
        logits = self.model(X)
        return get_acc(logits, y)

def get_acc(logits, labels):
    predicted_labels = logits.argmax(dim=-1)
    return (predicted_labels == labels).float().mean()

mlp = MLP()
eval_mod = EvaluateModel(mlp)
RUN_FOLDER = "./test_eval_dir/"

from pathlib import Path
Path(RUN_FOLDER).mkdir(parents=True, exist_ok=True) # create directory and any intermediate directories

model_path = os.path.join(RUN_FOLDER + 'eval.onnx')
compiled_model_path = os.path.join(RUN_FOLDER + 'eval_network.compiled')
pk_path = os.path.join(RUN_FOLDER + 'test.pk')
vk_path = os.path.join(RUN_FOLDER + 'test.vk')
settings_path = os.path.join(RUN_FOLDER + 'settings.json')                                            
                                                                                                               
witness_path = os.path.join(RUN_FOLDER + 'witness.json')                                              
                                                                                                               
data_path = os.path.join(RUN_FOLDER + 'test_input.json')    
images = images[:batch_size, :]
labels = labels[:batch_size]
print(images.shape)
print(labels.shape)
eval_mod = torch.jit.script(eval_mod)

torch.onnx.export(
    eval_mod,                       # the model/module to be exported
    (images, labels),                 # example inputs
    model_path,                # the file name to save the ONNX model
    export_params=True,          # store the trained parameter weights inside the model file
    opset_version=11,            # the ONNX version to export the model to
    do_constant_folding=True,    # whether to execute constant folding for optimization
    input_names=['images', 'labels'],  # input names
    output_names=['acc']        # output name
)

py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed"
res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
cal_path = os.path.join(RUN_FOLDER + "calibration.json")
# data = dict(input_data = [torch.rand(*images.shape).detach().numpy().reshape(-1).tolist(),
#                          torch.rand(*labels.shape).detach().numpy().reshape(-1).tolist()])
data = dict(input_data = [images.detach().numpy().reshape(-1).tolist(),
                       labels.detach().numpy().reshape(-1).tolist()])


json.dump(data, open(cal_path, 'w'))
# calibrate
start_cal = time.time()
await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
end_cal = time.time()
cal_dur = end_cal - start_cal
print('Calibration duration: ' + str(cal_dur))

Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.


num_params=19210
torch.Size([3, 64])
torch.Size([3])


Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.


 <------------- Numerical Fidelity Report (input_scale: 11, param_scale: 11, scale_input_multiplier: 1) ------------->

+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| mean_error | median_error | max_error | min_error | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |
+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| 0          | 0            | 0         | 0         | 0              | 0                | 0             | 0             | 0                  | 0   

Calibration duration: 2.7758753299713135


In [17]:
start_comp = time.time()
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
end_comp = time.time()
comp_dur = end_comp - start_comp
print('Circuit compilation duration: ' + str(comp_dur))

assert res == True

srs_start = time.time()
res = await ezkl.get_srs( settings_path)
assert res == True
srs_end = time.time()
srs_dur = srs_end - srs_start
print('SRS duration: ' + str(srs_dur))

Circuit compilation duration: 0.008025169372558594
SRS duration: 0.015050888061523438


In [18]:
print(images.shape)
print(labels.shape)

torch.Size([3, 64])
torch.Size([3])


In [19]:
#re init paths
# merge_id = 'iter_' + str(iteration) + '_merge_idx_' + str(merge_idx) + '_p1_' + str(parent_id_1.item()) + '_p2_' + str(parent_id_2.item())
# model_path = os.path.join(RUN_FOLDER + 'slerp.onnx')
# compiled_model_path = os.path.join(RUN_FOLDER + 'network.compiled')
# pk_path = os.path.join(RUN_FOLDER + merge_id + '_test.pk')
# vk_path = os.path.join(RUN_FOLDER + merge_id + '_test.vk')
# settings_path = os.path.join(RUN_FOLDER + 'settings.json')

# witness_path = os.path.join(RUN_FOLDER + 'witness.json')
# data_path = os.path.join(RUN_FOLDER + 'slerp_input_' + merge_id + '.json')
model_path = os.path.join(RUN_FOLDER + 'eval.onnx')
compiled_model_path = os.path.join(RUN_FOLDER + 'eval_network.compiled')
pk_path = os.path.join(RUN_FOLDER + 'test.pk')
vk_path = os.path.join(RUN_FOLDER + 'test.vk')
settings_path = os.path.join(RUN_FOLDER + 'settings.json')                                            
                                                                                                               
                                                                                                               

In [21]:
# get test input
images, labels = datasets.load_digits(return_X_y=True)
images = torch.tensor(images, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
print(images.shape)
print(labels.shape)

total_chunks = images.shape[0] // batch_size # assumes divisible
proof_setup_start = time.time()
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

proof_setup_end = time.time()
proof_setup_dur = proof_setup_end - proof_setup_start
print('Proof setup duration: ' + str(proof_setup_dur))

#total_chunks = 4 # for testing
for i in tqdm.tqdm(range(total_chunks)):

    witness_path = os.path.join(RUN_FOLDER + f'witness_{i}.json')                                              
    data = dict(input_data = [images[i * batch_size: (i + 1) * batch_size, :].detach().numpy().reshape(-1).tolist(),
                        labels[i * batch_size: (i + 1) * batch_size].detach().numpy().reshape(-1).tolist()])


    data_path = os.path.join(RUN_FOLDER + f'actual_test_input_{i}.json')
    with open(data_path, "w") as f:
        json.dump(data, f)

    witness_start = time.time()
    res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
    assert os.path.isfile(witness_path)
    witness_end = time.time()
    witness_dur = witness_end - witness_start
    print('Witness duration: ' + str(witness_dur))


    proof_path = os.path.join(RUN_FOLDER + f"test_proof_{i}.pf")
    proof_start = time.time()
    # prove
    res = ezkl.prove(
            witness_path,
            compiled_model_path,
            pk_path,
            proof_path,

            "single",
        )

    #print(res)
    proof_end = time.time()
    proof_duration = proof_end - proof_start
    print('Proof duration: ' + str(proof_duration))



torch.Size([1797, 64])
torch.Size([1797])
Proof setup duration: 6.927207946777344


  0%|                                                                                                      | 0/4 [00:00<?, ?it/s]

Witness duration: 0.05954289436340332


 25%|███████████████████████▌                                                                      | 1/4 [00:11<00:33, 11.31s/it]

Proof duration: 11.245236873626709
Witness duration: 0.03202199935913086


 50%|███████████████████████████████████████████████                                               | 2/4 [00:22<00:22, 11.27s/it]

Proof duration: 11.219467401504517
Witness duration: 0.04000139236450195


 75%|██████████████████████████████████████████████████████████████████████▌                       | 3/4 [00:33<00:11, 11.35s/it]

Proof duration: 11.389256715774536
Witness duration: 0.04317736625671387


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:45<00:00, 11.35s/it]

Proof duration: 11.365727663040161





In [23]:
# verify generated proofs
start_verify = time.time()
for i in range(total_chunks):
    witness_path = os.path.join(RUN_FOLDER + f'witness_{i}.json')
    proof_path = os.path.join(RUN_FOLDER + f"test_proof_{i}.pf")
    res = ezkl.verify(
            proof_path,
            settings_path,
            vk_path
        )
    assert res == True
    print("verified")
    
end_verify = time.time()
verify_dur = end_verify - start_verify
print('Total verify duration: ' + str(verify_dur))

verified
verified
verified
verified
Total verify duration: 0.07295751571655273


In [None]:
# generate witness
'''
witness_start = time.time()
res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
witness_end = time.time()
witness_dur = witness_end - witness_start
print('Witness duration: ' + str(witness_dur))
'''

In [None]:
'''
# setup proof
proof_setup_start = time.time()
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

proof_setup_end = time.time()
proof_setup_dur = proof_setup_end - proof_setup_start
print('Proof setup duration: ' + str(proof_setup_dur))
'''

In [None]:
'''
proof_path = os.path.join(RUN_FOLDER + "test_proof.pf")
proof_start = time.time()
# prove
res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,

        "single",
    )

print(res)
proof_end = time.time()
proof_duration = proof_end - proof_start
print('Proof duration: ' + str(proof_duration))
'''