## EZKL GCN Notebook

In [4]:
!pip install torch-scatter torch-sparse torch-geometric

Collecting torch-scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-sparse
  Downloading torch_sparse-0.6.17.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter, torch-sparse, torch-geometric
  Building wheel for torch-sc

In [3]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    for e in ["ezkl", "onnx", "torch", "torchvision", "torch-scatter", "torch-sparse", "torch-geometric"]:
        subprocess.check_call([sys.executable, "-m", "pip", "install", e])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

In [5]:
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[2, 1, 3],
                           [0, 0, 2]], dtype=torch.long)
x = torch.tensor([[1], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
data

Data(x=[3, 1], edge_index=[2, 3])

In [11]:
import torch
import math
from torch_geometric.nn import MessagePassing
from torch.nn.modules.module import Module

def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

class GCNConv(Module):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__()  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin.weight)
        zeros(self.lin.bias)

    def forward(self, x, adj_t, deg):
        x = self.lin(x)
        adj_t = self.normalize_adj(adj_t, deg)
        x = adj_t @ x

        return x

    def normalize_adj(self, adj_t, deg):
        deg.masked_fill_(deg == 0, 1.)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == 1, 0.)
        adj_t = adj_t *  deg_inv_sqrt.view(-1, 1) # N, 1
        adj_t = adj_t *  deg_inv_sqrt.view(1, -1) # 1, N

        return adj_t

## Train Pipeline

In [7]:
import os
import os.path as osp
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

path = osp.join(os.getcwd(), 'data', 'Cora')
dataset = Planetoid(path, 'Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [18]:
import time

from torch import tensor
from torch.optim import Adam

# define num feat to use for training here
num_feat = 10

def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping):

    val_losses, accs, durations = [], [], []
    for _ in range(runs):
        data = dataset[0]
        data = data.to(device)

        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        best_val_loss = float('inf')
        test_acc = 0
        val_loss_history = []

        for epoch in range(1, epochs + 1):
            train(model, optimizer, data)
            eval_info = evaluate(model, data)
            eval_info['epoch'] = epoch

            if eval_info['val_loss'] < best_val_loss:
                best_val_loss = eval_info['val_loss']
                test_acc = eval_info['test_acc']

            val_loss_history.append(eval_info['val_loss'])
            if early_stopping > 0 and epoch > epochs // 2:
                tmp = tensor(val_loss_history[-(early_stopping + 1):-1])
                if eval_info['val_loss'] > tmp.mean().item():
                    break

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()

        val_losses.append(best_val_loss)
        accs.append(test_acc)
        durations.append(t_end - t_start)

    loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)

    print('Val Loss: {:.4f}, Test Accuracy: {:.3f} ± {:.3f}, Duration: {:.3f}'.
          format(loss.mean().item(),
                 acc.mean().item(),
                 acc.std().item(),
                 duration.mean().item()))


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()

    E = data.edge_index.size(1)
    N = data.x.size(0)
    x = data.x[:, :num_feat]
    adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones(E), size=(N, N)).to_dense().T
    deg = torch.sum(adj_t, dim=1)
    out = model(x, adj_t, deg)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


def evaluate(model, data):
    model.eval()

    with torch.no_grad():

        E = data.edge_index.size(1)
        N = data.x.size(0)
        x = data.x[:, :num_feat]
        adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones(E), size=(N, N)).to_dense().T
        deg = torch.sum(adj_t, dim=1)
        logits = model(x, adj_t, deg)

    outs = {}
    for key in ['train', 'val', 'test']:
        mask = data['{}_mask'.format(key)]
        loss = F.nll_loss(logits[mask], data.y[mask]).item()
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

        outs['{}_loss'.format(key)] = loss
        outs['{}_acc'.format(key)] = acc

    return outs

In [19]:
runs = 10
epochs = 200
lr = 0.01
weight_decay = 0.0005
early_stopping = 10
hidden = 16
dropout = 0.5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Net(torch.nn.Module):
    def __init__(self, dataset, num_feat):
        super(Net, self).__init__()
        # self.conv1 = GCNConv(dataset.num_features, hidden)
        self.conv1 = GCNConv(num_feat, hidden)
        self.conv2 = GCNConv(hidden, dataset.num_classes)


    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, adj_t, deg):
        x = F.relu(self.conv1(x, adj_t, deg))
        x = F.dropout(x, p=dropout, training=self.training)
        x = self.conv2(x, adj_t, deg)
        return F.log_softmax(x, dim=1)

model = Net(dataset, num_feat)
run(dataset, model, runs, epochs, lr, weight_decay, early_stopping)

Val Loss: 1.5943, Test Accuracy: 0.350 ± 0.012, Duration: 25.914


## EZKL Setup

In [20]:
import os
import ezkl


model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')



In [21]:
# Downsample graph
num_node = 5

# filter edges so that we only bring adjacencies among downsampled node
filter_row = []
filter_col = []
row, col = dataset[0].edge_index
for idx in range(row.size(0)):
    if row[idx] < num_node and col[idx] < num_node:
        filter_row.append(row[idx])
        filter_col.append(col[idx])
filter_edge_index = torch.stack([torch.tensor(filter_row), torch.tensor(filter_col)])
num_edge = len(filter_row)


x = dataset[0].x[:num_node, :num_feat]
edge_index = filter_edge_index

adj_t = torch.sparse_coo_tensor(edge_index, torch.ones(num_edge), size=(num_node, num_node)).to_dense().T
deg = torch.sum(adj_t, dim=1)


In [22]:
import json

# Flips the neural net into inference mode
model.eval()
model.to('cpu')

# No dynamic axis for GNN batch
torch.onnx.export(model,               # model being run
                      (x, adj_t, deg),               # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=11,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['x', 'edge_index'],   # the model's input names
                      output_names = ['output']) # the model's output names

verbose: False, log level: Level.ERROR



  _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)


In [23]:
torch_out = model(x, adj_t, deg)
x_shape = x.shape
adj_t_shape=adj_t.shape
deg_shape=deg.shape

x = ((x).detach().numpy()).reshape([-1]).tolist()
adj_t = ((adj_t).detach().numpy()).reshape([-1]).tolist()
deg = ((deg).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes=[x_shape, adj_t_shape, deg_shape],
            input_data=[x, adj_t, deg],
            output_data=[((torch_out).detach().numpy()).reshape([-1]).tolist()])
json.dump(data, open(data_path, 'w'))

In [24]:
!RUST_LOG=trace
import ezkl

run_args = ezkl.PyRunArgs()
run_args.input_scale = 5
run_args.param_scale = 5
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True

res = await ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True

In [25]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [26]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)

In [27]:
# now generate the witness file

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [28]:
# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK



res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [29]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

{'instances': [[[18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843742363274879387, 6495664164330342886, 3072557160466273026], [18394180685272187193, 2843

In [30]:
# VERIFY IT

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")

verified
