In [2]:
import sys
sys.path.append('..')

from zklora import export_lora_submodules, generate_proofs, batch_verify_proofs, get_merkle_root

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
base_model_name = "distilgpt2"
lora_model_name = "q1e123/peft-starcoder-lora-a100"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
lora_model = PeftModel.from_pretrained(base_model, lora_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
lora_model.eval()

texts = ["Hello from LoRA", "And another test", "One more line..."]

export_lora_submodules(
    model=lora_model,
    tokenizer=tokenizer,
    input_texts=texts,
    submodule_key="attn.c_attn",
)



In [5]:
import os

activation_files = os.listdir("intermediate_activations")
merkle_roots = {}

for file in activation_files:
    if file.endswith(".json"):
        file_path = os.path.join("intermediate_activations", file)
        merkle_root = get_merkle_root(file_path)
        merkle_roots[file] = merkle_root
        print(f"Merkle root for {file}: {merkle_root}")


Merkle root for base_model_model_transformer_h_3_attn_c_attn.json: 0xf86ca87f2efb14b2d78b9d0906154c23d9ee3c541f09d66009afa70cca03bbc1
Merkle root for base_model_model_transformer_h_2_attn_c_attn.json: 0xcd552fdea1ea1ff23544266ae091b98c6ee3fdda004d46396a197d10d788776c
Merkle root for base_model_model_transformer_h_1_attn_c_attn.json: 0x1c33a1b9bd68e40acd4445c2c1a4d06a3fa7e8c431cd804520195a79a71f9dbf
Merkle root for base_model_model_transformer_h_5_attn_c_attn.json: 0x5d04294b28890929cbdf5e7147a964c05c3f01c8eca203b0d6ac5dfca6b6ee04
Merkle root for base_model_model_transformer_h_0_attn_c_attn.json: 0x7680e5c16d8022f23120d4fb7e1748d2cd8029d62085ed30001ad0860f36e0d0
Merkle root for base_model_model_transformer_h_4_attn_c_attn.json: 0x54b880fa9ae593f632efe48d2d041a94573ff1da12efb6f0dfbfe552d17851ff


In [7]:
async def main():
    return await generate_proofs(verbose=True)

await main()

Found 6 ONNX files in lora_onnx_params.
Processing ONNX files for proof generation...
Preparing to prove with ONNX: lora_onnx_params/base_model_model_transformer_h_5_attn_c_attn.onnx
Matching JSON: intermediate_activations/base_model_model_transformer_h_5_attn_c_attn.json
Number of parameters: 24,576
Generating settings & compiling circuit...
Setup for base_model_model_transformer_h_5_attn_c_attn took 77.25 sec
Input shape from JSON: (1, 9216)
Local ONNX output shape: (1, 27648)
Generating witness (async)...
Witness gen took 13.75 sec
Generating proof...
Proof gen took 39.24 sec
Done with base_model_model_transformer_h_5_attn_c_attn.

Preparing to prove with ONNX: lora_onnx_params/base_model_model_transformer_h_1_attn_c_attn.onnx
Matching JSON: intermediate_activations/base_model_model_transformer_h_1_attn_c_attn.json
Number of parameters: 24,576
Generating settings & compiling circuit...
Setup for base_model_model_transformer_h_1_attn_c_attn took 43.87 sec
Input shape from JSON: (1, 9

(298.1483585834503, 91.26909136772156, 230.97669887542725, np.int64(147456), 6)

In [109]:
import json

witness_file = 'proof_artifacts/base_model_model_transformer_h_0_attn_c_attn_witness.json'

with open(witness_file, 'r') as f:
    witness_data = json.load(f)

print(witness_data['inputs'][0])


['faffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'edffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0000000000000000000000000000000000000000000000000000000000000000', 'f7ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'fdffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'faffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '1400000000000000000000000000000000000000000000000000000000000000', 'f0ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'fcffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'fbffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '000000f093f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'f5ffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'ffffffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0300000000000000000000000000000000000000000000000000000000000000', 'ffffffef93f5e1439170b97948e833285d588181b64550

In [114]:
def hex_to_signed_int(hex_str):
    value = int(hex_str, 16)
    # If the highest bit is set (negative number in two's complement)
    if value & (1 << (256 - 1)):  # Assuming 256-bit numbers
        value -= 1 << 256
    return value

# Load both witness and settings files
witness_file = 'proof_artifacts/base_model_model_transformer_h_0_attn_c_attn_witness.json'
settings_file = 'proof_artifacts/base_model_model_transformer_h_0_attn_c_attn_settings.json'

with open(witness_file, 'r') as f:
    witness_data = json.load(f)

with open(settings_file, 'r') as f:
    settings_data = json.load(f)

# Get scaling factor from settings
scale_factor = settings_data['run_args']['input_scale']

# Convert witness inputs back to float
original_values = [hex_to_signed_int(x) / scale_factor for x in witness_data['inputs'][0]]
print("Original float values:")
print(original_values)

Original float values:
[-3.2308066937905276e+74, -1.1630902453194047e+75, 0.0, -5.169290330575955e+74, -1.2923230570051e+74, -3.2308066937905276e+74, 1.2923224245236183e+75, -9.692418816408619e+74, -1.9384842692669092e+74, -2.5846454815287183e+74, 9.265670365901085e+68, -6.461612755099573e+74, -6.324814818956501e+67, 1.9384836367854273e+74, -6.324814818956501e+67, -7.107773967361382e+74, 7.1077733348799e+74, 2.5846448490472364e+74, -4.523129118314146e+74, -8.400096391885001e+74, -1.2923230570051e+74, -2.5846454815287183e+74, 0.0, -5.169290330575955e+74, -6.46161844743291e+73, 3.2308060613090457e+74, -1.2923230570051e+74, -6.324814818956501e+67, 9.265670365901085e+68, 6.461612122618091e+74, -1.0338580028670427e+75, -7.753935179623191e+74, 9.265670365901085e+68, 1.2923224245236182e+74, 1.2923224245236182e+74, -1.2923230570051e+74, 9.265670365901085e+68, 9.265670365901085e+68, -6.46161844743291e+73, 6.461612122618091e+74, -1.9384842692669092e+74, -4.523129118314146e+74, 7.1077733348799e+7