<a href="https://colab.research.google.com/github/hamidb201214-svg/Lectures/blob/main/peft_prompt_tuning_patent_claim_style_FIXED_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prompt Tuning (PEFT) for Patent Claim Style  
**Lecture notebook** — train a tiny prompt-tuning adapter that nudges a base LLM to write in *patent-claim style*.

**Dataset:** `AI-Growth-Lab/patents_claims_1.5m_traim_test` (claims text + labels).  
For this lecture we use **only the claim text** (`text`) to do *style adaptation*.

> ⚠️ Not legal advice. Outputs must be reviewed by a qualified patent professional.


## 0) Install dependencies

> **Compatibility note:** PEFT versions below 0.18 are not compatible with Transformers v5, so we pin `transformers<5` for a stable lecture environment.



In [None]:
!pip -q install -U "transformers>=4.38,<5" "peft>=0.8.2,<0.18" "datasets>=2.14.5" accelerate


## 1) Imports + configuration

In [None]:
import os
import itertools
import torch

from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


### Model choice

For prompt tuning, a smaller base model is better for a live lecture.

- Default: `bigscience/bloomz-560m` (instruct-tuned Bloom)
- If you're on CPU and want faster runs, try `distilgpt2`


In [None]:
model_name = "bigscience/bloomz-560m"
# model_name = "distilgpt2"  # faster on CPU

NUM_VIRTUAL_TOKENS = 8     # prompt tuning parameters (small)
MAX_LENGTH = 256           # sequence length for training
TRAIN_SAMPLES = 2000       # how many claims to stream in for training
EVAL_SAMPLES = 200         # small held-out set for quick sanity check
MAX_STEPS = 80             # for lecture speed; increase for better results

SEED = 42
torch.manual_seed(SEED)


## 2) Load tokenizer + base model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure pad token exists (important for batching)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load base model
# (Keep it simple for lecture; device handling is done by Trainer)
foundational_model = AutoModelForCausalLM.from_pretrained(model_name)


# Make sure the model knows what to use for padding (avoids warnings)
foundational_model.config.pad_token_id = tokenizer.pad_token_id


## 3) Load a small subset of the patents claims dataset (streaming)

The full training CSV is several GB. For a lecture, we **stream** and only take a small sample.

We use only the `text` field (claim text).


In [None]:
dataset_name = "AI-Growth-Lab/patents_claims_1.5m_traim_test"

# Stream to avoid downloading multi-GB files.
train_stream = load_dataset(dataset_name, split="train", streaming=True)
test_stream  = load_dataset(dataset_name, split="test",  streaming=True)

def take_text(stream, n):
    out = []
    for ex in itertools.islice(stream, n):
        # Keep just claim text for language modeling
        out.append({"text": ex["text"]})
    return Dataset.from_list(out)

train_ds_raw = take_text(train_stream, TRAIN_SAMPLES)
eval_ds_raw  = take_text(test_stream,  EVAL_SAMPLES)

train_ds_raw, eval_ds_raw


### Quick look at a couple claim examples

In [None]:
for i in range(2):
    print("\n--- Example", i, "---")
    print(train_ds_raw[i]["text"][:600])


## 4) Tokenize

We train with standard causal LM objective: predict next token.  
`DataCollatorForLanguageModeling(mlm=False)` will create labels from `input_ids`.


In [None]:
def tokenize_function(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
    )

train_ds = train_ds_raw.map(tokenize_function, batched=True, remove_columns=["text"])
eval_ds  = eval_ds_raw.map(tokenize_function,  batched=True, remove_columns=["text"])

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

train_ds, eval_ds


## 5) Baseline generation (before prompt tuning)

We test the base model on a patent-claim-like prompt and see what it does.


In [None]:
def generate(model, prompt, max_new_tokens=80):
    model.eval()
    # Temporarily set padding_side to 'left' for generation with causal LMs
    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "left"

    inputs = tokenizer(prompt, return_tensors="pt")

    # Restore original padding_side
    tokenizer.padding_side = original_padding_side

    # Move inputs to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_length=inputs['input_ids'].shape[1] + max_new_tokens, # Explicitly calculate total max_length
            repetition_penalty=1.2,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

In [None]:
# A simple patent-claim-like prompt we will reuse throughout this notebook
baseline_prompt = """1. A system comprising:
    a processor; and
    a memory storing instructions that, when executed by the processor, cause the processor to:
        receive input data;
        determine an output based on the input data; and
        provide the output to a user interface.
"""

print("=== Baseline (Base Model; before prompt tuning) ===")
print(generate(foundational_model, baseline_prompt, max_new_tokens=140))


## 6) Create a Prompt Tuning adapter (PEFT)

Only the **virtual prompt embeddings** are trainable; the base model weights stay frozen.


In [None]:
prompt_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    num_virtual_tokens=NUM_VIRTUAL_TOKENS,
    tokenizer_name_or_path=model_name,
)

peft_model = get_peft_model(foundational_model, prompt_config)
peft_model.print_trainable_parameters()


## 7) Training

For lecture speed we use `max_steps` (instead of full epochs).
Increase steps (and/or sample size) for better style adaptation.


In [None]:
training_args = TrainingArguments(
    output_dir="./peft_patent_claim_style",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # effective batch size 16
    learning_rate=3e-3,
    max_steps=MAX_STEPS,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
    remove_unused_columns=False,
    fp16=torch.cuda.is_available(),
    seed=SEED,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collator,
)

trainer.train()

## 8) Generation after prompt tuning

We run the same prompt again. With enough steps, you’ll usually see more consistent claim-like phrasing and structure.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
peft_model.to(device)

print("=== After Prompt Tuning (PEFT adapter) ===")
print(generate(peft_model, baseline_prompt, max_new_tokens=140))

## 9) Save + reload adapter

You only need to save the **adapter**, not the full base model.


In [None]:
import os
from peft import PeftModel

adapter_dir = "./peft_patent_claim_style_adapter"
os.makedirs(adapter_dir, exist_ok=True)

peft_model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)  # optional convenience

print("Saved adapter (and tokenizer) to:", adapter_dir)


In [None]:
# Reload adapter on top of a *fresh* base model instance
# (More reliable than re-using `foundational_model`, which may already be wrapped/modified by PEFT.)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_for_reload = AutoModelForCausalLM.from_pretrained(model_name).to(device)
base_model_for_reload.config.pad_token_id = tokenizer.pad_token_id

reloaded = PeftModel.from_pretrained(base_model_for_reload, adapter_dir, is_trainable=False).to(device)
reloaded.eval()

print("=== Reloaded adapter output ===")
print(generate(reloaded, baseline_prompt, max_new_tokens=140))


## 10) Suggested lecture exercises

1. Change `baseline_prompt` to:
   - `"1. A method comprising:"`
   - `"1. A computer-readable medium storing instructions that, when executed, cause:"`
2. Increase `MAX_STEPS` to 300–1000 (if time/compute allows).
3. Try `NUM_VIRTUAL_TOKENS` in {4, 8, 16, 32} and compare.
4. Switch base model to `distilgpt2` to demonstrate how *base model choice* affects output.
