In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import json
import torch
from datetime import datetime
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model
from evaluate import load as load_metric
from transformers import TrainerCallback


class PrintStepCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 10 == 0:
            loss = state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'
            print(f"Step {state.global_step}, Loss: {loss}")


# ================================================================
# 1. Load Data
# ================================================================
DATA_PATH = "./summary_claims.json"

with open(DATA_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Loaded {len(data)} patents with real summaries")

records = []
for i, item in enumerate(data):
    if "summary" in item and "description" in item:
        records.append({
            "text": item["description"],
            "title": item["title"],
            "summary": item["summary"]
        })
        if i < 3:
            print(f"\nExample {i+1}: {item['title']}")
            print(f"  Summary: {item['summary'][:150]}...")

print(f"\nTotal records: {len(records)}")

# ================================================================
# 2. Load Model/Tokenizer
# ================================================================
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    trust_remote_code=True, 
    torch_dtype=torch.bfloat16
).cuda()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ================================================================
# 3. Create Dataset
# ================================================================
MAX_LENGTH = 512
MAX_TEXT_TOKENS = 350

def truncate_text(text, max_tokens):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
        text = tokenizer.decode(tokens, skip_special_tokens=True)
    return text

dataset = Dataset.from_list(records)
dataset = dataset.train_test_split(test_size=0.2, seed=42)
print(f"\nTrain: {len(dataset['train'])}, Test: {len(dataset['test'])}")

# ================================================================
# 4. Preprocessing
# ================================================================
def preprocess(batch):
    input_ids_list = []
    attention_list = []
    labels_list = []

    for text, summary in zip(batch["text"], batch["summary"]):
        text = truncate_text(text, MAX_TEXT_TOKENS)
        
        prompt = f"Summarize this patent:\n\n{text}\n\nSummary: "
        target = summary + tokenizer.eos_token
        full_text = prompt + target

        target_ids = tokenizer.encode(target, add_special_tokens=False)
        target_len = len(target_ids)

        tokenized = tokenizer(
            full_text,
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length",
            add_special_tokens=True,
        )

        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]
        seq_len = sum(attention_mask)

        labels = [-100] * MAX_LENGTH
        target_start = seq_len - target_len
        
        for i in range(target_len):
            pos = target_start + i
            if 0 <= pos < MAX_LENGTH:
                labels[pos] = input_ids[pos]

        input_ids_list.append(input_ids)
        attention_list.append(attention_mask)
        labels_list.append(labels)

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_list,
        "labels": labels_list
    }

tokenized_train = dataset["train"].map(
    preprocess, 
    batched=True,
    remove_columns=dataset["train"].column_names
)

example = tokenized_train[0]
valid_count = sum(1 for l in example["labels"] if l != -100)
print(f"Valid label tokens: {valid_count}")

def collate_fn(batch):
    return {
        "input_ids": torch.tensor([x["input_ids"] for x in batch], dtype=torch.long),
        "attention_mask": torch.tensor([x["attention_mask"] for x in batch], dtype=torch.long),
        "labels": torch.tensor([x["labels"] for x in batch], dtype=torch.long),
    }

# ================================================================
# 5. Baseline Evaluation
# ================================================================
rouge = load_metric("rouge")

def generate_summary(mdl, text):
    text = truncate_text(text, MAX_TEXT_TOKENS)
    prompt = f"Summarize this patent:\n\n{text}\n\nSummary:"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    mdl.eval()
    with torch.no_grad():
        output = mdl.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    full_output = tokenizer.decode(output[0], skip_special_tokens=True)
    if "Summary:" in full_output:
        return full_output.split("Summary:")[-1].strip()
    return full_output

test_refs = [item["summary"] for item in dataset["test"]]

print("\nBaseline evaluation...")
baseline_preds = []
for item in dataset["test"]:
    pred = generate_summary(model, item["text"])
    baseline_preds.append(pred)

baseline_rouge = rouge.compute(predictions=baseline_preds, references=test_refs)
print(f"Baseline ROUGE-L: {baseline_rouge['rougeL']:.4f}")

# ================================================================
# 6. LoRA Fine-tuning
# ================================================================
config = LoraConfig(
    r=64,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

training_args = TrainingArguments(
    output_dir="./qwen_lora_patent_real",
    num_train_epochs=50,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,
    bf16=True,
    logging_steps=10,
    save_steps=999999,
    report_to="none",
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    data_collator=collate_fn,
    callbacks=[PrintStepCallback()]
)

print("\n=== Training ===")
trainer.train()

# ================================================================
# 7. Final Evaluation
# ================================================================
print("\nFinal evaluation...")
finetuned_preds = []

model.eval()
for item in dataset["test"]:
    pred = generate_summary(model, item["text"])
    finetuned_preds.append(pred)

finetuned_rouge = rouge.compute(predictions=finetuned_preds, references=test_refs)

# ================================================================
# 8. SAVE MODEL PROPERLY
# ================================================================
print("\n" + "="*70)
print("SAVING MODEL")
print("="*70)

# Save LoRA adapters
model.save_pretrained("./qwen_lora_patent_real")
print("✓ LoRA adapters saved")

# Save tokenizer
tokenizer.save_pretrained("./qwen_lora_patent_real")
print("✓ Tokenizer saved")

# Save comprehensive metadata
metadata = {
    "model_info": {
        "base_model": model_name,
        "model_type": "LoRA_fine-tuned",
        "task": "patent_summarization"
    },
    "training_info": {
        "training_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "dataset": DATA_PATH,
        "num_train_examples": len(dataset["train"]),
        "num_test_examples": len(dataset["test"]),
        "num_epochs": 10,
        "batch_size": 2,
        "gradient_accumulation_steps": 4,
        "effective_batch_size": 8,
        "learning_rate": 1e-4,
        "max_length": MAX_LENGTH,
        "max_text_tokens": MAX_TEXT_TOKENS
    },
    "lora_config": {
        "r": 32,
        "lora_alpha": 64,
        "lora_dropout": 0.05,
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    },
    "results": {
        "baseline": {
            "rouge1": float(baseline_rouge['rouge1']),
            "rouge2": float(baseline_rouge['rouge2']),
            "rougeL": float(baseline_rouge['rougeL'])
        },
        "finetuned": {
            "rouge1": float(finetuned_rouge['rouge1']),
            "rouge2": float(finetuned_rouge['rouge2']),
            "rougeL": float(finetuned_rouge['rougeL'])
        },
        "improvement": {
            "rouge1": float(finetuned_rouge['rouge1'] - baseline_rouge['rouge1']),
            "rouge2": float(finetuned_rouge['rouge2'] - baseline_rouge['rouge2']),
            "rougeL": float(finetuned_rouge['rougeL'] - baseline_rouge['rougeL'])
        }
    }
}

with open("./qwen_lora_patent_real/metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)
print("✓ Metadata saved")

# Save sample predictions for reference
samples = []
for i in range(min(5, len(test_refs))):
    samples.append({
        "title": dataset['test'][i]['title'],
        "reference": test_refs[i],
        "baseline": baseline_preds[i],
        "finetuned": finetuned_preds[i]
    })

with open("./qwen_lora_patent_real/sample_predictions.json", "w") as f:
    json.dump(samples, f, indent=2, ensure_ascii=False)
print("✓ Sample predictions saved")

print(f"\n✓ Complete model package saved to: ./qwen_lora_patent_real/")
print("\nSaved files:")
print("  - adapter_model.bin (~13MB)")
print("  - adapter_config.json")
print("  - tokenizer files")
print("  - metadata.json")
print("  - sample_predictions.json")

# ================================================================
# 9. Display Results
# ================================================================
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"Baseline ROUGE-L:   {baseline_rouge['rougeL']:.4f}")
print(f"Fine-tuned ROUGE-L: {finetuned_rouge['rougeL']:.4f}")
delta = finetuned_rouge['rougeL'] - baseline_rouge['rougeL']
print(f"Change: {'+' if delta >= 0 else ''}{delta:.4f} ({delta*100:+.2f}%)")

print(f"\nDetailed scores:")
print(f"  ROUGE-1: {baseline_rouge['rouge1']:.4f} -> {finetuned_rouge['rouge1']:.4f}")
print(f"  ROUGE-2: {baseline_rouge['rouge2']:.4f} -> {finetuned_rouge['rouge2']:.4f}")
print(f"  ROUGE-L: {baseline_rouge['rougeL']:.4f} -> {finetuned_rouge['rougeL']:.4f}")

print("\n=== Sample Comparisons ===")
for i in range(min(3, len(test_refs))):
    print(f"\n--- {dataset['test'][i]['title']} ---")
    print(f"Reference:  {test_refs[i][:150]}...")
    print(f"Baseline:   {baseline_preds[i][:150]}...")
    print(f"Fine-tuned: {finetuned_preds[i][:150]}...")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)

Loaded 19 patents with real summaries

Example 1: Cup.
  Summary: The object of the present invention is to provide a convenient cup or container for this or similar purposes. My improved cup, designated as a whole b...

Example 2: Binder.
  Summary: My invention relates to certain new and useful improvements in binders. It is fully described and explained in the speciiication and shown in the acco...

Example 3: Crochet hook
  Summary: Patent No. 2,024,794. (01.66-118) This invention relates to a crochet hook or the like. I claim that the movement of the stitches onto and off of the ...

Total records: 19


`torch_dtype` is deprecated! Use `dtype` instead!



Train: 15, Test: 4


Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Valid label tokens: 213


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



Baseline evaluation...
Baseline ROUGE-L: 0.2091
trainable params: 35,192,832 || all params: 529,225,600 || trainable%: 6.6499

=== Training ===


Step,Training Loss
10,2.1322
20,1.1533
30,0.2871
40,0.0253
50,0.0064
60,0.0016
70,0.0011
80,0.0011


Step 10, Loss: N/A
Step 20, Loss: 2.1322
Step 30, Loss: 1.1533
Step 40, Loss: 0.2871
Step 50, Loss: 0.0253
Step 60, Loss: 0.0064
Step 70, Loss: 0.0016
Step 80, Loss: 0.0011
