# Introduction

This notebook shows how to use ZKLoRA to prove the forward pass of a LoRA model.

# Setup

In order to use this notebook, first install the zklora package:

```bash
pip install zklora
```

Or, if you are running this notebook locally and from the `examples` directory in this repository, you have to let the notebook know where to find the zklora package.

In [11]:
import sys
sys.path.append('..')

Now we can import the necessary functions from the zklora package.

In [12]:
from zklora import export_lora_submodules, generate_proofs, batch_verify_proofs

We will also use the Hugging Face Transformers and Peft libraries to load the model and tokenizer.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

In [13]:
base_model_name = "distilgpt2"
lora_model_name = "q1e123/peft-starcoder-lora-a100"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
lora_model = PeftModel.from_pretrained(base_model, lora_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
lora_model.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-5): 6 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2SdpaAttention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D(nf=2304, nx=768)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnit

In [14]:
texts = ["Hello from LoRA", "And another test", "One more line..."]

export_lora_submodules(
    model=lora_model,
    tokenizer=tokenizer,
    input_texts=texts,
    submodule_key="attn.c_attn",
)

Takes around 10 minutes depending on the hardware.

To make this work, you'll need to run the cell with the await statement in an async context. 

wrap your code in an async function and use IPython's async integration:

In [17]:
async def main():
    return await generate_proofs(verbose=True)

await main()

Found 6 ONNX files in lora_onnx_params.
Processing ONNX files for proof generation...
Preparing to prove with ONNX: lora_onnx_params/base_model_model_transformer_h_5_attn_c_attn.onnx
Matching JSON: intermediate_activations/base_model_model_transformer_h_5_attn_c_attn.json
Number of parameters: 24,576
Generating settings & compiling circuit...
Setup for base_model_model_transformer_h_5_attn_c_attn took 78.48 sec
Input shape from JSON: (1, 9216)
Local ONNX output shape: (1, 27648)
Generating witness (async)...
Witness gen took 12.75 sec
Generating proof...
Proof gen took 38.53 sec
Done with base_model_model_transformer_h_5_attn_c_attn.

Preparing to prove with ONNX: lora_onnx_params/base_model_model_transformer_h_1_attn_c_attn.onnx
Matching JSON: intermediate_activations/base_model_model_transformer_h_1_attn_c_attn.json
Number of parameters: 24,576
Generating settings & compiling circuit...
Setup for base_model_model_transformer_h_1_attn_c_attn took 44.72 sec
Input shape from JSON: (1, 9

(297.28650307655334, 88.75116658210754, 227.826753616333, np.int64(147456), 6)

In [18]:
batch_verify_proofs(verbose=True)

Verifying proof for base_model_model_transformer_h_1_attn_c_attn...
Verification took 0.50 seconds
Proof verified successfully for base_model_model_transformer_h_1_attn_c_attn!

Verifying proof for base_model_model_transformer_h_0_attn_c_attn...
Verification took 0.47 seconds
Proof verified successfully for base_model_model_transformer_h_0_attn_c_attn!

Verifying proof for base_model_model_transformer_h_2_attn_c_attn...
Verification took 0.45 seconds
Proof verified successfully for base_model_model_transformer_h_2_attn_c_attn!

Verifying proof for base_model_model_transformer_h_5_attn_c_attn...
Verification took 0.47 seconds
Proof verified successfully for base_model_model_transformer_h_5_attn_c_attn!

Verifying proof for base_model_model_transformer_h_4_attn_c_attn...
Verification took 0.46 seconds
Proof verified successfully for base_model_model_transformer_h_4_attn_c_attn!

Verifying proof for base_model_model_transformer_h_3_attn_c_attn...
Verification took 0.46 seconds
Proof verif

(2.80391001701355, 6)