In [None]:
!pip install transformers peft datasets

# 1. Imports & Environment checks
import os
import json
import re
import random
import time
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import math
from itertools import product
from pprint import pprint
from datasets import Dataset as HFDataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, EarlyStoppingCallback, DataCollatorWithPadding
)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import evaluate
from peft import LoraConfig, get_peft_model, PeftModel

try:
    import pynvml
    pynvml.nvmlInit()
    _NVML_AVAILABLE = True
except Exception:
    _NVML_AVAILABLE = False

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("NVML available:", _NVML_AVAILABLE)
if _NVML_AVAILABLE:
    print("Number GPUs:", pynvml.nvmlDeviceGetCount())

# Reproducibility
GLOBAL_SEED = 42
def set_seed(seed=GLOBAL_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(GLOBAL_SEED)

# 2. Paths - edit these for your environment
CUAD_JSON_PATH = "/kaggle/input/atticus-open-contract-dataset-aok-beta/CUAD_v1/CUAD_v1.json"
OUTPUT_DIR = Path("./dataset_processed")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# 3. Utility: GPU stats
def get_gpu_stats(index=0):
    if not _NVML_AVAILABLE:
        return {"error": "pynvml not available"}
    handle = pynvml.nvmlDeviceGetHandleByIndex(index)
    mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
    util = pynvml.nvmlDeviceGetUtilizationRates(handle)
    return {
        "gpu_index": index,
        "mem_used_GB": mem.used / (1024**3),
        "mem_total_GB": mem.total / (1024**3),
        "gpu_util_pct": util.gpu,
        "gpu_sm_pct": util.memory
    }

# Quick GPU print
if _NVML_AVAILABLE:
    for i in range(pynvml.nvmlDeviceGetCount()):
        print("GPU", i, get_gpu_stats(i))

# 4. Load CUAD JSON and parse into snippet-level DataFrame
def parse_cuad_json(path):
    with open(path, "r", encoding="utf-8") as f:
        cuad = json.load(f)

    records = []
    docs = cuad.get("data", cuad) if isinstance(cuad, dict) else cuad
    for doc in docs:
        doc_title = doc.get("title") or doc.get("document_id") or doc.get("name") or ""
        for para in doc.get("paragraphs", []):
            context = para.get("context", "").strip()
            for qa in para.get("qas", []):
                question = qa.get("question", "").strip()
                # Extract label/clause robustly (use anything in quotes after related to, else fallback)
                m = re.search(r'related to\s+"([^"]+)"', question, flags=re.IGNORECASE)
                if not m:
                    m2 = re.search(r'"([^"]+)"', question)
                    clause = m2.group(1).strip() if m2 else question
                else:
                    clause = m.group(1).strip()
                if qa.get("is_impossible", False):
                    continue
                answers = qa.get("answers", [])
                for a in answers:
                    text = (a.get("text") or "").strip()
                    if not text:
                        continue
                    # Normalize whitespace
                    text = re.sub(r"\s+", " ", text)
                    # Filter out extremely short snippets
                    if len(text.split()) < 3 or len(text) < 12:
                        continue
                    records.append({"document": doc_title, "text": text, "label": clause})
    df = pd.DataFrame(records).drop_duplicates(subset=["text", "label"]).reset_index(drop=True)
    return df

df = parse_cuad_json(CUAD_JSON_PATH)
print("Total parsed snippets:", len(df))
df.head(6)


# 5. Basic EDA - counts, length distribution, top labels
df['text_len_words'] = df['text'].apply(lambda t: len(t.split()))
label_counts = df['label'].value_counts()

print("\nNumber of unique labels:", label_counts.shape[0])
print("Top 20 labels:\n", label_counts.head(20).to_string())

label_counts.to_csv(OUTPUT_DIR / "label_counts.csv")

# Plot counts (bar for top 30)
plt.figure(figsize=(14,5))
label_counts.head(30).plot(kind='bar')
plt.title("Top 30 Clause Label Counts")
plt.xlabel("Clause label")
plt.ylabel("Count")
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

# Text length histogram
plt.figure(figsize=(10,4))
plt.hist(df['text_len_words'], bins=50)
plt.title("Distribution of snippet lengths (words)")
plt.xlabel("Words")
plt.ylabel("Number of snippets")
plt.show()


# 6. Label encoding and stratified splits
le = LabelEncoder()
df['label_id'] = le.fit_transform(df['label'])
labels = list(le.classes_)
num_labels = len(labels)
print("Number of classes (label encoder):", num_labels)

# Save mapping
mapping_df = pd.DataFrame({"label": labels, "label_id": list(range(num_labels))})
mapping_df.to_csv(OUTPUT_DIR / "label_mapping.csv", index=False)
joblib.dump(le, OUTPUT_DIR / "label_encoder.joblib")

# Stratified split: train 85%, val 7.5%, test 7.5% (same as your earlier split)
train_df, temp_df = train_test_split(df, test_size=0.15, stratify=df['label_id'], random_state=GLOBAL_SEED)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label_id'], random_state=GLOBAL_SEED)

print("Train / Val / Test sizes:", len(train_df), len(val_df), len(test_df))

# Save CSVs for downstream training scripts
train_df.to_csv(OUTPUT_DIR / "train.csv", index=False)
val_df.to_csv(OUTPUT_DIR / "val.csv", index=False)
test_df.to_csv(OUTPUT_DIR / "test.csv", index=False)

# Also save a small sample to sanity-check loading later
train_df.sample(5, random_state=GLOBAL_SEED).to_csv(OUTPUT_DIR / "train_sample.csv", index=False)

In [None]:
# 7. Quick sanity functions to read processed dataset (used later)
def load_processed_split(split="train"):
    p = OUTPUT_DIR / f"{split}.csv"
    return pd.read_csv(p)

# Re-use label encoder
le = joblib.load(OUTPUT_DIR / "label_encoder.joblib")
labels = list(le.classes_)
num_labels = len(labels)

# Load processed CSVs
train_df = pd.read_csv(OUTPUT_DIR / "train.csv")
val_df = pd.read_csv(OUTPUT_DIR / "val.csv")
test_df = pd.read_csv(OUTPUT_DIR / "test.csv")

# Simple helper to make HF datasets
def make_hf_dataset(df):
    return HFDataset.from_pandas(df[["text","label_id"]].rename(columns={"label_id":"labels"}))

train_hf = make_hf_dataset(train_df)
val_hf = make_hf_dataset(val_df)
test_hf = make_hf_dataset(test_df)

# Tokenizer + model name (Legal-BERT)
MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

def tokenize_fn(examples, max_length=256):
    return tokenizer(examples["text"], truncation=True, max_length=max_length)

train_hf = train_hf.map(lambda x: tokenize_fn(x, max_length=256), batched=True)
val_hf = val_hf.map(lambda x: tokenize_fn(x, max_length=256), batched=True)
test_hf = test_hf.map(lambda x: tokenize_fn(x, max_length=256), batched=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Metrics helpers
def compute_metrics_from_preds(y_true, y_pred):
    res = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision_macro": precision_score(y_true, y_pred, average="macro", zero_division=0),
        "recall_macro": recall_score(y_true, y_pred, average="macro", zero_division=0),
        "f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
        "precision_weighted": precision_score(y_true, y_pred, average="weighted", zero_division=0),
        "recall_weighted": recall_score(y_true, y_pred, average="weighted", zero_division=0),
        "f1_weighted": f1_score(y_true, y_pred, average="weighted", zero_division=0),
    }
    return res

def bootstrap_ci(y_true, y_pred, metric_fn=f1_score, n_boot=1000, alpha=0.05, average="macro"):
    scores = []
    n = len(y_true)
    for _ in range(n_boot):
        idx = np.random.choice(n, n, replace=True)
        scores.append(metric_fn(np.array(y_true)[idx], np.array(y_pred)[idx], average=average))
    lower = np.percentile(scores, 100*alpha/2)
    upper = np.percentile(scores, 100*(1-alpha/2))
    return np.mean(scores), (lower, upper)

# Trainer compute_metrics wrapper
def trainer_compute_metrics(eval_pred):
    logits, labels_ids = eval_pred
    preds = np.argmax(logits, axis=-1)
    return compute_metrics_from_preds(labels_ids, preds)

# Custom TrainerCallback for logging GPU stats after each epoch
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

class GPUStatsCallback(TrainerCallback):
    def __init__(self, nvml_available=_NVML_AVAILABLE, interval_sec=2):
        self.nvml_available = nvml_available
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if not self.nvml_available:
            return
        gpu_info = [get_gpu_stats(i) for i in range(pynvml.nvmlDeviceGetCount())]
        # append to file
        logpath = Path(args.output_dir) / "gpu_epoch_stats.jsonl"
        with open(logpath, "a") as f:
            f.write(json.dumps({"epoch": state.epoch, "step": state.global_step, "time": time.time(), "gpu_info": gpu_info}) + "\n")

# Function to create model (optionally wrap in LoRA)
def build_model(model_name=MODEL_NAME, num_labels=num_labels, lora=False, lora_r=8, lora_alpha=16, lora_dropout=0.1):
    base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    if lora:
        peft_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["query", "value", "dense", "key"], # common targets; adjust if errors
            lora_dropout=lora_dropout,
            bias="none",
            task_type="SEQ_CLS"
        )
        model = get_peft_model(base_model, peft_config)
    else:
        model = base_model
    return model

# CPU latency helper (use a small set of texts)
def measure_cpu_latency(predict_fn, samples, n_runs=50):
    # move model to cpu inside predict_fn if needed
    times = []
    for _ in range(n_runs):
        t0 = time.time()
        predict_fn(samples)
        t1 = time.time()
        times.append(t1-t0)
    return np.mean(times), np.std(times)

# Prediction helper using the loaded trainer/model for inference
def predict_with_model(trainer_or_model, texts, batch_size=16, device=None):
    # Accept either Trainer or model (PeftModel or base)
    if hasattr(trainer_or_model, "predict"):
        ds = HFDataset.from_dict({"text": texts}).map(lambda x: tokenizer(x["text"], truncation=True, max_length=256), batched=True)
        preds_out = trainer_or_model.predict(ds)
        preds = np.argmax(preds_out.predictions, axis=1)
        return preds
    else:
        # treat as model
        model = trainer_or_model
        model.eval()
        if device:
            model.to(device)
        inputs = tokenizer(texts, truncation=True, padding=True, return_tensors="pt", max_length=256)
        with torch.no_grad():
            outputs = model(**{k:v.to(model.device) for k,v in inputs.items()})
            logits = outputs.logits.cpu().numpy()
            preds = np.argmax(logits, axis=1)
        return preds


search_lrs = [2e-5]
search_epochs = [10,20]
search_batch = [,1632]
seeds = [42,123]  

results = []
runs_dir = Path("./runs")
runs_dir.mkdir(exist_ok=True)

for model_mode in ["full"]:
    print(f"\n=== Starting experiments for mode: {model_mode} ===")
    for lr, epochs, batch_size in product(search_lrs, search_epochs, search_batch):
        config_name = f"{model_mode}_lr{lr}_ep{epochs}_bs{batch_size}"
        print("\nConfig:", config_name)
        run_records = []
        for seed in seeds:
            set_seed(seed)
            run_out = runs_dir / f"{config_name}_seed{seed}"
            run_out.mkdir(parents=True, exist_ok=True)

            # Build model (LoRA wraps base)
            use_lora = model_mode == "lora"
            model = build_model(MODEL_NAME, num_labels, lora=use_lora, lora_r=8, lora_alpha=16, lora_dropout=0.1)

            # TrainingArguments
            training_args = TrainingArguments(
                output_dir=str(run_out),
                per_device_train_batch_size=batch_size,
                per_device_eval_batch_size=batch_size,
                learning_rate=lr,
                num_train_epochs=epochs,
                eval_strategy="epoch",
                save_strategy="epoch",
                load_best_model_at_end=True,
                metric_for_best_model="f1_macro",
                greater_is_better=True,
                logging_steps=50,
                save_total_limit=3,
                seed=seed,
                fp16=torch.cuda.is_available(),
                report_to="none"
            )

            # Trainer
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_hf,
                eval_dataset=val_hf,
                processing_class=tokenizer,
                data_collator=data_collator,
                compute_metrics=trainer_compute_metrics,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=2), GPUStatsCallback()]
            )

            # Train
            t0 = time.time()
            trainer.train()
            train_time = time.time() - t0

            # Evaluate on validation (best model loaded)
            eval_out = trainer.evaluate(eval_dataset=val_hf)

            test_preds_out = trainer.predict(test_hf)
            test_preds = np.argmax(test_preds_out.predictions, axis=1)
            test_labels = test_preds_out.label_ids

            # per-class report
            report = classification_report(test_labels, test_preds, target_names=labels, output_dict=True, zero_division=0)
            report_df = pd.DataFrame(report).transpose()
            report_df.to_csv(run_out / "classification_report_test.csv")

            # bootstrap CI for macro F1
            f1_mean, (f1_low, f1_high) = bootstrap_ci(test_labels, test_preds, f1_score, n_boot=500, average="macro")

            # GPU stats file (last logged if any)
            gpu_stats_file = run_out / "gpu_epoch_stats.jsonl"
            gpu_stats = []
            if gpu_stats_file.exists():
                with open(gpu_stats_file, "r") as f:
                    for l in f:
                        gpu_stats.append(json.loads(l))

            # Save artifacts: model & tokenizer
            if use_lora:
                # Save PEFT adapters (this will save only adapter weights)
                trainer.model.save_pretrained(run_out / "peft_adapter")
                tokenizer.save_pretrained(run_out / "tokenizer")
            else:
                trainer.save_model(run_out / "full_model")
                tokenizer.save_pretrained(run_out / "tokenizer")

            # CPU latency measurement (move model to cpu temporarily)
            try:
                # load best model for CPU test
                if use_lora:
                    base_m = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)
                    peft_loaded = PeftModel.from_pretrained(base_m, run_out / "peft_adapter")
                    peft_loaded = peft_loaded.merge_and_unload()  
                    cpu_model = peft_loaded.to("cpu")
                else:
                    cpu_model = AutoModelForSequenceClassification.from_pretrained(run_out / "full_model").to("cpu")
                sample_texts = val_df["text"].sample(32, random_state=seed).tolist()
                cpu_mean, cpu_std = measure_cpu_latency(lambda texts: predict_with_model(cpu_model, texts, device="cpu"), sample_texts, n_runs=10)
            except Exception as e:
                cpu_mean, cpu_std = None, None
                print("CPU latency measurement failed:", e)

            rec = {
                "config": config_name,
                "seed": seed,
                "model_mode": model_mode,
                "lr": lr,
                "epochs": epochs,
                "batch_size": batch_size,
                "train_time_s": train_time,
                "val_metrics": eval_out,
                "test_f1_macro": f1_score(test_labels, test_preds, average="macro"),
                "test_f1_weighted": f1_score(test_labels, test_preds, average="weighted"),
                "test_accuracy": accuracy_score(test_labels, test_preds),
                "f1_macro_boot_mean": f1_mean,
                "f1_macro_ci_low": f1_low,
                "f1_macro_ci_high": f1_high,
                "gpu_stats": gpu_stats,
                "cpu_latency_mean_s": cpu_mean,
                "cpu_latency_std_s": cpu_std,
                "run_dir": str(run_out)
            }
            run_records.append(rec)
            # free memory
            del trainer
            del model
            torch.cuda.empty_cache()
        # aggregate seed runs for this config
        results.extend(run_records)
        # Save partial results after each config
        pd.DataFrame(results).to_json("transformer_grid_results.jsonl", orient="records", lines=True)



# Finalize results to CSV
results_df = pd.read_json("transformer_grid_results.jsonl", lines=True)
results_df.to_csv("transformer_grid_results.csv", index=False)
print("Saved transformer grid results to transformer_grid_results.csv")

res = pd.read_csv("transformer_grid_results.csv")
# explode val_metrics column if present
def extract_metric(d, key):
    try:
        return json.loads(d).get(key)
    except Exception:
        return None

# example plot: test F1 by lr for LoRA vs Full (mean across seeds)
agg = res.groupby(["model_mode","lr"])["test_f1_macro"].mean().reset_index()
plt.figure(figsize=(8,5))
for mode in agg["model_mode"].unique():
    sub = agg[agg["model_mode"]==mode]
    plt.plot(sub["lr"], sub["test_f1_macro"], marker="o", label=mode)
plt.xscale("log")
plt.xlabel("Learning rate")
plt.ylabel("Mean test F1 (macro)")
plt.title("LR sweep: mean test macro-F1 across modes")
plt.legend()
plt.show()

# Per-class F1 for best run (choose highest test_f1_macro)
best_idx = res["test_f1_macro"].idxmax()
best_run = res.loc[best_idx]
print("Best run record:\n", best_run.to_dict())

best_report = pd.read_csv(Path(best_run["run_dir"]) / "classification_report_test.csv", index_col=0)
best_report = best_report.sort_values("f1-score", ascending=False)
plt.figure(figsize=(12,5))
plt.bar(best_report.index[:-3], best_report["f1-score"][:-3])
plt.xticks(rotation=90)
plt.ylabel("F1-Score")
plt.title("Per-class F1 (best run)")
plt.show()

print("Training + grid search complete. Results saved. Use transformer_grid_results.csv for numeric tables and per-run folders for artifacts.")


def predict_clause(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(model.device)
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
        pred_id = logits.argmax(dim=-1).item()
    return {
        "text": text,
        "predicted_label": labels[pred_id],   # use true label names
        "confidence": float(probs[pred_id])
    }

texts = [
    "This Agreement shall be governed by the laws of the State of New York.",
    "The liability of the supplier shall not exceed the total fees paid under this agreement."
]

for t in texts:
    print(predict_clause(t))