In [1]:
# Import libraries

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import json
import asyncio
from sklearn import datasets
from tqdm import tqdm
import os
import time

os.environ['ENABLE_ICICLE_GPU'] = 'true'
os.environ['RUST_BACKTRACE']='full'

# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "nest_asyncio"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

import ezkl
import os
import nest_asyncio

nest_asyncio.apply()

# Task definition

print('Data information')
print('=' * 20)

images, labels = datasets.load_digits(return_X_y=True)
images = torch.tensor(images, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
print(f'#samples={len(images)}, image_size={images[0].shape}')

mask = labels % 2 == 1
odd_num_images, odd_num_labels = images[mask], labels[mask]
print(f'#odd_num_images={len(odd_num_images)}')

mask = labels % 2 == 0
even_num_images, even_num_labels = images[mask], labels[mask]
print(f'#even_num_images={len(even_num_images)}')

# Neural network definition
hidden_dim = 256
input_dim = 64
output_dim = 10
num_params = (1 + input_dim) * hidden_dim + (1 + hidden_dim) * output_dim
print(f'num_params={num_params}')


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def get_loss(model, data, labels):
    logits = model(data)
    log_probs = F.log_softmax(logits, dim=1)
    true_log_probs = log_probs.gather(1, labels.view(-1, 1))
    return -true_log_probs.mean()


def get_grad(model, data, labels):
    model.zero_grad()
    loss = get_loss(model, data, labels)
    loss.backward()
    return model.parameters()


def get_acc(logits, labels):
    predicted_labels = logits.argmax(dim=1)
    return (predicted_labels == labels).float().mean().item()


def train(model, x, y, lr=0.003, num_epochs=10):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for _ in range(num_epochs):
        optimizer.zero_grad()
        loss = get_loss(model, x, y)
        loss.backward()
        optimizer.step()
    return model


# Test
mlp = MLP()
rand_params = torch.randn(num_params) * 0.01
mlp.fc1.weight.data = rand_params[:hidden_dim * input_dim].view(hidden_dim, input_dim)
mlp.fc1.bias.data = rand_params[hidden_dim * input_dim:hidden_dim * input_dim + hidden_dim]
mlp.fc2.weight.data = rand_params[-(output_dim * hidden_dim + output_dim):-output_dim].view(output_dim, hidden_dim)
mlp.fc2.bias.data = rand_params[-output_dim:]


Data information
#samples=1797, image_size=torch.Size([64])
#odd_num_images=906
#even_num_images=891
num_params=19210


In [2]:
print(images.shape)
logits = mlp(images)
loss = get_loss(mlp, images, labels)
acc = get_acc(logits, labels)
print(f'Random MLP: loss={loss.item():.4f}, acc={acc:.2f}')

# Train 2 seed MLPs
mlp1 = MLP()
mlp2 = MLP()

mlp1 = train(mlp1, odd_num_images, odd_num_labels)
mlp2 = train(mlp2, even_num_images, even_num_labels)

models = [mlp1, mlp2]
model_names = ['mlp1', 'mlp2']
for model, model_name in zip(models, model_names):
    for x, y, name in zip([odd_num_images, even_num_images, images], [odd_num_labels, even_num_labels, labels], ['d_odd', 'd_even', 'd0-9']):
        logits = model(x)
        acc = get_acc(logits, y)
        print(f'{model_name} acc@{name}={acc:.2f}')
    print('-' * 10)



torch.Size([1797, 64])
Random MLP: loss=2.3013, acc=0.08
mlp1 acc@d_odd=0.95
mlp1 acc@d_even=0.00
mlp1 acc@d0-9=0.48
----------
mlp2 acc@d_odd=0.00
mlp2 acc@d_even=0.96
mlp2 acc@d0-9=0.47
----------


In [4]:
class Slerp(nn.Module):
    def __init__(self):
        super(Slerp, self).__init__()

    
    def forward(self, val, x, y):
      norm_x = F.normalize(x, dim=-1)
      norm_y = F.normalize(y, dim=-1)
      dot = torch.sum(norm_x * norm_y, dim=-1, keepdim=True)
      omega = torch.acos(torch.clamp(dot, -1.0, 1.0))
      sin_omega = torch.sin(omega)
      scale_x = torch.sin((1.0 - val) * omega) / sin_omega
      scale_y = torch.sin(val * omega) / sin_omega
      lin_scale_x = 1.0 - val
      lin_scale_y = val
      return torch.where(sin_omega > 1e-6, scale_x * x + scale_y * y, lin_scale_x * x + lin_scale_y * y)
    '''
    def forward(self, val, x, y):
        return val*x + (1-val)*y
    '''

slerp = Slerp()

RUN_FOLDER = './test_merge_dir/'

from pathlib import Path
Path(RUN_FOLDER).mkdir(parents=True, exist_ok=True) # create directory and any intermediate directories

async def calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, operation_fn, output_names):
    model_path = os.path.join(RUN_FOLDER + op_name + '.onnx')
    compiled_model_path = os.path.join(RUN_FOLDER + op_name + '_network.compiled')
    pk_path = os.path.join(RUN_FOLDER + op_name + '_test.pk')
    vk_path = os.path.join(RUN_FOLDER + op_name + '_test.vk')
    settings_path = os.path.join(RUN_FOLDER + op_name + '_settings.json')
    witness_path = os.path.join(RUN_FOLDER + op_name + '_calibration_witness.json')
    data_path = os.path.join(RUN_FOLDER + op_name + '_calibration_input.json')

    #import ipdb; ipdb.set_trace()
    torch.onnx.export(
        operation_fn,                       # the model/module to be exported
        example_inputs,                 # example inputs
        model_path,                # the file name to save the ONNX model
        export_params=True,          # store the trained parameter weights inside the model file
        opset_version=11,            # the ONNX version to export the model to
        do_constant_folding=True,    # whether to execute constant folding for optimization
        input_names=input_names,  # input names
        output_names=output_names        # output name
    )
    
    py_run_args = ezkl.PyRunArgs()
    py_run_args.input_visibility = "public"
    py_run_args.output_visibility = "public"
    py_run_args.param_visibility = "fixed" # "fixed" for params means that the committed to params are used for all proofs

    res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)

    assert res == True
    data = dict(input_data = calibration_inputs)
    cal_path = os.path.join(RUN_FOLDER + op_name + "_calibration.json")

    json.dump(data, open(cal_path, 'w'))

    await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

    res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
    assert res == True

    # srs path
    #res = ezkl.get_srs(settings_path)
    print(f'await get_srs({settings_path}')
    res = await ezkl.get_srs(settings_path)
    print(f'after await get_srs({settings_path}')
    #res = await ezkl.get_srs(settings_path, commitment='ipa')
    #res = await ezkl.get_srs(settings_path, commitment=ezkl.PyCommitments.IPA)
    print(f'ezkl.setup({compiled_model_path},{vk_path},{pk_path}')
    res = ezkl.setup(
            compiled_model_path,
            vk_path,
            pk_path,
        )
    
    assert res == True


In [None]:
# Initialize the elites with seed models
mlp1_param = torch.concatenate([mlp1.state_dict()['fc1.weight'].view(-1), mlp1.state_dict()['fc1.bias'].view(-1), mlp1.state_dict()['fc2.weight'].view(-1), mlp1.state_dict()['fc2.bias'].view(-1)])
mlp2_param = torch.concatenate([mlp2.state_dict()['fc1.weight'].view(-1), mlp2.state_dict()['fc1.bias'].view(-1), mlp2.state_dict()['fc2.weight'].view(-1), mlp2.state_dict()['fc2.bias'].view(-1)])
elites = torch.stack([mlp1_param, mlp2_param])
#elites = elites[torch.randint(0, 2, (num_elite,))]

# calibrate circuit
val = torch.rand(1)
loop = asyncio.get_event_loop()
calibrate_start = time.time()
#verified = loop.run_until_complete(calibrate_circuit(val, mlp1_param, mlp2_param))

op_name = 'merge'
example_inputs = (val, mlp1_param, mlp2_param)
calibration_inputs = [[val.item()], 
                              torch.rand(*mlp1_param.shape).detach().numpy().reshape(-1).tolist(), 
                              torch.rand(*mlp2_param.shape).detach().numpy().reshape(-1).tolist()]
input_names = ['val', 'parent_1', 'parent_2']  # input names
operation_fn = slerp
output_names = ['merged_weights']        # output name

loop.run_until_complete(calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, operation_fn, output_names))

calibrate_end = time.time()
calibration_duration = calibrate_end - calibrate_start
print('Calibration duration: ' + str(calibration_duration))

value (2082345784286450043322368) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (2082345784286450043322368) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (2082345784286450043322368) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (2082345784286450043322368) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity

await get_srs(./test_merge_dir/merge_settings.json


Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.


In [11]:
verify_start = time.time()
best_ind = top_inds[0]
await verify_only_best_path_proofs(merge_idx_and_p_inds, best_ind)
verify_end = time.time()

verify_duration = verify_end - verify_start
print('Proof verification duration (best path):' + str(verify_duration))

[(1, 1, 1)]
Model 1 from iteration 4 has parents 1 and 1 from iteration 3
Iteration 4
{(1, 0, 1)}
number of parents to verify 1
{(1, 0, 1)}
Model 1 from iteration 3 has parents 0 and 1 from iteration 2
Iteration 3
{(1, 0, 0), (0, 3, 0)}
number of parents to verify 2
{(1, 0, 0), (0, 3, 0)}
Model 1 from iteration 2 has parents 0 and 0 from iteration 1
{(1, 0, 0), (0, 3, 0)}
Model 0 from iteration 2 has parents 3 and 0 from iteration 1
Iteration 2
{(3, 3, 1), (0, 1, 1)}
number of parents to verify 2
{(3, 3, 1), (0, 1, 1)}
Model 3 from iteration 1 has parents 3 and 1 from iteration 0
{(3, 3, 1), (0, 1, 1)}
Model 0 from iteration 1 has parents 1 and 1 from iteration 0
Iteration 1
{(3, 1, 0), (1, 1, 0)}
number of parents to verify 2
{(3, 1, 0), (1, 1, 0)}
Model 3 from iteration 0 has parents 1 and 0 from iteration -1
{(3, 1, 0), (1, 1, 0)}
Model 1 from iteration 0 has parents 1 and 0 from iteration -1
Iteration 0
{(0, 3, 2), (1, 1, 1)}
number of parents to verify 2
Aggregating ./test_merge_d

thread '<unnamed>' panicked at /root/.cargo/git/checkouts/halo2-9de98af521c882c2/8cfca22/halo2_proofs/src/dev.rs:457:13:
row=2097146, usable_rows=0..2097146, k=21
stack backtrace:
   0:     0x7c3295b8c925 - <std::sys::backtrace::BacktraceLock::print::DisplayBacktrace as core::fmt::Display>::fmt::hca750ad87bb2f1d4
   1:     0x7c3295bbae9b - core::fmt::write::h133a0eb20f0a6a5d
   2:     0x7c3295b887ef - std::io::Write::write_fmt::h0b1c7497ddea4e96
   3:     0x7c3295b8dca1 - std::panicking::default_hook::{{closure}}::h7dd45b5804215332
   4:     0x7c3295b8d97c - std::panicking::default_hook::haed8ee3169af9669
   5:     0x7c3295b8e371 - std::panicking::rust_panic_with_hook::h4bf66cb658082ab2
   6:     0x7c3295b8e1d7 - std::panicking::begin_panic_handler::{{closure}}::h7a19fc32e0e387d7
   7:     0x7c3295b8cde9 - std::sys::backtrace::__rust_end_short_backtrace::h6117389318248cc0
   8:     0x7c3295b8de64 - rust_begin_unwind
   9:     0x7c3295bb7e13 - core::panicking::panic_fmt::h1817a57f977b85

PanicException: row=2097146, usable_rows=0..2097146, k=21

In [81]:
## EVALUATION
batch_size = 3

class EvaluateModel(nn.Module):

    def __init__(self, model):
        super().__init__()
        # self.X = X
        self.model = model
    
    def forward(self, X, y):
        logits = self.model(X)
        return get_acc(logits, y)

def get_acc(logits, labels):
    predicted_labels = logits.argmax(dim=-1)
    return (predicted_labels == labels).float().mean()

best_model = reroll_params(elites[0])
eval_mod = EvaluateModel(best_model)
RUN_FOLDER = "./test_eval_dir/"

from pathlib import Path
Path(RUN_FOLDER).mkdir(parents=True, exist_ok=True) # create directory and any intermediate directories


images = images[:batch_size, :]
labels = labels[:batch_size]
print(images.shape)
print(labels.shape)
op_name = 'eval'
example_inputs = (images, labels)
calibration_inputs = [images.detach().numpy().reshape(-1).tolist(),
                       labels.detach().numpy().reshape(-1).tolist()]
input_names = ['images', 'labels']
output_names = ['acc']

loop = asyncio.get_event_loop()
loop.run_until_complete(calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, eval_mod, output_names))


Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.


torch.Size([3, 64])
torch.Size([3])


Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.


 <------------- Numerical Fidelity Report (input_scale: 11, param_scale: 11, scale_input_multiplier: 1) ------------->

+----------------+----------------+----------------+----------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| mean_error     | median_error   | max_error      | min_error      | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |
+----------------+----------------+----------------+----------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| -0.00048828125 | -0.00048828125 | -0.00048828125 | -0.00048828125 | 0.00048828125  | 0.0004882812