In [25]:
# imports
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Subset, DataLoader
import numpy as np

In [26]:
classes_to_filter = ["orchid", "poppy", "rose", "sunflower", "tulip", "maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar100_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
cifar100_classes = cifar100_train.classes

class_to_idx = {cls_name: idx for idx, cls_name in enumerate(cifar100_classes)}

filtered_class_indices = [class_to_idx[class_name] for class_name in classes_to_filter]
print(filtered_class_indices)

filtered_cifar100_train = [idx for idx, label in enumerate(cifar100_train.targets) if label in filtered_class_indices]

batch_size = 64
filtered_dataset = Subset(cifar100_train, filtered_cifar100_train)
dataloader = DataLoader(filtered_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
[54, 62, 70, 82, 92, 47, 52, 56, 59, 96]


In [27]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

In [28]:
from torch import nn
import torch.optim as optim
import ezkl
import os
import json
import torch
from sklearn.metrics import accuracy_score

In [29]:
class Model(nn.Module):
    def __init__(self, num_classes):
        super(Model, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5, stride=2)

        self.relu = nn.ReLU()

        self.d1 = nn.Linear(3 * 5 * 5, 32)
        self.d2 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        x = x.flatten(start_dim = 1)

        x = self.d1(x)
        x = self.relu(x)

        logits = self.d2(x)

        return logits

In [30]:
num_classes = len(classes_to_filter)
learning_rate = 0.001
batch_size = 64
num_epochs = 10

In [31]:
model = Model(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    predictions = []
    labels = []

    for images, targets in dataloader:
        mapping = {54: 0, 62: 1, 70: 2, 82: 3, 92: 4, 47: 5, 52: 6, 56: 7, 59: 8, 96: 9}
        targets = torch.LongTensor([mapping[int(number)] for number in targets])
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        predictions.extend(predicted.tolist())
        labels.extend(targets.tolist())

    accuracy = accuracy_score(labels, predictions)
    print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {total_loss / len(dataloader):.4f} - Accuracy: {accuracy * 100:.2f}%')

# Save the trained model (optional)
torch.save(model.state_dict(), 'cifar100_cnn_model.pth')

Epoch [1/10] - Loss: 2.0796 - Accuracy: 21.14%
Epoch [2/10] - Loss: 1.7451 - Accuracy: 33.10%
Epoch [3/10] - Loss: 1.6687 - Accuracy: 35.20%
Epoch [4/10] - Loss: 1.6123 - Accuracy: 39.18%
Epoch [5/10] - Loss: 1.5798 - Accuracy: 40.74%
Epoch [6/10] - Loss: 1.5564 - Accuracy: 42.04%
Epoch [7/10] - Loss: 1.5288 - Accuracy: 43.78%
Epoch [8/10] - Loss: 1.5148 - Accuracy: 43.14%
Epoch [9/10] - Loss: 1.4938 - Accuracy: 43.96%
Epoch [10/10] - Loss: 1.4821 - Accuracy: 45.26%


In [32]:
model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')

In [33]:


# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*torch.rand(1,*[3, 32, 32], requires_grad=True)

# Flips the neural net into inference mode
model.eval()

    # Export the model
torch.onnx.export(model,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))


verbose: False, log level: Level.ERROR



In [34]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "private"

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)

assert res == True

res = await ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True

In [35]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [36]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)

In [37]:
# now generate the witness file

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [38]:
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [39]:
proof_path = os.path.join('test.pf')

proof = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single",
    )

print(proof)
assert os.path.isfile(proof_path)

{'instances': [[[7959790035488735211, 12951774245394433045, 16242874202584236123, 560012691975822483], [3898461358030585804, 15381148731429954174, 10354293334678834451, 3143530863341041542], [3483395353741361115, 3494632259903994625, 6657987792994187913, 108272644256946680], [6425625360762666998, 7924344314350639699, 14762033076929465436, 2023505479389396574], [11443185389230096326, 16446406505298427670, 4454117921868872420, 668285336232769164], [3483395353741361115, 3494632259903994625, 6657987792994187913, 108272644256946680], [12436184717236109307, 3962172157175319849, 7381016538464732718, 1011752739694698287], [0, 0, 0, 0], [15919580070977470422, 7456804417079314474, 14039004331458920631, 1120025383951644967], [5432626032756654017, 1961834588764195904, 11835134460333605139, 1680038075927467451], [3483395353741361115, 3494632259903994625, 6657987792994187913, 108272644256946680], [3898461358030585804, 15381148731429954174, 10354293334678834451, 3143530863341041542], [642562536076266

In [40]:
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")

verified


# Deploy on chain

In [41]:
# check if notebook is in colab
try:
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "solc-select"])
    !solc-select install 0.8.20
    !solc-select use 0.8.20
    !solc --version

# rely on local installation if the notebook is not in colab
except:
    pass

Installing solc '0.8.20'...
Version '0.8.20' installed.
Switched global version to 0.8.20
solc, the solidity compiler commandline interface
Version: 0.8.20+commit.a1b79de6.Linux.g++


In [42]:
sol_code_path = os.path.join('Verifier.sol')
abi_path = os.path.join('Verifier.abi')

res = ezkl.create_evm_verifier(
        vk_path,
        srs_path,
        settings_path,
        sol_code_path,
        abi_path
    )

assert res == True
assert os.path.isfile(sol_code_path)

In [43]:
onchain_input_array = []

for value in proof["instances"]:
    for field_element in value:
        onchain_input_array.append(ezkl.vecu64_to_int(field_element))
# This will be the values you use onchain
# copy them over to remix and see if they verify
# What happens when you change a value?
print("pubInputs: ", onchain_input_array)
print("proof: ", "0x" + proof["proof"])

pubInputs:  [4, 10, 7, 2, 11, 7, 1, 0, 8, 12, 7, 10, 2, 7, 3, 8, 1, 5, 4, 6, 1, 2, 11, 9, 3, 3, 1, 3, 8, 7, 3, 13, 6, 7, 12, 6, 9, 7, 11, 6, 9, 0, 8, 11, 6, 8, 8, 3, 0, 5, 6, 2, 2, 12, 1, 4, 2, 2, 3, 6, 7, 3, 5, 2, 3, 12, 12, 0, 10, 2, 11, 4, 6, 1, 3, 8, 7, 11, 13, 3, 2, 7, 1, 8, 4, 11, 9, 6, 12, 2, 6, 7, 1, 12, 2, 9, 11, 4, 11, 4, 8, 10, 4, 9, 2, 9, 1, 4, 6, 1, 9, 6, 8, 6, 4, 7, 7, 3, 4, 4, 13, 9, 6, 11, 10, 0, 1, 5, 12, 3, 5, 11, 9, 10, 6, 9, 1, 5, 3, 6, 6, 6, 7, 11, 9, 2, 8, 5, 4, 4, 10, 10, 0, 11, 11, 6, 3, 1, 9, 8, 6, 12, 10, 4, 10, 8, 8, 0, 3, 11, 11, 8, 6, 5, 10, 1, 10, 4, 6, 6, 8, 9, 11, 8, 10, 8, 1, 11, 4, 11, 13, 12, 3, 2, 8, 1, 1, 7, 10, 1, 4, 2, 5, 4, 0, 11, 4, 12, 7, 6, 1, 0, 11, 10, 11, 1, 4, 2, 9, 9, 3, 13, 4, 1, 1, 1, 10, 12, 4, 11, 4, 11, 12, 1, 11, 5, 12, 13, 7, 3, 7, 3, 10, 1, 10, 3, 5, 9, 9, 8, 12, 1, 2, 5, 2, 7, 13, 8, 5, 6, 10, 12, 10, 9, 2, 1, 2, 9, 6, 2, 8, 11, 7, 1, 10, 8, 3, 11, 5, 11, 5, 5, 1, 5, 12, 7, 3, 10, 1, 10, 8, 9, 10, 2, 11, 2, 10, 9, 5, 1, 13, 6, 7,