In [1]:
# Import libraries

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import json
import asyncio
from sklearn import datasets
from tqdm import tqdm
import os
import time

os.environ['RUST_BACKTRACE'] = 'full'

# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

import ezkl
import os

# Neural network definition
hidden_dim = 256
input_dim = 64
output_dim = 10
num_params = (1 + input_dim) * hidden_dim + (1 + hidden_dim) * output_dim
print(f'num_params={num_params}')


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

mlp1 = MLP()
mlp2 = MLP()

# get flat parameters of seed models to merge
mlp1_param = torch.concatenate([mlp1.state_dict()['fc1.weight'].view(-1), mlp1.state_dict()['fc1.bias'].view(-1), mlp1.state_dict()['fc2.weight'].view(-1), mlp1.state_dict()['fc2.bias'].view(-1)])
mlp2_param = torch.concatenate([mlp2.state_dict()['fc1.weight'].view(-1), mlp2.state_dict()['fc1.bias'].view(-1), mlp2.state_dict()['fc2.weight'].view(-1), mlp2.state_dict()['fc2.bias'].view(-1)])


num_params=19210


In [2]:
# Model merge class definitions

class Slerp(nn.Module):
    def __init__(self):
        super(Slerp, self).__init__()

    def forward(self, val, x, y):
      norm_x = F.normalize(x, dim=-1)
      norm_y = F.normalize(y, dim=-1)
      dot = torch.sum(norm_x * norm_y, dim=-1, keepdim=True)
      omega = torch.acos(torch.clamp(dot, -1.0, 1.0))
      sin_omega = torch.sin(omega)
      scale_x = torch.sin((1.0 - val) * omega) / sin_omega
      scale_y = torch.sin(val * omega) / sin_omega
      lin_scale_x = 1.0 - val
      lin_scale_y = val
      return torch.where(sin_omega > 1e-6, scale_x * x + scale_y * y, lin_scale_x * x + lin_scale_y * y)

class WeightedAvg(nn.Module):
    def __init__(self):
        super(WeightedAvg, self).__init__()

    def forward(self, val, x, y):
        return val*x + (1-val)*y

slerp = Slerp()
weighted_avg = WeightedAvg()

RUN_FOLDER = './test_merge_dir/'

from pathlib import Path
Path(RUN_FOLDER).mkdir(parents=True, exist_ok=True) # create directory if it's not there

# calibration function for either merge operation
async def calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, operation_fn, output_names):
    model_path = os.path.join(RUN_FOLDER + op_name + '.onnx')
    compiled_model_path = os.path.join(RUN_FOLDER + op_name + '_network.compiled')
    pk_path = os.path.join(RUN_FOLDER + op_name + '_test.pk')
    vk_path = os.path.join(RUN_FOLDER + op_name + '_test.vk')
    settings_path = os.path.join(RUN_FOLDER + op_name + '_settings.json')
    witness_path = os.path.join(RUN_FOLDER + op_name + '_calibration_witness.json')
    data_path = os.path.join(RUN_FOLDER + op_name + '_calibration_input.json')

    #import ipdb; ipdb.set_trace()
    torch.onnx.export(
        operation_fn,                       # the model/module to be exported
        example_inputs,                 # example inputs
        model_path,                # the file name to save the ONNX model
        export_params=True,          # store the trained parameter weights inside the model file
        opset_version=11,            # the ONNX version to export the model to
        do_constant_folding=True,    # whether to execute constant folding for optimization
        input_names=input_names,  # input names
        output_names=output_names        # output name
    )
    
    py_run_args = ezkl.PyRunArgs()
    py_run_args.input_visibility = "public"
    py_run_args.output_visibility = "public"
    py_run_args.param_visibility = "fixed" # "fixed" for params means that the committed to params are used for all proofs

    res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)

    assert res == True
    data = dict(input_data = calibration_inputs)
    cal_path = os.path.join(RUN_FOLDER + op_name + "_calibration.json")

    json.dump(data, open(cal_path, 'w'))

    await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

    res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
    assert res == True
    return model_path, compiled_model_path, settings_path, vk_path, pk_path



In [3]:
# calibrate weighted average operation
val = torch.rand(1)
loop = asyncio.get_event_loop()
calibrate_start = time.time()

op_name = 'weighted_avg'
example_inputs = (val, mlp1_param, mlp2_param)

# inputs to merge function are the randomly sampled weight value, and the flattened parameters of each of the two parent models to be merged
calibration_inputs = [[val.item()], 
                              torch.rand(*mlp1_param.shape).detach().numpy().reshape(-1).tolist(), 
                              torch.rand(*mlp2_param.shape).detach().numpy().reshape(-1).tolist()]
input_names = ['val', 'parent_1', 'parent_2']  # input names
operation_fn = weighted_avg
output_names = ['merged_weights']        # output name

model_path, compiled_model_path, settings_path, vk_path, pk_path = await calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, operation_fn, output_names)

calibrate_end = time.time()
calibration_duration = calibrate_end - calibrate_start
print('Calibration duration for weighted average: ' + str(calibration_duration))

Using 2 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 3 columns for non-linearity table.


 <------------- Numerical Fidelity Report (input_scale: 13, param_scale: 13, scale_input_multiplier: 10) ------------->

+------------------+----------------+---------------+----------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| mean_error       | median_error   | max_error     | min_error      | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |
+------------------+----------------+---------------+----------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| 0.00000018319736 | 0.000008374453 | 0.00007992983 | -0.00007939339 | 0.000021565987 | 0.000008374453   | 0.00007992983 | 0            

Calibration duration for weighted average: 13.178054809570312


In [4]:
# srs path
print(f'get_srs({settings_path}')
res = ezkl.get_srs(settings_path)
print(f'after get_srs({settings_path}')

get_srs(./test_merge_dir/weighted_avg_settings.json
after get_srs(./test_merge_dir/weighted_avg_settings.json


In [5]:
print(f'ezkl.setup({compiled_model_path}, {vk_path}, {pk_path}')
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True

ezkl.setup(./test_merge_dir/weighted_avg_network.compiled, ./test_merge_dir/weighted_avg_test.vk, ./test_merge_dir/weighted_avg_test.pk


In [6]:
# calibrate circuit
val = torch.rand(1)
loop = asyncio.get_event_loop()
calibrate_start = time.time()

# calibrate slerp
op_name = 'slerp'
example_inputs = (val, mlp1_param, mlp2_param)

# inputs to merge function are the randomly sampled weight value, and the flattened parameters of each of the two parent models to be merged
calibration_inputs = [[val.item()], 
                              torch.rand(*mlp1_param.shape).detach().numpy().reshape(-1).tolist(), 
                              torch.rand(*mlp2_param.shape).detach().numpy().reshape(-1).tolist()]
input_names = ['val', 'parent_1', 'parent_2']  # input names
operation_fn = slerp
output_names = ['merged_weights']        # output name

model_path, compiled_model_path, settings_path, vk_path, pk_path = await calibrate_operation(op_name, example_inputs, calibration_inputs, input_names, operation_fn, output_names)

calibrate_end = time.time()
calibration_duration = calibrate_end - calibrate_start
print('Calibration duration for slerp: ' + str(calibration_duration))

value (1443106651889401625313280) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (1443106651889401625313280) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (1443106651889401625313280) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
value (1443106651889401625313280) out of range: (0, 0)
forward pass failed: "failed to forward: [halo2] General synthesis error"
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity

Calibration duration for slerp: 81.44143986701965


In [7]:
# srs path
print(f'get_srs({settings_path}')
res = ezkl.get_srs(settings_path)
print(f'after get_srs({settings_path}')

get_srs(./test_merge_dir/slerp_settings.json
after get_srs(./test_merge_dir/slerp_settings.json


In [None]:
print(f'ezkl.setup({compiled_model_path},{vk_path},{pk_path}')
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True

Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
