Simple MNIST Classifier

In [None]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorflow_datasets"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tf2onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

# make sure you have the dependencies required here already installed
import ezkl
import os
import json
import time
import random
import logging

import tensorflow as tf
from tensorflow.keras.layers import *
import tensorflow as tf
import tensorflow_datasets as tfds

# uncomment for more descriptive logging
FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.INFO)


In [None]:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  #round to 0 or 1
  image = tf.cast(image, tf.float32) / 255.
  return tf.round(image), label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)


In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)


In [None]:
import os

model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('key.pk')
vk_path = os.path.join('key.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')


In [None]:
print(list(ds_train)[0][0].numpy()[0:1].shape)


In [None]:
import numpy as np
import tf2onnx
import tensorflow as tf
import json
import matplotlib.pyplot as plt
from PIL import Image

# After training, export to onnx (network.onnx) and create a data file (input.json)
png_file_path = "./MNIST_57_0.png"

# Open the image file
with Image.open(png_file_path) as img:
    # Convert to grayscale
    img = img.convert("L")

    # Resize to 28x28 (for MNIST)
    img = img.resize((28, 28))

    # Convert to a numpy array
    img_array = np.array(img)

    # Normalize the image data to 0-1 range
    img_array = img_array / 255.0

    # Reshape to match the input shape expected by the model: [1, 28, 28]
    x = img_array.reshape((1, 28, 28))

plt.imshow(x[0], cmap="gray")  # x[0] to get the first image in the batch
plt.title("Displayed Image")
plt.show()
spec = tf.TensorSpec([1, 28, 28], tf.float32, name="input_0")


tf2onnx.convert.from_keras(
    model,
    input_signature=[spec],
    inputs_as_nchw=["input_0"],
    opset=12,
    output_path=model_path,
)

logits = model.predict(x)
print(logits)
probabilities = tf.nn.softmax(logits).numpy()

# Find the index of the highest probability
predicted_class_index = np.argmax(probabilities)

# Convert this index to the corresponding digit
predicted_digit = str(predicted_class_index)
print("The model predicts this image is a:", predicted_digit)

data_array = x.reshape([-1]).tolist()

data = dict(input_data=[data_array])

# Serialize data into file:
json.dump(data, open(data_path, "w"))


In [None]:
import ezkl

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "public"
run_args.param_visibility = "fixed"
run_args.output_visibility = "public"
run_args.variables = [("batch_size", 1)]

# Capture set of data points
x = list(ds_train)[0][0].numpy()[0:30]

data_array = x.reshape([-1]).tolist()

data = dict(input_data = [data_array])

cal_path = os.path.join('cal_data.json')

# Serialize data into file:
json.dump( data, open(cal_path, 'w' ))

!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True

res = ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources", scales = [1, 7])
assert res == True


In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True


In [None]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)


In [None]:
# now generate the witness file
witness_path = "witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)


In [None]:
# uncomment to mock prove
res = ezkl.mock(witness_path, compiled_model_path)
assert res == True


In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)


In [None]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single",
    )

print(res)
assert os.path.isfile(proof_path)


In [None]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")


We can now create an EVM / `.sol` verifier that can be deployed on chain to verify submitted proofs using a view function.

In [None]:

abi_path = 'test.abi'
sol_code_path = 'test.sol'

res = ezkl.create_evm_verifier(
        vk_path,
        srs_path,
        settings_path,
        sol_code_path,
        abi_path,
    )
assert res == True


## Verify on the evm

In [None]:
# Make sure anvil is running locally first
# run with $ anvil -p 3030
# we use the default anvil node here
import json

address_path = os.path.join("address.json")

res = ezkl.deploy_evm(
    address_path,
    sol_code_path,
    'http://127.0.0.1:3030'
)

assert res == True

with open(address_path, 'r') as file:
    addr = file.read().rstrip()


In [None]:
# make sure anvil is running locally
# $ anvil -p 3030

res = ezkl.verify_evm(
    proof_path,
    addr,
    "http://127.0.0.1:3030"
)
assert res == True
