In [2]:
from torch import nn
import torch.nn.functional as F
import ezkl
import os
import sys
import json
import torch

In [3]:
sys.path.append(os.path.abspath(os.path.join('../..', 'src')))

from mimc import hash_model_weights, Fr

In [69]:
from typing import List, Tuple
from flwr.common import NDArrays

class AggregateModel(nn.Module):
    def preprocess(self, results: List[Tuple[NDArrays, int]]) -> Tuple[torch.Tensor, torch.Tensor]:
        weights, num_examples = zip(*results)

        client_tensors = []
        for i, client_weights in enumerate(weights):
            dims = []
            flatten_weights = []
            for layer in client_weights:
                dims.append(len(layer))
                flatten_weights.extend(layer)
            
            tensor = torch.tensor(flatten_weights)
            client_tensors.append(tensor)

        self.layers_dims = dims

        return torch.stack(client_tensors), torch.tensor(num_examples, dtype=torch.float32)

    def calculate_weights_hash(self, weights: torch.Tensor, num_examples: torch.Tensor) -> int:
        clients_hash_sum = 0
        n_clients = num_examples.numel()
        for i in range(n_clients):
            client_hash = hash_model_weights((weights[i], num_examples[i]))
            clients_hash_sum = int(Fr(clients_hash_sum + client_hash))
        
        return clients_hash_sum

    def forward(self, weights: torch.Tensor, num_examples: torch.Tensor) -> Tuple[torch.Tensor, str]:
        """Weighted average of model params"""
        num_examples_total = torch.sum(num_examples)
        weights_prime = (num_examples @ weights) / num_examples_total

        clients_hash_sum = str(self.calculate_weights_hash(weights, num_examples))

        return weights_prime, clients_hash_sum

In [70]:
# Tried to simplify the model hoping that it will improve the circuit compilation (it won't)
class SimpleAggregateModel(nn.Module):
    def forward(self, weights: torch.Tensor, num_examples: torch.Tensor) -> torch.Tensor:
        """Weighted average of model params"""
        num_examples_total = torch.sum(num_examples)
        weights_prime = (num_examples @ weights) / num_examples_total

        return weights_prime

In [71]:
circuit = AggregateModel()

In [72]:
with open("inputs/input_3.json", "r") as f:
    results = json.load(f)

results["input_data"]

[[[[0.055081795901060104,
    -0.006253935396671295,
    0.079899862408638,
    -0.05143433064222336,
    0.12056254595518112,
    -0.09760936349630356,
    -0.08089195191860199,
    -0.14716407656669617,
    -0.02868165634572506,
    0.06175040453672409,
    -0.022865556180477142,
    -0.007218115963041782,
    -0.19143353402614594,
    -0.057629745453596115,
    0.014571104198694229,
    -0.06167624518275261,
    -0.08106157183647156,
    -0.032106779515743256,
    0.020989682525396347,
    0.06874813139438629,
    -0.04862296208739281,
    -0.11871006339788437,
    0.0054029920138418674,
    0.09648309648036957,
    -0.05246738716959953,
    0.009629609994590282,
    -0.15317456424236298,
    -0.0894351527094841,
    -0.0031356036197394133,
    -0.004426474682986736,
    0.009687334299087524,
    -0.18223179876804352,
    -0.12819969654083252,
    -0.10110396146774292,
    -0.02951064705848694,
    -0.1534486562013626,
    -0.006374645512551069,
    -0.10561978071928024,
    -0.0730

In [73]:
inputs = circuit.preprocess(results["input_data"])
inputs

(tensor([[ 0.0551, -0.0063,  0.0799,  ..., -0.0307,  0.0433, -0.1751],
         [ 0.0548, -0.0088,  0.0768,  ..., -0.0319,  0.0425, -0.1763]]),
 tensor([50000., 50000.]))

### Verify hash sum

In [74]:
res = circuit.forward(*inputs)
res

torch.float32 torch.float32 torch.float32


(tensor([ 0.0549, -0.0075,  0.0783,  ..., -0.0313,  0.0429, -0.1757]),
 '18422080934942585953252598756928440043225150370187337684073461925761226736920')

In [75]:
client_1 = inputs[0][0], inputs[1][0]
hash_1 = hash_model_weights(client_1)

client_2 = inputs[0][1], inputs[1][1]
hash_2 = hash_model_weights(client_2)

hash_sum = int(Fr(hash_1 + hash_2))

assert res[1] == str(hash_sum)

client_1, hash_1, client_2, hash_2, hash_sum

((tensor([ 0.0551, -0.0063,  0.0799,  ..., -0.0307,  0.0433, -0.1751]),
  tensor(50000.)),
 15800721470892511143794912100930039860095342683798810904654329577788420712297,
 (tensor([ 0.0548, -0.0088,  0.0768,  ..., -0.0319,  0.0425, -0.1763]),
  tensor(50000.)),
 2621359464050074809457686655998400183129807686388526779419132347972806024623,
 18422080934942585953252598756928440043225150370187337684073461925761226736920)

### Prepare the circuit and inputs

In [76]:
data_path = os.path.join('zk', 'input.json')

x = inputs

w = ((x[0]).detach().numpy()).reshape([-1]).tolist()
n = ((x[1]).detach().numpy()).reshape([-1]).tolist()
data = dict(input_data = (w, n))
json.dump( data, open(data_path, 'w'))

In [81]:
model_path = os.path.join('zk', 'network.onnx')

circuit = SimpleAggregateModel()  # tested different aggregators (same results)

circuit.eval()
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
)

### Setup

In [82]:
settings_path = os.path.join('zk', 'settings.json')
compiled_model_path = os.path.join('zk', 'network.compiled')

srs_path = os.path.join('zk', 'kzg.srs')
pk_path = os.path.join('zk', 'test.pk')
vk_path = os.path.join('zk', 'test.vk')

py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "private"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed" # private by default

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
assert res

res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res

res = await ezkl.get_srs(settings_path)
assert res

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        # srs_path,
    )
assert res
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

### Prove

In [83]:
witness_path = os.path.join('zk', 'witness.json')

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert res and os.path.isfile(witness_path)

proof_path = os.path.join('zk', 'test.pf')
res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        "single",
    )
print(res)
assert os.path.isfile(proof_path)

value (0) out of range: (64, 192)


RuntimeError: Failed to generate witness: [graph] [halo2] General synthesis error

### Verify

In [None]:
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
    )
if res:
    print("legit proof verified")
else:
    print("legit proof failed")

legit proof verified


### Malicious proof

In [None]:
mal_model_path = os.path.join('zk', 'mal_network.onnx')

class MaliciousAggregateModel(AggregateModel):
    def calculate_weights_hash(self, weights: torch.Tensor, num_examples: torch.Tensor) -> int:
        return super().calculate_weights_hash(weights, num_examples)
    
    def forward(self, weights: torch.Tensor, num_examples: torch.Tensor) -> Tuple[torch.Tensor, str]:
        weights_prime, clients_hash_sum = super().forward(weights, num_examples)
        weights_prime *= 0.99  # modify the aggregated weights
        return weights_prime, clients_hash_sum 

mal_circuit = MaliciousAggregateModel()

res = mal_circuit.forward(*inputs)
print(res)

mal_circuit.eval()
torch.onnx.export(mal_circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      mal_model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
)

(tensor([3.6300, 3.9600, 4.2900, 4.6200, 4.9500, 5.2800]), '6197898982643862682635299448059270293950200210132496661978496667184822303567')


In [48]:
compiled_mal_model_path = os.path.join('zk', 'mal_network.compiled')
# mal_settings_path = os.path.join('zk', 'mal_settings.json')

# res = ezkl.gen_settings(mal_model_path, settings_path, py_run_args=py_run_args)
# assert res == True

res = ezkl.compile_circuit(mal_model_path, compiled_mal_model_path, settings_path)
assert res

In [49]:
mal_witness_path = os.path.join('zk', 'mal_witness.json')

res = await ezkl.gen_witness(data_path, compiled_mal_model_path, mal_witness_path)
assert res and os.path.isfile(mal_witness_path)

mal_proof_path = os.path.join('zk', 'mal_test.pf')
res = ezkl.prove(
        mal_witness_path,
        compiled_model_path,
        pk_path,
        mal_proof_path,
        "single",
    )
print(res)
assert os.path.isfile(mal_proof_path)

{'instances': [['fc01000000000000000000000000000000000000000000000000000000000000', '1c02000000000000000000000000000000000000000000000000000000000000', '3c02000000000000000000000000000000000000000000000000000000000000', '5b02000000000000000000000000000000000000000000000000000000000000', '7b02000000000000000000000000000000000000000000000000000000000000', '9b02000000000000000000000000000000000000000000000000000000000000']], 'proof': '0x167a5552c2516573975bcb2e023f92861c94978b415366302ca528c9bc8cc80e2f0b3bb08721a7d7291d4eb59952b397bca099339c36019c68feb0466f24b872292225533de048f5a743ac890cfa94585f09b46c3aa2fc6a733a7b0ae04d04b12c9a79ee4e1e819ae1f1ca360847601a22bacc461d18bde9b894d651b4a47d8c2f28a3c2ab418998c6a18027d8d28f447f0a506a3f6acc66101e0650f3e4f83b11316a959cbba6b0924499b612440c5dc98fc24eda57569df9c8e1bcfed44787022ed9680b5c2459a7fec6249a7a07af374951b3ede110730fa6cf5aad7f5c572c3da9d0abba6173ce00b05f3914ba7503c320e2cb5a8c879fddb0f03bd1d7c924f23863fa8adb0671efe15cd26dd47192685a7843745f86c2

In [50]:
# file that was edited to match outputs from legit witness file
mal_witness_tampered_path = os.path.join('zk', 'mal_witness_tampered.json')

mal_proof_tampered_path = os.path.join('zk', 'mal_test_tampered.pf')
res = ezkl.prove(
        mal_witness_tampered_path,
        compiled_model_path,
        pk_path,
        mal_proof_tampered_path,
        "single",
    )
print(res)
assert os.path.isfile(mal_proof_tampered_path)

{'instances': [['ef01000000000000000000000000000000000000000000000000000000000000', '1c02000000000000000000000000000000000000000000000000000000000000', '4902000000000000000000000000000000000000000000000000000000000000', '7602000000000000000000000000000000000000000000000000000000000000', 'a302000000000000000000000000000000000000000000000000000000000000', 'd002000000000000000000000000000000000000000000000000000000000000']], 'proof': '0x0dae73b6a7a046f5a0dc017944cdad1fca637207b28433dfb4d9ce0df6ee56a7199101f05086eecf24723112d59df26a75566a0ff64ff7185b8f06286ca496e4094f69ca394b00cd420d4bcbfefaab44b21e14cf35ebbdcf8606cb2247d4f8b2268f1d8811549fb6c3f586c8d1fe68f0ebed9834c3d5cdcedab432031eef00b32059a49212fe5ba297874661ebc1936f4745f3e29bcb0c24ecd77a8f2cfcb831138f0b763cd65d911593a7d5ad350b4cfa0ea158a8bd22ec160d8b9a1c08b5b12eda9d93db418d23280c8028bfbd1d2bd74aa578bcfb60e74d466de5a5ef31ab24ac758ec6e832bba88971b6aff58a4abc5d91a8df615f5c2fafbf3a6fc92bf712ace70466837bda020e25587db644f846ef31b1c6e60bacfe

In [51]:
res = ezkl.verify(
        mal_proof_path,
        settings_path,
        vk_path,
    )
if res:
    print("mal proof verified")
else:
    print("mal proof failed")

RuntimeError: Failed to run verify: [halo2] The constraint system is not satisfied

In [None]:
res = ezkl.verify(
        mal_proof_tampered_path,
        settings_path,
        vk_path,
    )
if res:
    print("tampered proof verified")
else:
    print("tampered proof failed")

RuntimeError: Failed to run verify: [halo2] The constraint system is not satisfied

https://arc.net/l/quote/ikvvnwpp

Soundness is the quality of the verifier (or parties representing the verifier) knowing that if a proof passes, it is more than likely a true statement. In some cases, such as those in underconstrained circuits, bad proofs can be generated that fool the verifier into passing a false statement. In this case, the vulnerability is not in the machine learning model itself, but in the SNARK constructed by ezkl.

ezkl is a compiler, so eventually should be less susceptible to such issues than a hand-written circuit, but it is still under active development.