In [1]:
# 1. Install all required libraries
# We use datasets==2.19.0 to support the CUAD loading script.
# The 'gcsfs' dependency error in the output is EXPECTED and SAFE to ignore.
!pip install --upgrade datasets langchain-text-splitters transformers sentence-transformers faiss-cpu evaluate accelerate -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h

# Restoration Code
If you restart your session and want to skip the 1+ hour cell run:

1. Upload `rag_backup.zip` and `official_results.csv` using the file upload icon in Colab.

2. Run the code below

# Load Data & Implement Sliding Window

In [2]:
def load_dataset():
  # Load from your refactored JSON file
  json_path = "test_refactored.json"

  with open(json_path, 'r', encoding='utf-8') as f:
      data = json.load(f)


  dataset = []
  if 'data' in data:
      for contract in data['data']:
          contract_id = contract.get('title', 'unknown')
          paragraphs = contract.get('paragraphs', [])

          # Store all paragraphs with their QAs
          contract_paragraphs = []
          for para in paragraphs:
              paragraph_data = {
                  'context': para.get('context', ''),
                  'qas': []
              }

              # Extract all QAs from this paragraph
              if 'qas' in para:
                  for qa in para['qas']:
                      qa_data = {
                          'id': qa.get('id', ''),
                          'question': qa.get('question', ''),
                          'refactored_question': qa.get('refactored_question', ''),
                          'answers': qa.get('answers', []),  # Array of {text, answer_start}
                          'is_impossible': qa.get('is_impossible', False)
                      }
                      paragraph_data['qas'].append(qa_data)

              contract_paragraphs.append(paragraph_data)

          dataset.append({
              'id': contract_id,
              'paragraphs': contract_paragraphs
          })
  return dataset

In [3]:

import json
from langchain_text_splitters import RecursiveCharacterTextSplitter
from tqdm import tqdm

# --- A. LOAD DATASET FROM REFACTORED JSON ---
print("1. Loading CUAD dataset (REFACTORED TEST SPLIT)...")

dataset = load_dataset()

print(f"   Loaded {len(dataset)} contracts.")

# --- B. DEFINE SLIDING WINDOW ---
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2500,
    chunk_overlap=250,
    length_function=len,
)

# --- C. APPLY CHUNKING (FULL DATASET) ---
print("2. Applying Sliding Window to ALL contracts...")
all_passages = []
doc_map = []

# Iterate over each contract in the dataset
for contract in tqdm(dataset):
    contract_id = contract['id']

    # Iterate over each paragraph in the contract
    for para in contract['paragraphs']:
        context = para['context']
        chunks = text_splitter.split_text(context)

        for chunk in chunks:
            all_passages.append(chunk)
            doc_map.append({
                "id": contract_id,
                "text": chunk,
                "qas": para['qas']  # Keep QAs linked for evaluation
            })

print(f"\nSuccess! Created {len(all_passages)} searchable passages from the full dataset.")

1. Loading CUAD dataset (REFACTORED TEST SPLIT)...
   Loaded 102 contracts.
2. Applying Sliding Window to ALL contracts...


100%|██████████| 102/102 [00:00<00:00, 107.66it/s]


Success! Created 2793 searchable passages from the full dataset.





# Build the Retriever and the Generator
the **Retriever** (The "Search Engine"):

It uses a Sentence Transformer to convert every text chunk into a list of numbers (an "embedding").

It stores these embeddings in a FAISS Index, which allows for ultra-fast similarity searching.

**Generator** (The "Answerer"):

It loads **FLAN-T5-Base**, a
Google model trained to follow instructions.

It defines a function `generate_rag_answer(question)` that connects the whole pipeline: **Retrieve -> Combine -> Generate**.

In [4]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch

# # 1. Load the Embedding Model
# # 'all-mpnet-base-v2' is widely used for semantic search (it understands meaning, not just keywords)
# print("1. Loading Retriever Model...")
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# retriever_model = SentenceTransformer('all-mpnet-base-v2', device=device)

# # 2. Encode all passages (Convert text to vectors)
# # This might take 1-2 minutes depending on how many chunks you have
# print(f"2. Encoding {len(all_passages)} passages...")
# passage_embeddings = retriever_model.encode(all_passages, show_progress_bar=True)

# # 3. Build the FAISS Index
# d = passage_embeddings.shape[1]  # Dimension of the vector (768 for mpnet)
# index = faiss.IndexFlatL2(d)     # Use L2 distance (Euclidean) for similarity
# index.add(passage_embeddings)    # Add our chunk vectors to the index

print("1. Loading Retriever Model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
retriever_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device=device)

# 2. Encode all passages (Convert text to vectors)
# BGE models benefit from adding an instruction prefix for queries (not passages)
print(f"2. Encoding {len(all_passages)} passages...")
passage_embeddings = retriever_model.encode(all_passages, show_progress_bar=True, normalize_embeddings=True)

# 3. Build the FAISS Index
# Using Inner Product (IP) since embeddings are normalized - equivalent to cosine similarity
d = passage_embeddings.shape[1]  # Dimension of the vector (1024 for bge-large)
index = faiss.IndexFlatIP(d)     # Use Inner Product for normalized vectors
index.add(passage_embeddings)

print(f"   Success! Index built with {index.ntotal} vectors.")

1. Loading Retriever Model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

2. Encoding 2793 passages...


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

   Success! Index built with 2793 vectors.


In [5]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import CrossEncoder
import torch

model_name = "google/flan-t5-large"
print(f"Loading Generator: {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")

Loading Generator: google/flan-t5-large...


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [6]:
def llm_answer(question, retrieved_text):
  prompt = f"""
  Task: Check if the extracted legal text EXACTLY answers the question or not. If the retrieved text perfectly answers the question, then output 'Answer', in any other case output 'No Answer'. Do not output anything else.
  Retrieved Text: {retrieved_text}

  Question: {question}"""

  inputs = tokenizer(prompt, return_tensors="pt", max_length=1500, truncation=True).to(generator_model.device)

  # max_new_tokens=300 allows for long legal clauses
  outputs = generator_model.generate(**inputs, max_new_tokens=300, temperature=0.1)
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return answer

In [7]:
from collections import defaultdict
import numpy as np

# Build separate indices per contract
print("Building per-contract indices...")
contract_indices = {}
contract_passages = defaultdict(list)
contract_embeddings = defaultdict(list)

for i, meta in enumerate(doc_map):
    cid = meta["id"]
    contract_passages[cid].append(all_passages[i])
    contract_embeddings[cid].append(passage_embeddings[i])

# Create a FAISS index for each contract
for cid, embeddings in contract_embeddings.items():
    embeddings_array = np.array(embeddings).astype('float32')
    idx = faiss.IndexFlatIP(embeddings_array.shape[1])
    idx.add(embeddings_array)
    contract_indices[cid] = idx

print(f"Built {len(contract_indices)} contract indices")


def generate_rag_answer(query, k=5, contract_id=None, threshold=0.7):
    """
    Generate RAG answer with optional similarity threshold filtering.

    Args:
        query: The search query
        k: Maximum number of passages to retrieve
        contract_id: Optional contract ID to search within
        threshold: Minimum similarity score (0.0 to 1.0) for retrieved passages

    Returns:
        answer: LLM-generated answer
        retrieved_text: Concatenated retrieved passages
        scores: Similarity scores of retrieved passages (for debugging)
    """
    query_with_instruction = "Represent this sentence for searching relevant passages: " + query
    query_embedding = retriever_model.encode([query_with_instruction], normalize_embeddings=True)

    if contract_id:
        idx = contract_indices[contract_id]
        passages = contract_passages[contract_id]
    else:
        idx = index
        passages = all_passages

    # Ensure we don't request more than available
    k_actual = min(k, idx.ntotal)
    scores, indices = idx.search(query_embedding, k_actual)

    # Filter by threshold
    filtered_chunks = []
    filtered_scores = []
    for score, i in zip(scores[0], indices[0]):
        if score >= threshold:
            filtered_chunks.append(passages[i])
            filtered_scores.append(score)

    # Handle case where no passages meet threshold
    if not filtered_chunks:
        return "No Answer", "", []

    retrieved_text = " ".join(filtered_chunks)
    answer = llm_answer(query, retrieved_text)
    return answer, retrieved_text, filtered_scores

Building per-contract indices...
Built 102 contract indices


## Performance Analysis

* It runs the RAG model on the Official questions.

* It runs the RAG model on the Paraphrased questions.

* It calculates F1 scores (accuracy).

* It prints the "Performance by Clause Category" breakdown.

In [12]:
def compute_f1_em(prediction, truth, answer):
    prediction = prediction.strip()
    truth = truth.strip()

    if len(truth) == 0:
      print(f"Truth: {truth}")

    # Normalize for comparison
    pred_lower = prediction.lower()
    truth_lower = truth.lower()
    answer_lower = answer.lower()

    if (answer_lower == "no answer"):
      em = 100.0 if truth_lower in answer_lower else 0.0

    else:
      # EM: exact substring match
      em = 100.0 if truth_lower in pred_lower else 0.0

    # F1: Use token overlap with recall-focused scoring
    pred_tokens = pred_lower.split()
    truth_tokens = truth_lower.split()

    if len(truth_tokens) == 0:
        return em, 100.0 if len(pred_tokens) == 0 else 0.0

    # Count matching tokens
    truth_counter = {}
    for t in truth_tokens:
        truth_counter[t] = truth_counter.get(t, 0) + 1

    pred_counter = {}
    for t in pred_tokens:
        pred_counter[t] = pred_counter.get(t, 0) + 1

    num_common = sum(min(truth_counter[t], pred_counter.get(t, 0)) for t in truth_counter)

    if num_common == 0:
        return em, 0.0

    precision = num_common / len(pred_tokens)
    recall = num_common / len(truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall) * 100

    return em, f1

In [13]:
all_em_scores = []
all_f1_scores = []
count = 0

for contract in dataset:
    contract_id = contract['id']
    for para in contract['paragraphs']:
        context = para['context']
        for qa in para['qas']:
            # if count >= 100:
            #       break
            # if count % 100:
            #       print(f"Completed: {count}")

            question = qa['refactored_question']

            if qa['is_impossible'] == True:
                ground_truth = "No Answer"

            else:
              ground_truth = qa['answers'][0]['text']

            pred_ans, retrieved_text, _ = generate_rag_answer(question, 5, contract_id)
            em, f1 = compute_f1_em(retrieved_text, ground_truth, pred_ans)

            # if(em==0):
            #   print("++++++++++++++++++++++++++++++++++")
            #   print(f"Context: {context}")
            #   print("@@@@@@@@@@@@@@@@@")
            #   print(f"Question: {question}")
            #   print(f"ground_truth: {ground_truth}")
            #   print(f"retrieved_text: {retrieved_text}")

            all_em_scores.append(em)
            all_f1_scores.append(f1)

    #         count += 1
    #     if count >= 100:
    #         break
    # if count >= 100:
    #     break
    #         count += 1
    #     if count % 100:
    #         print(f"Completed: {count}")
    # if count % 100:
    #     print(f"Completed: {count}")

# Calculate averages
avg_em = sum(all_em_scores) / len(all_em_scores) if all_em_scores else 0
avg_f1 = sum(all_f1_scores) / len(all_f1_scores) if all_f1_scores else 0

print("Paraphrased Questions:")
print(f"Average EM: {avg_em:.2f}")
print(f"Average F1: {avg_f1:.2f}")

ValueError: At least 2 points are needed to compute area under curve, but x.shape = 0

In [16]:
import numpy as np
from sklearn.metrics import auc

def compute_precision_recall_curve(scores, labels):
    """
    Compute precision-recall curve from scores and binary labels.
    """
    # Sort by scores in descending order
    sorted_indices = np.argsort(scores)[::-1]
    sorted_labels = np.array(labels)[sorted_indices]

    precisions = []
    recalls = []
    thresholds = []

    tp = 0
    fp = 0
    total_positives = np.sum(labels)

    if total_positives == 0:
        return np.array([1.0]), np.array([0.0]), np.array([])

    for i, label in enumerate(sorted_labels):
        if label == 1:
            tp += 1
        else:
            fp += 1

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / total_positives

        precisions.append(precision)
        recalls.append(recall)
        thresholds.append(scores[sorted_indices[i]])

    return np.array(precisions), np.array(recalls), np.array(thresholds)


def compute_aupr(scores, labels):
    """
    Compute Area Under Precision-Recall curve.
    """
    precisions, recalls, _ = compute_precision_recall_curve(scores, labels)

    # Sort by recall for proper AUC calculation
    sorted_indices = np.argsort(recalls)
    recalls_sorted = recalls[sorted_indices]
    precisions_sorted = precisions[sorted_indices]

    return auc(recalls_sorted, precisions_sorted)


def compute_precision_at_recall(scores, labels, target_recall=0.8):
    """
    Compute precision at a specific recall level.
    """
    precisions, recalls, _ = compute_precision_recall_curve(scores, labels)

    # Find the precision at the target recall
    valid_indices = recalls >= target_recall

    if not np.any(valid_indices):
        return None  # Target recall not achievable

    # Return the maximum precision at or above target recall
    return np.max(precisions[valid_indices])

def print_metrics(question_type = "question"):
    all_em_scores = []
    all_f1_scores = []
    count = 0

    for contract in dataset:
        contract_id = contract['id']
        for para in contract['paragraphs']:
            context = para['context']
            for qa in para['qas']:
                question = qa[question_type]

                if qa['is_impossible'] == True:
                    ground_truth = "No Answer"
                else:
                    ground_truth = qa['answers'][0]['text']

                pred_ans, retrieved_text, _ = generate_rag_answer(question, 5, contract_id)
                em, f1 = compute_f1_em(retrieved_text, ground_truth, pred_ans)

                all_em_scores.append(em)
                all_f1_scores.append(f1)

    # Calculate averages
    avg_em = sum(all_em_scores) / len(all_em_scores) if all_em_scores else 0
    avg_f1 = sum(all_f1_scores) / len(all_f1_scores) if all_f1_scores else 0

    # Calculate AUPR and Precision@Recall metrics
    # Use F1 scores as confidence scores and EM as binary labels
    scores = np.array([f1 / 100.0 for f1 in all_f1_scores])  # Normalize to [0, 1]
    labels = np.array([1 if em == 100.0 else 0 for em in all_em_scores])

    aupr = compute_aupr(scores, labels)
    p_at_80r = compute_precision_at_recall(scores, labels, target_recall=0.8)
    p_at_90r = compute_precision_at_recall(scores, labels, target_recall=0.9)

    if (question_type == "question"):
        print("Actuall CUAD Questions:")
    else:
        print("Paraphrased Questions:")
    print(f"Average EM: {avg_em:.2f}")
    print(f"Average F1: {avg_f1:.2f}")
    print(f"AUPR: {aupr:.4f}")
    print(f"Precision@80%R: {p_at_80r:.4f}" if p_at_80r is not None else "Precision@80%R: N/A (target recall not achievable)")
    print(f"Precision@90%R: {p_at_90r:.4f}" if p_at_90r is not None else "Precision@90%R: N/A (target recall not achievable)")

In [17]:
print_metrics('question')

Actuall CUAD Questions:
Average EM: 70.40
Average F1: 0.26
AUPR: 0.7071
Precision@80%R: 0.7084
Precision@90%R: 0.7075


In [18]:
print_metrics('refactored_question')

Paraphrased Questions:
Average EM: 70.61
Average F1: 0.40
AUPR: 0.7047
Precision@80%R: 0.7104
Precision@90%R: 0.7095


In [10]:

from collections import defaultdict
import pandas as pd

def analyze_by_clause_category(dataset, generate_rag_answer, compute_f1_em, question_type = 'question', max_samples=None):
    """
    Analyze RAG model performance by clause category on CUAD dataset.

    Args:
        dataset: CUAD dataset
        generate_rag_answer: Function to generate RAG answers
        compute_f1_em: Function to compute F1 and EM scores
        max_samples: Maximum number of samples to evaluate (None for all)
    """
    clause_categories = {
        'Document Name': ['document name'],
        'Parties': ['parties'],
        'Agreement Date': ['agreement date', 'effective date'],
        'Expiration Date': ['expiration date', 'renewal term'],
        'Governing Law': ['governing law'],
        'Termination': ['termination', 'can be terminated'],
        'IP Rights': ['intellectual property', 'ip ownership'],
        'Confidentiality': ['confidential information', 'confidentiality'],
        'Liability': ['liability', 'cap on liability'],
        'Payment Terms': ['payment', 'price', 'cost'],
        'Non-Compete': ['non-compete', 'competitive restriction'],
        'Insurance': ['insurance'],
        'Warranties': ['warranties', 'representations'],
        'Indemnification': ['indemnification', 'indemnify'],
        'Audit Rights': ['audit', 'auditing'],
    }

    category_results = defaultdict(lambda: {'em': [], 'f1': [], 'count': 0})
    count = 0

    for contract in dataset:
        contract_id = contract['id']
        for para in contract['paragraphs']:
            for qa in para['qas']:
                if max_samples and count >= max_samples:
                    break

                question = qa[question_type]

                if qa['is_impossible'] == True:
                    ground_truth = "No Answer"

                else:
                  ground_truth = qa['answers'][0]['text']
                pred_ans, retrieved_text, _ = generate_rag_answer(question, 5, contract_id)

                em, f1 = compute_f1_em(retrieved_text, ground_truth, pred_ans)

                # Determine category
                question_lower = question.lower()
                category = 'Other'
                for cat_name, keywords in clause_categories.items():
                    if any(kw in question_lower for kw in keywords):
                        category = cat_name
                        break

                category_results[category]['em'].append(em)
                category_results[category]['f1'].append(f1)
                category_results[category]['count'] += 1

                count += 1

            if max_samples and count >= max_samples:
                break
        if max_samples and count >= max_samples:
            break

    # Build results list
    results = []
    for category, metrics in category_results.items():
        if metrics['count'] > 0:
            results.append({
                'Category': category,
                'Count': metrics['count'],
                'EM': sum(metrics['em']) / len(metrics['em']),
                'F1': sum(metrics['f1']) / len(metrics['f1'])
            })

    df = pd.DataFrame(results).sort_values('F1', ascending=False)

    # Print results
    print("\n" + "=" * 60)
    print("Performance by Clause Category:")
    print("=" * 60)
    print(df.to_string(index=False))
    print("=" * 60)

    # Print summary
    total_count = sum(r['Count'] for r in results)
    overall_em = sum(sum(m['em']) for m in category_results.values()) / total_count if total_count > 0 else 0
    overall_f1 = sum(sum(m['f1']) for m in category_results.values()) / total_count if total_count > 0 else 0

    print(f"\nTotal Questions Evaluated: {total_count}")
    print(f"Number of Categories: {len(results)}")
    print(f"Overall EM: {overall_em:.2f}")
    print(f"Overall F1: {overall_f1:.2f}")

Actuall CUAD Dataset

In [None]:
analyze_by_clause_category(dataset, generate_rag_answer, compute_f1_em, "question")


Performance by Clause Category:
       Category  Count        EM       F1
  Governing Law    102 22.549020 0.761849
Expiration Date    204 51.470588 0.609399
   Audit Rights    102 63.725490 0.577507
      Liability    102 57.843137 0.446688
    Termination    306 76.797386 0.303356
      Insurance    102 69.607843 0.295539
          Other   1938 81.269350 0.289251
    Non-Compete    204 81.372549 0.175089
      IP Rights    306 81.699346 0.107588
        Parties    306 53.594771 0.054001
  Payment Terms    102 99.019608 0.004830
  Document Name    102  0.000000 0.000000
 Agreement Date    204 20.098039 0.000000
Confidentiality    102 87.254902 0.000000

Total Questions Evaluated: 4182
Number of Categories: 14
Overall EM: 70.40
Overall F1: 0.26


Paraphrased questions

In [None]:

analyze_by_clause_category(dataset, generate_rag_answer, compute_f1_em, "refactored_question")


Performance by Clause Category:
       Category  Count         EM       F1
  Governing Law    102  27.450980 1.764035
Expiration Date    204  52.941176 1.591437
      Liability    102  58.823529 1.083258
   Audit Rights    102  63.725490 0.577507
    Termination    306  83.006536 0.491262
    Non-Compete    204  81.372549 0.347604
          Other   1938  76.522188 0.317998
        Parties    408  57.843137 0.252965
      Insurance    102  69.607843 0.196078
      IP Rights    306  81.372549 0.090549
 Agreement Date    204  20.588235 0.025716
  Payment Terms    102 100.000000 0.000000
Confidentiality    102  87.254902 0.000000

Total Questions Evaluated: 4182
Number of Categories: 13
Overall EM: 70.61
Overall F1: 0.40
