In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from collections import defaultdict
import time
from sklearn.metrics import precision_recall_curve, auc

import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer,
    default_data_collator
)

import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

Device: cuda
GPU: NVIDIA L4
VRAM: 23.80 GB


The next cell require two files to run:
'train_separate_questions.json' and 'test.json'

In [None]:
with open('train_separate_questions.json', 'r') as f:
    train_data = json.load(f)

with open('test.json', 'r') as f:
    test_data = json.load(f)

def convert_cuad(data):
    examples = []
    for doc in data['data']:
        title = doc['title']
        for paragraph in doc['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                examples.append({
                    'id': qa['id'],
                    'title': title,
                    'question': qa['question'],
                    'context': context,
                    'answers': {
                        'text': [ans['text'] for ans in qa.get('answers', [])] if 'answers' in qa else [''],
                        'answer_start': [ans['answer_start'] for ans in qa.get('answers', [])] if 'answers' in qa else [0]
                    }
                })
    return examples

train_examples = convert_cuad(train_data)
test_examples = convert_cuad(test_data)

cuad = {
    'train': Dataset.from_list(train_examples),
    'test': Dataset.from_list(test_examples)
}

print(f"Train: {len(cuad['train']):,} examples")
print(f"Test: {len(cuad['test']):,} examples")

Train: 22,450 examples
Test: 4,182 examples


In [None]:
model_name = "nlpaueb/legal-bert-base-uncased"

print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
model.to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Loading nlpaueb/legal-bert-base-uncased...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameters: 108,893,186


In [None]:
def preprocess_with_stride(examples, tokenizer, max_length=512, stride=128, max_chunks=10):
    all_input_ids = []
    all_attention_mask = []
    all_start_positions = []
    all_end_positions = []

    for i in range(len(examples['question'])):
        question = examples['question'][i]
        context = examples['context'][i]
        answers = examples['answers'][i]

        question_tokens = tokenizer(question, add_special_tokens=False, truncation=False)
        context_tokens = tokenizer(context, add_special_tokens=False, truncation=False, return_offsets_mapping=True)

        question_len = len(question_tokens['input_ids'])
        context_len = len(context_tokens['input_ids'])
        offset_mapping = context_tokens['offset_mapping']

        answer_start_char = None
        answer_end_char = None
        if answers['text'] and len(answers['text'][0]) > 0:
            answer_start_char = answers['answer_start'][0]
            answer_end_char = answer_start_char + len(answers['text'][0])

        max_context_len = max_length - question_len - 3
        stride_step = max_context_len - stride

        chunk_count = 0
        for start_idx in range(0, context_len, stride_step):
            if chunk_count >= max_chunks:
                break

            end_idx = min(start_idx + max_context_len, context_len)

            input_ids = (
                [tokenizer.cls_token_id] +
                question_tokens['input_ids'] +
                [tokenizer.sep_token_id] +
                context_tokens['input_ids'][start_idx:end_idx] +
                [tokenizer.sep_token_id]
            )

            padding_length = max_length - len(input_ids)
            input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
            attention_mask = [1] * (len(input_ids) - padding_length) + [0] * padding_length

            start_pos = 0
            end_pos = 0

            if answer_start_char is not None:
                chunk_offset_mapping = offset_mapping[start_idx:end_idx]

                for token_idx, (char_start, char_end) in enumerate(chunk_offset_mapping):
                    if char_start <= answer_start_char < char_end:
                        start_pos = question_len + 2 + token_idx
                    if char_start < answer_end_char <= char_end:
                        end_pos = question_len + 2 + token_idx

                if start_pos == 0 or end_pos == 0 or start_pos > end_pos:
                    start_pos = 0
                    end_pos = 0

            all_input_ids.append(input_ids)
            all_attention_mask.append(attention_mask)
            all_start_positions.append(start_pos)
            all_end_positions.append(end_pos)

            chunk_count += 1

    return {
        'input_ids': all_input_ids,
        'attention_mask': all_attention_mask,
        'start_positions': all_start_positions,
        'end_positions': all_end_positions
    }

print("Tokenizing...")

tokenized_train = cuad['train'].map(
    lambda x: preprocess_with_stride(x, tokenizer, max_length=512, stride=128, max_chunks=10),
    batched=True,
    remove_columns=cuad['train'].column_names,
    desc="Tokenizing train"
)

tokenized_test = cuad['test'].map(
    lambda x: preprocess_with_stride(x, tokenizer, max_length=512, stride=128, max_chunks=10),
    batched=True,
    remove_columns=cuad['test'].column_names,
    desc="Tokenizing test"
)

print(f"Training examples: {len(tokenized_train):,}")
print(f"Test examples: {len(tokenized_test):,}")
print(f"Expansion factor: {len(tokenized_train) / len(cuad['train']):.1f}x")

Tokenizing...


Tokenizing train:   0%|          | 0/22450 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (7369 > 512). Running this sequence through the model will result in indexing errors


Tokenizing test:   0%|          | 0/4182 [00:00<?, ? examples/s]

Training examples: 206,226
Test examples: 36,511
Expansion factor: 9.2x


In [None]:
def compute_exact_match(pred, truth):
    return int(pred.strip().lower() == truth.strip().lower())

def compute_f1(pred, truth):
    pred_tokens = pred.strip().lower().split()
    truth_tokens = truth.strip().lower().split()
    if len(pred_tokens) == 0 and len(truth_tokens) == 0:
        return 1.0
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return 0.0
    common = set(pred_tokens) & set(truth_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(truth_tokens)
    return 2 * (precision * recall) / (precision + recall)

def predict_batch_with_confidence(model, tokenizer, dataset, device, batch_size=16):
    model.eval()
    predictions = []
    confidences = []
    for idx in tqdm(range(0, len(dataset), batch_size), desc="Predicting"):
        batch = dataset[idx:min(idx + batch_size, len(dataset))]
        if not isinstance(batch['question'], list):
            batch = {k: [v] for k, v in batch.items()}
        inputs = tokenizer(batch['question'], batch['context'], max_length=512, truncation="only_second", padding=True, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        for j in range(len(batch['question'])):
            start_logits = outputs.start_logits[j]
            end_logits = outputs.end_logits[j]
            start_idx = torch.argmax(start_logits).item()
            end_idx = torch.argmax(end_logits).item()
            start_probs = torch.softmax(start_logits, dim=0)
            end_probs = torch.softmax(end_logits, dim=0)
            confidence = (start_probs[start_idx] + end_probs[end_idx]) / 2
            confidence = confidence.item()
            if start_idx <= end_idx:
                answer_tokens = inputs['input_ids'][j][start_idx:end_idx+1]
                pred = tokenizer.decode(answer_tokens, skip_special_tokens=True)
            else:
                pred = ""
                confidence = 0.0
            predictions.append(pred)
            confidences.append(confidence)
    return predictions, confidences

def evaluate_model(dataset, predictions, confidences=None, model_name="Model"):
    from sklearn.metrics import precision_recall_curve, auc

    em_scores = []
    f1_scores = []
    binary_labels = []

    for example, pred in zip(dataset, predictions):
        truth = example['answers']['text'][0] if example['answers']['text'] and len(example['answers']['text']) > 0 else ""
        em_scores.append(compute_exact_match(pred, truth))
        f1_score = compute_f1(pred, truth)
        f1_scores.append(f1_score)
        binary_labels.append(1 if f1_score > 0 else 0)

    results = {
        "model": model_name,
        "exact_match": np.mean(em_scores) * 100,
        "f1": np.mean(f1_scores) * 100,
        "total": len(predictions)
    }

    if confidences is not None:
        binary_labels = np.array(binary_labels)
        confidences_array = np.array(confidences)

        precisions, recalls, thresholds = precision_recall_curve(binary_labels, confidences_array)
        aupr = auc(recalls, precisions)
        results["aupr"] = aupr * 100

        idx_80 = np.argmin(np.abs(recalls - 0.80))
        results["precision_at_80_recall"] = precisions[idx_80] * 100

        idx_90 = np.argmin(np.abs(recalls - 0.90))
        results["precision_at_90_recall"] = precisions[idx_90] * 100

    return results

def print_results(results):
    print("="*60)
    print(f"{results['model']} RESULTS")
    print("="*60)
    print(f"Exact Match:        {results['exact_match']:6.2f}%")
    print(f"F1 Score:           {results['f1']:6.2f}%")

    if "aupr" in results:
        print(f"AUPR:               {results['aupr']:6.2f}%")
        print(f"Precision @ 80% R:  {results['precision_at_80_recall']:6.2f}%")
        print(f"Precision @ 90% R:  {results['precision_at_90_recall']:6.2f}%")

    print(f"Total examples:     {results['total']:,}")
    print("="*60)

In [None]:
print("STAGE 1: Zero-Shot Evaluation\n")

model_name = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(device)

stage1_predictions, stage1_confidences = predict_batch_with_confidence(model, tokenizer, cuad['test'], device, batch_size=16)
stage1_results = evaluate_model(cuad['test'], stage1_predictions, stage1_confidences, model_name="Legal-BERT Zero-Shot")

print_results(stage1_results)

STAGE 1: Zero-Shot Evaluation



Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predicting:   0%|          | 0/262 [00:00<?, ?it/s]

Legal-BERT Zero-Shot RESULTS
Exact Match:         46.58%
F1 Score:            47.27%
AUPR:                39.27%
Precision @ 80% R:   56.43%
Precision @ 90% R:   56.43%
Total examples:     4,182


In [None]:
print("STAGE 2: Fine-tuning Legal-BERT\n")

training_args = TrainingArguments(
    output_dir="./legal_bert_cuad_stride",
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    warmup_steps=1000,
    logging_steps=100,
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    fp16=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
)

print("Starting training...\n")
train_result = trainer.train()

print(f"\nTraining complete")
print(f"Time: {train_result.metrics['train_runtime']/60:.1f} minutes")
print(f"Loss: {train_result.training_loss:.4f}")

model.save_pretrained("./legal_bert_cuad_stride_final")
tokenizer.save_pretrained("./legal_bert_cuad_stride_final")
print("Model saved")

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


STAGE 2: Fine-tuning Legal-BERT

Starting training...



Epoch,Training Loss,Validation Loss
1,0.0914,0.067427
2,0.0703,0.060516



Training complete
Time: 182.3 minutes
Loss: 0.1522
Model saved


In [None]:
import shutil
from google.colab import files

print("Zipping first model...")
shutil.make_archive('legal_bert_cuad_stride', 'zip', './legal_bert_cuad_stride')
print("Downloading stride model...")
files.download('legal_bert_cuad_stride.zip')

print("\nZipping final model...")
shutil.make_archive('legal_bert_cuad_stride_final', 'zip', './legal_bert_cuad_stride_final')
print("Downloading stride_final model...")
files.download('legal_bert_cuad_stride_final.zip')

print("\n✅ Both downloads complete!")

Zipping first model...
Downloading stride model...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Zipping final model...
Downloading stride_final model...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


✅ Both downloads complete!


Fine-tuned Model : legal_bert_cuad_stride_final

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create folder first
!mkdir -p legal_bert_cuad_stride_final

# Unzip INTO the folder
!unzip -q "/content/drive/MyDrive/legal_bert_cuad_stride_final.zip" -d ./legal_bert_cuad_stride_final

# Now load from folder
model = AutoModelForQuestionAnswering.from_pretrained("./legal_bert_cuad_stride_final")
tokenizer = AutoTokenizer.from_pretrained("./legal_bert_cuad_stride_final")
model.to(device)

print("✅ Fine-tuned model loaded")

Mounted at /content/drive
✅ Fine-tuned model loaded


Fine-tuned Legal BERT model named 'legal_bert_cuad_stride_final.zip' file is provided. Unzip it, upload the extracted folder to Colab, and then run the next cell.

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained("./legal_bert_cuad_stride_final")
tokenizer = AutoTokenizer.from_pretrained("./legal_bert_cuad_stride_final")
model.to(device)

In [None]:
print("Evaluating fine-tuned model...\n")

stage2_predictions, stage2_confidences = predict_batch_with_confidence(
    model, tokenizer, cuad['test'], device, batch_size=16
)

stage2_results = evaluate_model(
    cuad['test'],
    stage2_predictions,
    stage2_confidences,
    model_name="Legal-BERT Fine-tuned"
)

print_results(stage2_results)

Evaluating fine-tuned model...



Predicting:   0%|          | 0/262 [00:00<?, ?it/s]

Legal-BERT Fine-tuned RESULTS
Exact Match:         73.74%
F1 Score:            74.65%
AUPR:                79.46%
Precision @ 80% R:   77.49%
Precision @ 90% R:   77.29%
Total examples:     4,182


In [None]:
def analyze_by_clause_category(dataset, predictions):
    clause_categories = {
        'Parties': ['parties'],
        'Agreement Date': ['agreement date', 'effective date'],
        'Expiration Date': ['expiration date', 'renewal term'],
        'Governing Law': ['governing law'],
        'Termination': ['termination', 'can be terminated'],
        'IP Rights': ['intellectual property', 'ip ownership'],
        'Confidentiality': ['confidential information', 'confidentiality'],
        'Liability': ['liability', 'cap on liability'],
        'Payment Terms': ['payment', 'price', 'cost'],
        'Non-Compete': ['non-compete', 'competitive restriction'],
        'Insurance': ['insurance'],
        'Warranties': ['warranties', 'representations'],
        'Indemnification': ['indemnification', 'indemnify'],
        'Audit Rights': ['audit', 'auditing'],
    }

    category_results = defaultdict(lambda: {'em': [], 'f1': [], 'count': 0})

    for i, example in enumerate(dataset):
        question = example['question'].lower()
        truth = example['answers']['text'][0] if example['answers']['text'] and len(example['answers']['text']) > 0 else ""
        pred = predictions[i]

        category = 'Other'
        for cat_name, keywords in clause_categories.items():
            if any(kw in question for kw in keywords):
                category = cat_name
                break

        em = compute_exact_match(pred, truth)
        f1 = compute_f1(pred, truth)

        category_results[category]['em'].append(em)
        category_results[category]['f1'].append(f1)
        category_results[category]['count'] += 1

    results = []
    for category, metrics in category_results.items():
        results.append({
            'Category': category,
            'Count': metrics['count'],
            'EM': np.mean(metrics['em']) * 100,
            'F1': np.mean(metrics['f1']) * 100
        })

    df = pd.DataFrame(results).sort_values('F1', ascending=False)
    return df

def analyze_by_answer_length(dataset, predictions):
    length_buckets = {
        'No Answer': (0, 0),
        'Short (1-5 words)': (1, 5),
        'Medium (6-15 words)': (6, 15),
        'Long (16-30 words)': (16, 30),
        'Very Long (>30 words)': (31, 1000)
    }

    bucket_results = defaultdict(lambda: {'em': [], 'f1': [], 'count': 0})

    for i, example in enumerate(dataset):
        truth = example['answers']['text'][0] if example['answers']['text'] and len(example['answers']['text']) > 0 else ""
        pred = predictions[i]

        truth_len = len(truth.split()) if truth else 0
        bucket = 'No Answer'
        for bucket_name, (min_len, max_len) in length_buckets.items():
            if min_len <= truth_len <= max_len:
                bucket = bucket_name
                break

        em = compute_exact_match(pred, truth)
        f1 = compute_f1(pred, truth)

        bucket_results[bucket]['em'].append(em)
        bucket_results[bucket]['f1'].append(f1)
        bucket_results[bucket]['count'] += 1

    results = []
    for bucket, metrics in bucket_results.items():
        results.append({
            'Answer Length': bucket,
            'Count': metrics['count'],
            'EM': np.mean(metrics['em']) * 100,
            'F1': np.mean(metrics['f1']) * 100
        })

    df = pd.DataFrame(results)
    order = ['No Answer', 'Short (1-5 words)', 'Medium (6-15 words)', 'Long (16-30 words)', 'Very Long (>30 words)']
    df['Answer Length'] = pd.Categorical(df['Answer Length'], categories=order, ordered=True)
    df = df.sort_values('Answer Length')
    return df

def analyze_by_contract_length(dataset, predictions):
    length_results = defaultdict(lambda: {'em': [], 'f1': [], 'count': 0})

    for i, example in enumerate(dataset):
        context = example['context']
        truth = example['answers']['text'][0] if example['answers']['text'] and len(example['answers']['text']) > 0 else ""
        pred = predictions[i]

        token_count = len(context.split())
        if token_count < 2000:
            bucket = 'Short (<2K tokens)'
        elif token_count < 5000:
            bucket = 'Medium (2K-5K)'
        elif token_count < 10000:
            bucket = 'Long (5K-10K)'
        else:
            bucket = 'Very Long (>10K)'

        em = compute_exact_match(pred, truth)
        f1 = compute_f1(pred, truth)

        length_results[bucket]['em'].append(em)
        length_results[bucket]['f1'].append(f1)
        length_results[bucket]['count'] += 1

    results = []
    for bucket, metrics in length_results.items():
        results.append({
            'Contract Length': bucket,
            'Count': metrics['count'],
            'EM': np.mean(metrics['em']) * 100,
            'F1': np.mean(metrics['f1']) * 100
        })

    df = pd.DataFrame(results)
    order = ['Short (<2K tokens)', 'Medium (2K-5K)', 'Long (5K-10K)', 'Very Long (>10K)']
    df['Contract Length'] = pd.Categorical(df['Contract Length'], categories=order, ordered=True)
    df = df.sort_values('Contract Length')
    return df


In [None]:
print("\n" + "="*70)
print("DEEPER PERFORMANCE ANALYSIS")
print("="*70)

print("\n1. Performance by Clause Category:\n")
category_df = analyze_by_clause_category(cuad['test'], stage2_predictions)
print(category_df.to_string(index=False))

print("\n2. Performance by Answer Length:\n")
length_df = analyze_by_answer_length(cuad['test'], stage2_predictions)
print(length_df.to_string(index=False))

print("\n3. Performance by Contract Length:\n")
contract_df = analyze_by_contract_length(cuad['test'], stage2_predictions)
print(contract_df.to_string(index=False))

print("\n" + "="*70)


DEEPER PERFORMANCE ANALYSIS

1. Performance by Clause Category:

       Category  Count         EM         F1
  Payment Terms    102 100.000000 100.000000
Confidentiality    102  87.254902  87.254902
      IP Rights    306  82.679739  82.679739
    Non-Compete    204  80.882353  80.882353
          Other   2040  80.098039  80.768438
    Termination    306  76.470588  76.470588
      Insurance    102  68.627451  68.627451
 Agreement Date    204  59.313725  63.725537
   Audit Rights    102  62.745098  62.745098
        Parties    306  56.209150  58.956845
      Liability    102  56.862745  57.683539
Expiration Date    204  50.490196  53.329279
  Governing Law    102  18.627451  18.627451

2. Performance by Answer Length:

        Answer Length  Count        EM        F1
            No Answer   2938 99.727706 99.727706
    Short (1-5 words)    315 43.492063 50.237002
  Medium (6-15 words)     87 14.942529 25.416037
   Long (16-30 words)    213  0.938967  2.193218
Very Long (>30 words)   

The following cell requires 'test_refactored.json' file to run.

In [None]:
print("\n" + "="*70)
print("STAGE 3: Paraphrased Questions (Question Type Analysis)")
print("="*70)

with open('test_refactored.json', 'r') as f:
    paraphrased_data = json.load(f)

def convert_paraphrased_cuad(data):
    examples = []
    for doc in data['data']:
        for paragraph in doc['paragraphs']:
            for qa in paragraph['qas']:
                examples.append({
                    'id': qa['id'],
                    'title': doc['title'],
                    'question': qa['refactored_question'],
                    'original_question': qa['question'],
                    'context': paragraph['context'],
                    'answers': {
                        'text': [ans['text'] for ans in qa.get('answers', [])] if qa.get('answers') else [''],
                        'answer_start': [ans['answer_start'] for ans in qa.get('answers', [])] if qa.get('answers') else [0]
                    }
                })
    return examples

paraphrased_examples = convert_paraphrased_cuad(paraphrased_data)
paraphrased_test = Dataset.from_list(paraphrased_examples)

print(f"\nLoaded {len(paraphrased_test):,} paraphrased questions")

stage3_predictions, stage3_confidences = predict_batch_with_confidence(
    model, tokenizer, paraphrased_test, device, batch_size=16
)

stage3_results = evaluate_model(
    paraphrased_test,
    stage3_predictions,
    stage3_confidences,
    model_name="Paraphrased Questions"
)

print_results(stage3_results)

print(f"\nOriginal Legal Phrasing (Stage 2):")
print(f"  EM: {stage2_results['exact_match']:.2f}%")
print(f"  F1: {stage2_results['f1']:.2f}%")
print(f"  AUPR: {stage2_results['aupr']:.2f}%")
print(f"  P@80%R: {stage2_results['precision_at_80_recall']:.2f}%")
print(f"  P@90%R: {stage2_results['precision_at_90_recall']:.2f}%")

print(f"\nParaphrased User Phrasing (Stage 3):")
print(f"  EM: {stage3_results['exact_match']:.2f}%")
print(f"  F1: {stage3_results['f1']:.2f}%")
print(f"  AUPR: {stage3_results['aupr']:.2f}%")
print(f"  P@80%R: {stage3_results['precision_at_80_recall']:.2f}%")
print(f"  P@90%R: {stage3_results['precision_at_90_recall']:.2f}%")

gap_f1 = stage2_results['f1'] - stage3_results['f1']
gap_em = stage2_results['exact_match'] - stage3_results['exact_match']
gap_aupr = stage2_results['aupr'] - stage3_results['aupr']
gap_p80 = stage2_results['precision_at_80_recall'] - stage3_results['precision_at_80_recall']
gap_p90 = stage2_results['precision_at_90_recall'] - stage3_results['precision_at_90_recall']
retention_f1 = (stage3_results['f1'] / stage2_results['f1']) * 100 if stage2_results['f1'] > 0 else 0
retention_em = (stage3_results['exact_match'] / stage2_results['exact_match']) * 100 if stage2_results['exact_match'] > 0 else 0
retention_aupr = (stage3_results['aupr'] / stage2_results['aupr']) * 100 if stage2_results['aupr'] > 0 else 0

print(f"\nQuestion Type Generalization:")
print(f"  F1 Gap: {gap_f1:.2f}% (Retention: {retention_f1:.1f}%)")
print(f"  EM Gap: {gap_em:.2f}% (Retention: {retention_em:.1f}%)")
print(f"  AUPR Gap: {gap_aupr:.2f}% (Retention: {retention_aupr:.1f}%)")
print(f"  P@80%R Gap: {gap_p80:.2f}%")
print(f"  P@90%R Gap: {gap_p90:.2f}%")

if abs(gap_f1) < 5:
    print("  ✅ Strong - Handles both phrasings well")
elif abs(gap_f1) < 10:
    print("  ⚠️ Moderate - Some sensitivity to phrasing")
else:
    print("  ❌ Large - Overfits to legal phrasing")

print("\n" + "="*70)


STAGE 3: Paraphrased Questions (Question Type Analysis)

Loaded 4,182 paraphrased questions


Predicting:   0%|          | 0/262 [00:00<?, ?it/s]

Paraphrased Questions RESULTS
Exact Match:         68.05%
F1 Score:            68.26%
AUPR:                79.02%
Precision @ 80% R:   75.47%
Precision @ 90% R:   73.48%
Total examples:     4,182

Original Legal Phrasing (Stage 2):
  EM: 73.74%
  F1: 74.65%
  AUPR: 79.46%
  P@80%R: 77.49%
  P@90%R: 77.29%

Paraphrased User Phrasing (Stage 3):
  EM: 68.05%
  F1: 68.26%
  AUPR: 79.02%
  P@80%R: 75.47%
  P@90%R: 73.48%

Question Type Generalization:
  F1 Gap: 6.39% (Retention: 91.4%)
  EM Gap: 5.69% (Retention: 92.3%)
  AUPR Gap: 0.44% (Retention: 99.5%)
  P@80%R Gap: 2.02%
  P@90%R Gap: 3.82%
  ⚠️ Moderate - Some sensitivity to phrasing



In [None]:
import random

print("="*70)
print("SAMPLE PREDICTIONS ANALYSIS")
print("="*70)

with_answer_correct = []
with_answer_wrong = []
no_answer_correct = []
no_answer_wrong = []

for i, example in enumerate(cuad['test']):
    truth = example['answers']['text'][0] if example['answers']['text'] and len(example['answers']['text']) > 0 else ""
    pred = stage2_predictions[i]

    if truth:
        if compute_exact_match(pred, truth) == 1:
            with_answer_correct.append((example['question'], truth, pred))
        else:
            with_answer_wrong.append((example['question'], truth, pred))
    else:
        if pred.strip() == "":
            no_answer_correct.append((example['question'], truth, pred))
        else:
            no_answer_wrong.append((example['question'], truth, pred))

print(f"\n📊 Statistics:")
print(f"Questions WITH answer: {len(with_answer_correct) + len(with_answer_wrong)}")
print(f"  ✅ Correct: {len(with_answer_correct)} ({len(with_answer_correct)/(len(with_answer_correct)+len(with_answer_wrong))*100:.1f}%)")
print(f"  ❌ Wrong: {len(with_answer_wrong)} ({len(with_answer_wrong)/(len(with_answer_correct)+len(with_answer_wrong))*100:.1f}%)")

print(f"\nQuestions with NO answer: {len(no_answer_correct) + len(no_answer_wrong)}")
print(f"  ✅ Correct (predicted empty): {len(no_answer_correct)} ({len(no_answer_correct)/(len(no_answer_correct)+len(no_answer_wrong))*100:.1f}%)")
print(f"  ❌ Wrong (predicted something): {len(no_answer_wrong)} ({len(no_answer_wrong)/(len(no_answer_correct)+len(no_answer_wrong))*100:.1f}%)")

print("\n" + "="*70)
print("✅ CORRECT PREDICTIONS (Questions WITH Answers)")
print("="*70)
for q, truth, pred in random.sample(with_answer_correct, min(5, len(with_answer_correct))):
    print(f"\nQ: {q[:100]}...")
    print(f"Truth: {truth}")
    print(f"Predicted: {pred}")

print("\n" + "="*70)
print("❌ WRONG PREDICTIONS (Questions WITH Answers)")
print("="*70)
for q, truth, pred in random.sample(with_answer_wrong, min(5, len(with_answer_wrong))):
    print(f"\nQ: {q[:100]}...")
    print(f"Truth: {truth}")
    print(f"Predicted: {pred}")

print("\n" + "="*70)
print("✅ CORRECT: No Answer Questions (Correctly predicted empty)")
print("="*70)
for q, truth, pred in random.sample(no_answer_correct, min(3, len(no_answer_correct))):
    print(f"\nQ: {q[:100]}...")
    print(f"Truth: [No Answer]")
    print(f"Predicted: [No Answer]")

print("\n" + "="*70)
print("❌ WRONG: No Answer Questions (Incorrectly predicted something)")
print("="*70)
for q, truth, pred in random.sample(no_answer_wrong, min(5, len(no_answer_wrong))):
    print(f"\nQ: {q[:100]}...")
    print(f"Truth: [No Answer]")
    print(f"Predicted: {pred}")

print("\n" + "="*70)

SAMPLE PREDICTIONS ANALYSIS

📊 Statistics:
Questions WITH answer: 1244
  ✅ Correct: 154 (12.4%)
  ❌ Wrong: 1090 (87.6%)

Questions with NO answer: 2938
  ✅ Correct (predicted empty): 2930 (99.7%)
  ❌ Wrong (predicted something): 8 (0.3%)

✅ CORRECT PREDICTIONS (Questions WITH Answers)

Q: Highlight the parts (if any) of this contract related to "Effective Date" that should be reviewed by...
Truth: March 3, 2011
Predicted: march 3, 2011

Q: Highlight the parts (if any) of this contract related to "Agreement Date" that should be reviewed by...
Truth: January 24, 2014
Predicted: january 24, 2014

Q: Highlight the parts (if any) of this contract related to "Agreement Date" that should be reviewed by...
Truth: 1st day of November, 2002
Predicted: 1st day of november, 2002

Q: Highlight the parts (if any) of this contract related to "Document Name" that should be reviewed by ...
Truth: JOINT VENTURE CONTRACT
Predicted: joint venture contract

Q: Highlight the parts (if any) of this contract 