In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import json
import asyncio
from sklearn import datasets
from tqdm import tqdm
import os
import time

import ezkl
import os


images, labels = datasets.load_digits(return_X_y=True)
images = torch.tensor(images, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)


hidden_dim = 256
input_dim = 64
output_dim = 10
num_params = (1 + input_dim) * hidden_dim + (1 + hidden_dim) * output_dim
print(f'num_params={num_params}')


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class EvaluateModel(nn.Module):

    def __init__(self, model):
        super().__init__()
        # self.X = X
        self.model = model
    
    def forward(self, X, y):
        logits = self.model(X)
        return get_acc(logits, y)

def get_acc(logits, labels):
    predicted_labels = logits.argmax(dim=1)
    return (predicted_labels == labels).float().mean()

mlp = MLP()
eval_mod = EvaluateModel(mlp)
images = images[:2, :]
labels = labels[:2]

RUN_FOLDER = "./"
model_path = os.path.join(RUN_FOLDER + 'eval.onnx')
compiled_model_path = os.path.join(RUN_FOLDER + 'eval_network.compiled')
pk_path = os.path.join(RUN_FOLDER + 'test.pk')
vk_path = os.path.join(RUN_FOLDER + 'test.vk')
settings_path = os.path.join(RUN_FOLDER + 'settings.json')                                            
                                                                                                               
witness_path = os.path.join(RUN_FOLDER + 'witness.json')                                              
                                                                                                               
data_path = os.path.join(RUN_FOLDER + 'test_input.json')                                              
torch.onnx.export(
    eval_mod,                       # the model/module to be exported
    (images, labels),                 # example inputs
    model_path,                # the file name to save the ONNX model
    export_params=True,          # store the trained parameter weights inside the model file
    opset_version=11,            # the ONNX version to export the model to
    do_constant_folding=True,    # whether to execute constant folding for optimization
    input_names=['images', 'labels'],  # input names
    output_names=['acc']        # output name
)

py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed"
res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
cal_path = os.path.join(RUN_FOLDER + "calibration.json")
data = dict(input_data = [torch.rand(*images.shape).detach().numpy().reshape(-1).tolist(),
                         torch.rand(*labels.shape).detach().numpy().reshape(-1).tolist()])

json.dump(data, open(cal_path, 'w'))
await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

res = await ezkl.get_srs( settings_path)
asseert res == True


data = dict(input_data = [images.detach().numpy().reshape(-1).tolist(),
                       labels.detach().numpy().reshape(-1).tolist()])

data_path = os.path.join(RUN_FOLDER + 'actual_test_input.json')
with open(data_path, "w") as f:
    json.dump(data, f)


res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert res == True
assert os.path.isfile(witness_path)

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

proof_path = os.path.join(RUN_FOLDER + "test_proof.pf")
res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,

        "single",
    )

