# Intro
CUAD_V1 Overview (Contract Understanding Atticus Dataset)
- Public dataset of 510 legal contracts annotated for clause extraction

- Includes 41 clause types (e.g., Confidentiality, Termination, Indemnification, Force Majeure)

- Annotations are span-based, making it ideal for NER and clause detection tasks

- Labeled by legal professionals to ensure real-world accuracy and relevance

- Enables training of AI models to automate contract review and analysis

- Popular in legal NLP research and fine-tuning of transformer-based models

- Developed and released by The Atticus Project to promote AI in law

https://www.atticusprojectai.org/cuad


https://youtu.be/hFUSdgryXyU?si=IPkKI_SSwUpfLWhG

https://publications.cohubicol.com/typology/cuad/?utm_source=chatgpt.com

# Step 0: Setup & Package Installation

In [70]:
!nvidia-smi        # Check A100 is actually assigned
import torch
torch.cuda.is_available()

Sun Aug  3 01:04:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   70C    P0             32W /   72W |    2231MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

True

In [71]:
!pip install seqeval > /dev/null 2>&1
!pip install datasets > /dev/null 2>&1
!pip install evaluate > /dev/null 2>&1  # Install the evaluate library
!pip install fitz > /dev/null 2>&1
!pip install rapidfuzz > /dev/null 2>&1
!pip install --upgrade --force-reinstall PyMuPDF > /dev/null 2>&1
!pip install datefinder > /dev/null 2>&1
!pip install thefuzz > /dev/null 2>&1
!pip install fuzzywuzzy > /dev/null 2>&1
!pip install fuzzysearch



#### Load Packages

In [72]:
# ======================
# STANDARD LIBRARIES
# ======================
import os
import re
import math
import json
import string
import random
import csv
import time
from pathlib import Path
from collections import Counter, defaultdict
from difflib import SequenceMatcher
from datetime import datetime
import shutil  # file moves, copies, and deletes

# ======================
# DATA HANDLING
# ======================
import pandas as pd
import numpy as np

# ======================
# VISUALIZATION
# ======================
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# ======================
# PROGRESS DISPLAY
# ======================
from tqdm import tqdm

# ======================
# TORCH AND GPU
# ======================
import torch
from torch.utils.data import Dataset, DataLoader  # local datasets and loaders

# ======================
# TRANSFORMERS
# ======================
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
)

# ======================
# HUGGING FACE DATASETS
# Avoid name clash with torch.utils.data.Dataset by aliasing
# ======================
from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict

# ======================
# EVALUATION
# seqeval is for sequence labeling metrics, sklearn for general reports if needed
# ======================
from seqeval.metrics import (
    classification_report as seqeval_report,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report as sk_classification_report

# ======================
# PDF PARSING
# ======================
import fitz  # PyMuPDF

# ======================
# NLP UTILITIES
# ======================
import spacy
from spacy.lang.en import English
from dateutil import parser

# ======================
# FUZZY MATCHING
# ======================
import rapidfuzz
import fuzzysearch
from thefuzz import fuzz, process

In [73]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Step 1: Define File Paths & Variables

In [74]:
## Random seed for reproducibility
RANDOM_SEED = 42

# Model selection
MODEL_NAME = "bert-base-cased"   # or "distilbert-base-uncased", "legal-bert-base-uncased"

# Training hyperparameters
BATCH_SIZE = 16
EPOCHS = 6
LR_SCHEDULER = 'cosine'
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

In [75]:
# File structure
MASTER_PATH = "/content/drive/MyDrive/CUAD/CUAD_v1/"            # Root CUAD folder
TC_PATH = os.path.join(MASTER_PATH, "full_contract_pdf/")       # Folder with raw PDFs
OUTPUT_DIR  = os.path.join(MASTER_PATH, "model_outputs")


# Filenames
MASTER_CLAUSES = "master_clauses.csv"                            # CSV with annotated answers
JSON_EXPORT = "jsonl_cuadv1.json"                                # Output: spans before cleanup
JSON_EXPORT_CLEANED = "jsonl_cuadv1_cleaned.json"                # Output: cleaned spans
UNMATCHED_EXPORT = "unmatched_parties_cases.csv"                 # Optional export of unmatched rows

# Canonical cleaned file lives under MASTER_PATH
CLEAN_JSONL = os.path.join(MASTER_PATH, NER_SUBDIR, JSON_EXPORT_CLEANED)
# Where to store NER JSONL files
NER_DIR = "ner_outputs"

In [76]:
# Step 1: Mount Google Drive, handle "mountpoint already contains files"
from google.colab import drive
import os, shutil

MOUNT = "/content/drive"

if os.path.ismount(MOUNT):
    print("Drive already mounted at", MOUNT)
else:
    # If the folder exists and is nonempty, clear it before mounting
    if os.path.isdir(MOUNT) and os.listdir(MOUNT):
        shutil.rmtree(MOUNT, ignore_errors=True)
    drive.mount(MOUNT, force_remount=True)

print("MyDrive visible:", os.path.exists("/content/drive/MyDrive"))

Drive already mounted at /content/drive
MyDrive visible: True


In [77]:
# Create directories if they don't already exist
os.makedirs(NER_DIR, exist_ok=True)                  # Folder for NER-related outputs
os.makedirs(OUTPUT_DIR, exist_ok=True)               # Folder for main project outputs
os.makedirs(os.path.dirname(CLEAN_JSONL), exist_ok=True)  # Parent folder for cleaned JSONL file

# Step 2: Load PDF Filenames and Clause Labels

This step loads:

1. A list of PDF contract filenames from the folder.
2. The master clause labels provided by the CUAD team.

We then:
- Keep only a subset of relevant clause types: `Document Name`, `Parties`, and `Agreement Date`
- Sort both datasets alphabetically to ensure proper alignment
- Insert the PDF filenames into the clause DataFrame for downstream merging

This step produces a table where each row represents a contract and its corresponding labeled clauses.


In [78]:
# Collect all PDF filenames from the contracts folder
pdf_files = []

for dirpath, dirnames, filenames in os.walk(TC_PATH):
    pdf_files.extend([f for f in filenames if f.lower().endswith(".pdf")])

# Create DataFrame of PDF filenames
pdf_df = pd.DataFrame({"PDF Files": sorted(pdf_files)})
pdf_df.head(3)

Unnamed: 0,PDF Files
0,2ThemartComInc_19990826_10-12G_EX-10.10_670028...
1,ABILITYINC_06_15_2020-EX-4.25-SERVICES AGREEME...
2,ACCELERATEDTECHNOLOGIESHOLDINGCORP_04_24_2003-...


In [79]:
# Load clause answer labels from the master clause file
mc_df = pd.read_csv(os.path.join(MASTER_PATH, MASTER_CLAUSES))

# Keep only the relevant clause columns
mc_df_cut = mc_df[[
    "Filename",
    "Document Name",
    "Document Name-Answer",
    "Parties",
    "Parties-Answer",
    "Agreement Date",
    "Agreement Date-Answer"
]].copy()

# Sort both DataFrames for alignment
mc_df_cut.sort_values("Filename", inplace=True, ignore_index=True)
pdf_df.sort_values("PDF Files", inplace=True, ignore_index=True)

# Insert the sorted PDF filenames into the master clause DataFrame
mc_df_cut.insert(loc=1, column="PDF Files", value=pdf_df)

# Display sample rows and info
display(mc_df_cut.head(3))
mc_df_cut.info()

Unnamed: 0,Filename,PDF Files,Document Name,Document Name-Answer,Parties,Parties-Answer,Agreement Date,Agreement Date-Answer
0,2ThemartComInc_19990826_10-12G_EX-10.10_670028...,2ThemartComInc_19990826_10-12G_EX-10.10_670028...,['CO-BRANDING AND ADVERTISING AGREEMENT'],CO-BRANDING AND ADVERTISING AGREEMENT,"['2THEMART.COM, INC.', '2TheMart', 'i-Escrow',...","I-ESCROW, INC. (""i-Escrow"" ); 2THEMART.COM, I...","['June 21, 1999']",6/21/99
1,ABILITYINC_06_15_2020-EX-4.25-SERVICES AGREEME...,ABILITYINC_06_15_2020-EX-4.25-SERVICES AGREEME...,['Services Agreement'],Services Agreement,"['""Provider""', 'TELCOSTAR PTE, LTD.', 'Each of...","[ * * * ] (""Provider""); TELCOSTAR PTE, LTD.; A...","['October 1, 2019']",10/1/19
2,ACCELERATEDTECHNOLOGIESHOLDINGCORP_04_24_2003-...,ACCELERATEDTECHNOLOGIESHOLDINGCORP_04_24_2003-...,['JOINT VENTURE AGREEMENT'],JOINT VENTURE AGREEMENT,"['Pivotal Self Service Tech, Inc.', '(the ""Par...","Collectible Concepts Group, Inc. (""CCGI""); Piv...",[],


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 510 entries, 0 to 509
Data columns (total 8 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   Filename               510 non-null    object
 1   PDF Files              510 non-null    object
 2   Document Name          510 non-null    object
 3   Document Name-Answer   510 non-null    object
 4   Parties                510 non-null    object
 5   Parties-Answer         509 non-null    object
 6   Agreement Date         510 non-null    object
 7   Agreement Date-Answer  465 non-null    object
dtypes: object(8)
memory usage: 32.0+ KB


# Step 3: Extract and Clean Text from PDFs

This step extracts raw contract text from each PDF file in the dataset using PyMuPDF (`fitz`).

We also clean the text using `pre_process_doc_common()`:
- Removes newline characters and unusual formatting
- Replaces multiple dashes, stars, and underscores
- Normalizes whitespace

The result is a DataFrame `cuad_df` with:
- `FileName`: the name of the PDF
- `text`: the cleaned contract content

This cleaned text will be used to match clause spans in later steps.

In [80]:
# Use rglob to find all PDFs recursively (including upper/lower case)
file_list = list(Path(TC_PATH).rglob("*.[pP][dD][fF]"))

In [81]:
# Text cleaning function for standard PDF parsing workflow
def pre_process_doc_common(text):
    text = text.replace("\n", " ")  # Simple replacement for "\n"
    text = text.replace("\xa0", " ")  # Simple replacement for "\xa0"
    text = text.replace("\x0c", " ")  # Simple replacement for "\x0c"

    regex = "\ \.\ "
    subst = "."
    text = re.sub(regex, subst, text, 0)  # Get rid of multiple dots

    regex = "_"
    subst = " "
    text = re.sub(regex, subst, text, 0)  # Get rid of underscores

    regex = "--+"
    subst = " "
    text = re.sub(regex, subst, text, 0)   # Get rid of multiple dashes

    regex = "\*+"
    subst = "*"
    text = re.sub(regex, subst, text, 0)  # Get rid of multiple stars

    regex = "\ +"
    subst = " "
    text = re.sub(regex, subst, text, 0)  # Get rid of multiple whitespace

    text = text.strip()  #Strip leading and trailing whitespace
    return text

In [82]:
def extract_text_from_pdfs(files_list, clean_text=True):
    """
    Extract text from a list of PDF file paths.
    Optionally applies preprocessing for cleaning.
    Returns a list of [filename, cleaned_text].
    """
    text_list = []

    for filepath in tqdm(files_list):
        doc = fitz.open(filepath)
        full_text = "".join([page.get_text("text") for page in doc])
        doc.close()

        if clean_text:
            full_text = pre_process_doc_common(full_text)

        text_list.append([filepath.name, full_text])

    return text_list

In [83]:
# Extract and clean text
text_data = extract_text_from_pdfs(file_list, clean_text=True)

100%|██████████| 510/510 [13:56<00:00,  1.64s/it]


In [84]:
# Create DataFrame of extracted text
cuad_df = pd.DataFrame(text_data, columns=["FileName", "text"])
cuad_df.head(3)

Unnamed: 0,FileName,text
0,MARTINMIDSTREAMPARTNERSLP_01_23_2004-EX-10.3-T...,EXHIBIT 10.3 TRANSPORTATION SERVICES AGREEMENT...
1,ENTERPRISEPRODUCTSPARTNERSLP_07_08_1998-EX-10....,EXHIBIT 10.3 [ENTERPRISE LOGO APPEARS HERE] EN...
2,ENERGYXXILTD_05_08_2015-EX-10.13-Transportatio...,Exhibit 10.13 TRANSPORTATION AGREEMENT BETWEEN...


In [85]:
# What does an agreement look like?
cuad_df['text'][0]

'EXHIBIT 10.3 TRANSPORTATION SERVICES AGREEMENT THIS MARINE TRANSPORTATION AGREEMENT (this “Agreement”) is executed this 23rd day of December, 2003, by and between Martin Operating Partnership L.P., a Delaware limited partnership (“Owner”), and Midstream Fuel Service LLC, an Alabama limited liability company (“Charterer”), in order to evidence the agreement of such parties with respect to Owner’s provision of marine transportation services with respect to #2 fuel oil and high sulfur diesel on board its marine vessels under the following terms and conditions. 1. TERM; TERMINATION The initial term of this Agreement shall be for 3 years (the “Initial Term”) commencing on the date first set forth above (the “Commencement Date”) and ending on the 3rd anniversary of the Commencement Date. This Agreement will automatically renew for successive one year terms (each a “Renewal Term”, and together with the Initial Term, the “Term”), unless either Charterer or Owner elects not to renew this Agree

# Step 4: Normalize and Merge Filenames

This step ensures that the extracted contract text (`cuad_df`) aligns with the clause labels from the official CUAD dataset (`mc_df_cut`).

Why normalization is necessary:
- PDF filenames may contain uppercase letters, extra characters, or inconsistent formatting.
- Clause labels in the CSV may use different formatting conventions.
- To align them, we create a `clean_name` field in both datasets using a normalization function.

Once aligned, we merge the clause answer columns (e.g., `Agreement Date-Answer`, `Parties-Answer`) into our text DataFrame.

---

### Why the Clause Answers Matter

These **ground truth clause answers** are critical for model training.

- We use them to **locate the exact spans of each clause** in the contract text (e.g., where "July 1, 2020" appears).
- Once located, we convert those spans into **token-level BIO tags** (e.g., `B-Agreement Date`, `I-Agreement Date`, etc.).
- These token-level labels become the **supervised learning signal** for fine-tuning a pre-trained transformer (like BERT or DistilBERT).

In short:
> The clause answers are what let us fine-tune a general-purpose transformer model into a legal-specific clause extractor.

Without these answers:
- We would not know what the model is supposed to predict.
- Fine-tuning wouldn’t be possible. We’d be doing zero-shot inference instead.


This merge step ensures that every training example has both:
1. The full contract text
2. The correct clause spans (used to supervise the model)

---
Just an extra note:
- **Zero-shot inference** means using a model to make predictions or perform tasks without giving it any examples from the target task during training. The model relies entirely on what it learned during its original pretraining.



In [86]:
def normalize_filename(name):
    """
    Normalize filenames by:
    - Lowercasing
    - Removing .pdf/.txt extensions
    - Replacing non-alphanumeric characters with underscores
    - Collapsing repeated underscores
    """
    name = name.lower().strip()
    name = re.sub(r"\.pdf$", "", name)
    name = re.sub(r"\.txt$", "", name)
    name = re.sub(r"[^a-z0-9_]", "_", name)
    name = re.sub(r"_+", "_", name)
    return name.strip("_")

In [87]:
# Apply normalization to both DataFrames
cuad_df["clean_name"] = cuad_df["FileName"].apply(normalize_filename)
mc_df_cut["clean_name"] = mc_df_cut["Filename"].apply(normalize_filename)

In [88]:
# Merge clause labels into text dataframe
cuad_df = pd.merge(
    cuad_df.drop(columns=["Document Name-Answer", "Parties-Answer", "Agreement Date-Answer"], errors="ignore"),
    mc_df_cut[["clean_name", "Document Name-Answer", "Parties-Answer", "Agreement Date-Answer"]],
    on="clean_name",
    how="left"
)

cuad_df

Unnamed: 0,FileName,text,clean_name,Document Name-Answer,Parties-Answer,Agreement Date-Answer
0,MARTINMIDSTREAMPARTNERSLP_01_23_2004-EX-10.3-T...,EXHIBIT 10.3 TRANSPORTATION SERVICES AGREEMENT...,martinmidstreampartnerslp_01_23_2004_ex_10_3_t...,MARINE TRANSPORTATION AGREEMENT,"Martin Operating Partnership L.P. (""Owner""); M...",12/23/03
1,ENTERPRISEPRODUCTSPARTNERSLP_07_08_1998-EX-10....,EXHIBIT 10.3 [ENTERPRISE LOGO APPEARS HERE] EN...,enterpriseproductspartnerslp_07_08_1998_ex_10_...,TRANSPORTATION CONTRACT,"Enterprise Transportation Company (""Carrier"");...",6/1/98
2,ENERGYXXILTD_05_08_2015-EX-10.13-Transportatio...,Exhibit 10.13 TRANSPORTATION AGREEMENT BETWEEN...,energyxxiltd_05_08_2015_ex_10_13_transportatio...,Transportation Agreement,"Energy XXI Gulf Coast, Inc. (""Shipper""); Energ...",3/11/15
3,"HEMISPHERX - Sales, Marketing, Distribution, a...","Exhibit 10.1 Sales, Marketing, Distribution, a...",hemispherx_sales_marketing_distribution_and_su...,"Sales, Marketing, Distribution, and Supply Agr...","HEMISPHERX BIOPHARMA, INC (""HEMISPHERX""); Scie...",3/31/16
4,FUSIONPHARMACEUTICALSINC_06_05_2020-EX-10.17-S...,Exhibit 10.17 Supply Agreement - FUSION CERTAI...,fusionpharmaceuticalsinc_06_05_2020_ex_10_17_s...,SUPPLY AGREEMENT,Centre for Probe Development and Commercializa...,
...,...,...,...,...,...,...
505,SouthernStarEnergyInc_20051202_SB-2A_EX-9_8018...,Exhibit 10.8 Affiliate Program / Premium Affil...,southernstarenergyinc_20051202_sb_2a_ex_9_8018...,Affiliate Program / Premium Affiliate Manageme...,"Web site owners (hereafter, ""Affiliates""); Sof...",
506,CybergyHoldingsInc_20140520_10-Q_EX-10.27_8605...,Exhibit 10.27 MARKETING AFFILIATE AGREEMENT Be...,cybergyholdingsinc_20140520_10_q_ex_10_27_8605...,MARKETING AFFILIATE AGREEMENT,"Birch First Global Investments Inc. (""Company""...",5/8/14
507,CreditcardscomInc_20070810_S-1_EX-10.33_362297...,"Exhibit 10.33 Last Updated: April 6, 2007 CHAS...",creditcardscominc_20070810_s_1_ex_10_33_362297...,CHASE AFFILIATE AGREEMENT,"Chase Bank USA, N.A., (""Chase""); You (""Affilia...",4/6/07
508,UnionDentalHoldingsInc_20050204_8-KA_EX-10_334...,EXHIBIT 10.1 BUSINESS AFFILIATE AGREEMENT This...,uniondentalholdingsinc_20050204_8_ka_ex_10_334...,BUSINESS AFFILIATE AGREEMENT,"Dr. George D. Green (""Business Affiliate""); UN...",1/28/05


# Step 5: Drop Missing Clause Answers

After merging the clause answers from the CUAD master clause file into our main contract DataFrame (`cuad_df`), we must remove any rows where the required clause answers are missing.

### Why This Matters

We're training a **supervised model**, which means:
- It learns by comparing model predictions to the **true spans** (labels).
- If a label is missing (e.g., the correct *Agreement Date*), the model has no ground truth to learn from.

### Common Causes of Missing Values
- The contract doesn't include a particular clause
- Annotation errors or gaps in the CUAD master clause file
- Text extraction failures (e.g., unreadable or empty PDFs)

### What Gets Dropped

We only drop rows where **any** of the following clause answers is missing:
- `Agreement Date-Answer`
- `Parties-Answer`
- `Document Name-Answer`

This ensures that the training dataset contains only usable examples.answers** (`Agreement Date`, `Parties`, `Document Name`) is `NaN`.

In [89]:
# Step 5: Drop rows from cuad_df where clause answers are missing
CLAUSES = ["Agreement Date", "Document Name", "Parties"]

# Drop contracts missing any target clause answer
cuad_df = cuad_df.dropna(subset=[f"{clause}-Answer" for clause in CLAUSES]).reset_index(drop=True)

print(f"Remaining contracts with all clause answers: {len(cuad_df)}")

Remaining contracts with all clause answers: 463


---

### After This Step

We are left with a filtered version of `cuad_df` that:
- Has valid, extracted contract text
- Has complete answers for all three target clauses
- Can be safely used in the next step: extracting the **exact span locations** of each clause in the text

# Step 6: Fuzzy Matching Clause Spans

To fine-tune a token classification model, we must convert clause answers (e.g., `January 1, 2020`) into **character spans** within the full contract text (e.g., characters 215–230). These spans are required to train a model to label the exact tokens corresponding to each clause.

However, the CUAD dataset only provides clause **answers**, not their **locations** within the contract text. Therefore, we must **find the position of each answer** using a combination of exact and fuzzy matching logic.

---

### Objective

For each contract and each clause (`Agreement Date`, `Parties`, and `Document Name`), identify the **start** and **end** character positions of the answer within the contract text. This information is stored in a new column called `entities`.

---

### Matching Logic

We apply clause-specific logic in a single function called `fuzzy_search_entity()`:

- **Agreement Date**  
  - Parses the answer string into a `datetime` object using `dateutil.parser`.  
  - Searches for common date formats (e.g., `"January 1, 2020"`) in the first ~3000 characters of the contract.  
  - If a date in the contract text matches the parsed value, its span is returned.

- **Parties**  
  - Uses regex to extract named legal entities from the answer (e.g., `Acme Corp. ("Acme")`).  
  - Searches the first ~8000 characters of the contract for those entities.  
  - Returns a match if the detected span contains legal entity keywords like `Inc`, `LLC`, `Corp`, etc.

- **Document Name**  
  - Performs an exact search and fallback fuzzy match using `difflib.SequenceMatcher`.  
  - Searches only the first ~1000 characters (names typically appear at the top).  
  - Returns the highest-confidence match if similarity is above threshold.

---

### Output Format

The matched spans are stored in a new column called `entities` for each row in `cuad_df`. Each entry is a list of dictionaries like:

```json
[
  {"start": 215, "end": 230, "label": "Agreement Date"},
  {"start": 1015, "end": 1032, "label": "Document Name"}
]
```

These spans represent the training labels we’ll use for fine-tuning.

---

### ✅ Clause Match Summary

After extracting all clause spans, we count how many matches were found per clause type and how many contracts contain at least one match. This gives us visibility into match success rates and where extraction logic may need improvement.

You can also filter and export unmatched examples to a `.csv` file for debugging.


In [90]:
# Define the clauses we are extracting
CLAUSES = ["Agreement Date", "Parties", "Document Name"]

In [91]:
def fuzzy_search_entity(answer_text, contract_text, clause_type, max_dist_ratio=0.85):
    if not isinstance(answer_text, str) or not isinstance(contract_text, str):
        return None

    # Expand search window for Parties
    search_space = contract_text[:8000] if clause_type == "Parties" else contract_text[:3000]
    answer_text = answer_text.strip()

    # --- AGREEMENT DATE ---
    if clause_type == "Agreement Date":
        try:
            parsed_date = parser.parse(answer_text, fuzzy=True)
            date_regex = r'(?i)(January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},\s+\d{4}'
            matches = list(re.finditer(date_regex, search_space))
            for m in matches:
                try:
                    if parser.parse(m.group()) == parsed_date:
                        return {"start": m.start(), "end": m.end(), "text": m.group()}
                except:
                    continue
        except:
            return None

    # --- PARTIES ENHANCED REGEX LOGIC ---
    if clause_type == "Parties":
        party_pattern = r'([A-Z][A-Za-z0-9\s\.,&\-]+?\b(?:Inc\.|LLC|Ltd\.|L\.P\.|Corp\.|Company|Corporation))\s*\(".*?"\)'
        matches = re.findall(party_pattern, answer_text)
        for name in matches:
            name_clean = name.strip()
            if len(name_clean) < 5 or len(name_clean) > 200:
                continue
            m = re.search(re.escape(name_clean), search_space, re.IGNORECASE)
            if m:
                span_text = m.group()
                legal_terms = ["inc", "llc", "corp", "company", "l.p.", "plc", "limited", "corporation"]
                if any(term in span_text.lower() for term in legal_terms):
                    return {"start": m.start(), "end": m.end(), "text": span_text}

    # --- EXACT MATCH (ALL CLAUSES) ---
    pattern = re.escape(answer_text)
    m = re.search(pattern, search_space, re.IGNORECASE)
    if m:
        return {"start": m.start(), "end": m.end(), "text": m.group()}

    # --- FUZZY MATCH (ALL CLAUSES) ---
    best_match = None
    best_ratio = 0
    for i in range(len(search_space) - len(answer_text)):
        window = search_space[i:i+len(answer_text)+10]
        ratio = SequenceMatcher(None, window.lower(), answer_text.lower()).ratio()
        if ratio > best_ratio and ratio > max_dist_ratio:
            best_match = {"start": i, "end": i+len(window), "text": window}
            best_ratio = ratio

    # Filter fuzzy spans for Parties
    if best_match and clause_type == "Parties":
        text = best_match["text"].lower()
        legal_terms = ["inc", "llc", "corp", "company", "l.p.", "plc", "limited", "corporation"]
        if (
            len(text) < 20 or
            len(text) > 300 or
            not any(term in text for term in legal_terms)
        ):
            return None

    return best_match

In [92]:
# Apply the span extraction to each row
def extract_spans(row):
    spans = []
    for clause in CLAUSES:
        target = row.get(f"{clause}-Answer", None)
        result = fuzzy_search_entity(target, row["text"], clause)
        if result:
            spans.append({
                "start": result["start"],
                "end": result["end"],
                "label": clause
            })
    return spans

In [93]:
# Run on all contracts
tqdm.pandas()
cuad_df["entities"] = cuad_df.progress_apply(extract_spans, axis=1)

100%|██████████| 463/463 [08:19<00:00,  1.08s/it]


In [94]:
# Track how many clause spans were found per type
clause_counts = Counter()
for ents in cuad_df["entities"]:
    for ent in ents:
        clause_counts[ent["label"]] += 1

# Track match coverage
cuad_df["has_match"] = cuad_df["entities"].apply(lambda x: len(x) > 0)
matched = cuad_df["has_match"].sum()
unmatched = len(cuad_df) - matched

# Output results
print("\nClause Match Summary:")
for clause in CLAUSES:
    print(f" - {clause}: {clause_counts[clause]} matches")
print(f"\nTotal contracts with at least one match: {matched}")
print(f"Total contracts with no matches: {unmatched}")


Clause Match Summary:
 - Agreement Date: 232 matches
 - Parties: 266 matches
 - Document Name: 443 matches

Total contracts with at least one match: 462
Total contracts with no matches: 1


In [95]:
# Loop over each clause and identify unmatched spans
for clause in CLAUSES:
    col_answer = f"{clause}-Answer"

    unmatched = cuad_df[
        (cuad_df[col_answer].notna()) &
        (~cuad_df["entities"].apply(lambda ents: any(ent["label"] == clause for ent in ents)))
    ]

    print(f"Unmatched '{clause}' spans despite ground truth answer: {len(unmatched)}")

    # Saves CSVs for manual inspection
    save_cols = ["FileName", col_answer, "text"]
    unmatched[save_cols].to_csv(f"unmatched_{clause.lower().replace(' ', '_')}_cases.csv", index=False)
    print(f"Exported to unmatched_{clause.lower().replace(' ', '_')}_cases.csv\n")

Unmatched 'Agreement Date' spans despite ground truth answer: 231
Exported to unmatched_agreement_date_cases.csv

Unmatched 'Parties' spans despite ground truth answer: 197
Exported to unmatched_parties_cases.csv

Unmatched 'Document Name' spans despite ground truth answer: 20
Exported to unmatched_document_name_cases.csv



In [96]:
# Save Step 6 output as a raw JSONL file so Step 7 can start from here if needed.
# It writes ner_outputs/jsonl_cuadv1.json and you can skip earlier steps next time.
os.makedirs("ner_outputs", exist_ok=True)
RAW_JSONL = "ner_outputs/jsonl_cuadv1.json"


raw_djson = [
    {"text": row["text"], "entities": (row["entities"] if isinstance(row["entities"], list) else [])}
    for _, row in cuad_df.iterrows()
]
with open(RAW_JSONL, "w") as f:
    for ex in raw_djson:
        f.write(json.dumps(ex) + "\n")
print(f"Wrote {len(raw_djson)} samples to {RAW_JSONL}")

Wrote 463 samples to ner_outputs/jsonl_cuadv1.json


### Optional - code to show why we need Step 7

In [97]:
# Point this to your raw JSONL from Step 6
RAW_JSONL = "ner_outputs/jsonl_cuadv1.json"   # change if needed
OUT_JSON = "ner_outputs/example_overlap.json"

def spans_overlap(a, b):
    return not (a["end"] <= b["start"] or b["end"] <= a["start"])

def find_first_overlap(path):
    with open(path, "r") as f:
        for line in f:
            ex = json.loads(line)
            ents = ex.get("entities", [])
            parties = [e for e in ents if e.get("label") == "Parties"]
            others  = [e for e in ents if e.get("label") in {"Document Name", "Agreement Date"}]
            for p in parties:
                for o in others:
                    if spans_overlap(p, o):
                        return ex, p, o
    return None, None, None

ex, p, o = find_first_overlap(RAW_JSONL)

if ex is None:
    print("No overlaps found or file not found. Check RAW_JSONL path.")
else:
    text = ex["text"]
    # Build a readable excerpt around the overlap
    lo = min(p["start"], o["start"])
    hi = max(p["end"], o["end"])
    pad = 80
    s = max(0, lo - pad)
    e = min(len(text), hi + pad)

    excerpt = text[s:lo] + "[[" + text[lo:hi] + "]]" + text[hi:e]
    print("Excerpt around overlap:\n")
    print(excerpt)
    print("\nSpans:")
    print(" Parties:", p)
    print(" Other  :", o)

    # Save a mini file
    os.makedirs(os.path.dirname(OUT_JSON), exist_ok=True)
    payload = {
        "excerpt": excerpt,
        "text_window_start": s,
        "text_window_end": e,
        "parties": p,
        "other": o
    }
    with open(OUT_JSON, "w") as fout:
        json.dump(payload, fout, indent=2)
    print(f"\nSaved mini example to {OUT_JSON}")

Excerpt around overlap:

[[United National Bancorp Enters Into Outsourcing Agreement with the BISYS Group, Inc.]] Bridgewater, NJ February 18, 1999 United National Bancorp (Nasdaq: UNBJ) announ

Spans:
 Parties: {'start': 67, 'end': 84, 'label': 'Parties'}
 Other  : {'start': 0, 'end': 84, 'label': 'Document Name'}

Saved mini example to ner_outputs/example_overlap.json


# Step 7: Remove Overlapping Party Spans (Preprocessing Before Training)

Before fine-tuning our token classification model, we perform a cleanup step to reduce annotation noise, especially for the **"Parties"** clause, which often overlaps with other clauses like **"Document Name"** or **"Agreement Date"**.

---

### Why This Step Matters

In legal contracts, the **Parties** clause often overlaps or appears near other clause types. For example:

> `"United National Bancorp Enters Into Outsourcing Agreement with the BISYS Group, Inc." Bridgewater, NJ February 18, 1999 United National Bancorp (Nasdaq: UNBJ)...`

In this excerpt:
- **Document Name** = `"United National Bancorp Enters Into Outsourcing Agreement with the BISYS Group, Inc."`  
- **Parties**       = `"the BISYS Group, Inc."` (a substring inside the document name)

Here, `"the BISYS Group, Inc."` appears as part of both the **Document Name** and **Parties** clause answers. If we include both, we risk **double-labeling overlapping text**, which confuses the model during training. Tokens must be assigned only one label.

---

### What the Code Does

- Iterates over all spans in the `entities` list.
- Keeps all **non-Parties** spans.
- For **Parties** spans, checks if they **overlap** with any other clause span.
- If a **Parties** span overlaps another clause, it is **removed**.
- Only **non-overlapping Parties spans** are kept.

---

### Output

- Saves a cleaned JSONL file:  
  `ner_outputs/jsonl_cuadv1_cleaned.json`
- Each entry now contains de-overlapped spans.
- A summary is printed showing:
  - A sample of labels per contract
  - Number of samples with at least one entity
  - Global label distribution (via `Counter`)

This cleaned file will be used in the next step to tokenize and align labels with the model’s expected input format.


In [98]:
def remove_party_overlaps(entities):
    """
    Remove 'Parties' spans that overlap with other clause spans.
    Keeps non-Parties spans and only non-overlapping Parties spans.
    """
    # Store spans for non-Parties clauses
    non_party_spans = [(e['start'], e['end']) for e in entities if e['label'] != 'Parties']
    cleaned = []

    for ent in entities:
        if ent['label'] != 'Parties':
            cleaned.append(ent)
        else:
            overlap = False
            for ns_start, ns_end in non_party_spans:
                # Check for overlap
                if not (ent['end'] <= ns_start or ent['start'] >= ns_end):
                    overlap = True
                    break
            if not overlap:
                cleaned.append(ent)

    return cleaned

In [99]:
## Load original JSONL file (from Step 6)
with open("ner_outputs/jsonl_cuadv1.json", "r") as f:
    raw_djson = [json.loads(line) for line in f]

In [100]:
# Apply overlap cleaning to each example
cleaned_djson = []

for row in raw_djson:
    row["entities"] = remove_party_overlaps(row["entities"])
    cleaned_djson.append(row)

In [101]:
# Save cleaned version to Drive so it persists
with open(CLEAN_JSONL, "w", encoding="utf-8") as f:
    for row in cleaned_djson:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"Cleaned data saved to {CLEAN_JSONL}")

Cleaned data saved to /content/drive/MyDrive/CUAD/CUAD_v1/ner_outputs/jsonl_cuadv1_cleaned.json


In [102]:
## Show results
print("Cleaned djson saved to ner_outputs/jsonl_cuadv1_cleaned.json")

# Sample labels from first example
print(f"Sample raw span labels from djson: {[ent['label'] for ent in cleaned_djson[0]['entities']]}")

# Count how many entries have at least one span
non_empty_count = sum(1 for ex in cleaned_djson if ex["entities"])
print(f"Number of training samples with at least one entity: {non_empty_count}")

# Label distribution across all cleaned spans
label_counts = Counter([e["label"] for x in cleaned_djson for e in x["entities"]])
print("Label counts in saved djson:", label_counts)

Cleaned djson saved to ner_outputs/jsonl_cuadv1_cleaned.json
Sample raw span labels from djson: ['Parties', 'Document Name']
Number of training samples with at least one entity: 462
Label counts in saved djson: Counter({'Document Name': 443, 'Parties': 265, 'Agreement Date': 232})


# Step 8: Tokenize Contract Text and Align Clause Labels

Before training a token classification model (like BERT), we need to tokenize each contract and convert the clause **character spans** into **token-level BIO labels** (`B-`, `I-`, `O`).

---

### Objective

Transform each training sample from this format:

```json
{
  "text": "This Agreement is made on January 1, 2020 by ACME LLC...",
  "entities": [
    {"start": 28, "end": 42, "label": "Agreement Date"},
    {"start": 46, "end": 55, "label": "Parties"}
  ]
}
```

Into this format:

- `tokens`: `["This", "Agreement", "is", "made", "on", "January", "1", ",", "2020", ...]`
- `labels`: `["O", "O", "O", "O", "O", "B-Agreement Date", "I-Agreement Date", "I-Agreement Date", "I-Agreement Date", ...]`

---

### What This Step Does

- Loads the **cleaned dataset** from Step 7
- Loads a Hugging Face tokenizer (`bert-base-cased`)
- Converts clause spans into token-level **BIO labels** using character offset alignment
- Skips special tokens like `[CLS]` and `[SEP]` when labeling
- Stores the final label IDs inside the tokenized dictionary as `labels`

---

### BIO Format Explained

We use the **BIO tagging scheme** to mark where entities start and continue:

| Token | Label             |
|-------|------------------|
| January | B-Agreement Date |
| 1       | I-Agreement Date |
| ,       | I-Agreement Date |
| 2020    | I-Agreement Date |

Tokens not part of any clause get `"O"`.

---

### Output

- A list of tokenized examples, ready for training with `input_ids`, `attention_mask`, and `labels`
- These will be passed into a Hugging Face `Trainer` in the next step


In [103]:
# Load tokenizer (same as the model you'll train)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, max_length = 512)

In [104]:
# Names and preferred locations
FILENAME = "jsonl_cuadv1_cleaned.json"
DRIVE_DIR = Path("/content/drive/MyDrive/CUAD_v1/ner_outputs")
LOCAL_DIRS = [Path("/content/ner_outputs"), Path("/content")]  # matches your Colab sidebar

# Ensure Drive directory exists (safe if already there)
DRIVE_DIR.mkdir(parents=True, exist_ok=True)

In [105]:
def resolve_clean_jsonl() -> Path | None:
    """Return the first existing, non-empty path to the file."""
    candidates = [DRIVE_DIR / FILENAME] + [d / FILENAME for d in LOCAL_DIRS]
    for p in candidates:
        if p.exists() and p.stat().st_size > 0:
            return p
    # last resort: search anywhere under /content
    for p in Path("/content").rglob(FILENAME):
        if p.stat().st_size > 0:
            return p
    return None

In [106]:
# Locate the cleaned, model ready JSONL produced in Step 7
CLEAN_JSONL_PATH = resolve_clean_jsonl()

# Fail fast if the file is missing, tell the user exactly what to run
if CLEAN_JSONL_PATH is None:
    raise FileNotFoundError(
        f"Expected model-ready file named {FILENAME}. "
        f"Checked Drive {DRIVE_DIR} and local {LOCAL_DIRS}. Run Step 7 to create it."
    )

print(f"Using: {CLEAN_JSONL_PATH}")

# If the file exists only locally, copy a backup into Drive for persistence
if CLEAN_JSONL_PATH.parent != DRIVE_DIR:
    dest = DRIVE_DIR / FILENAME
    shutil.copy2(CLEAN_JSONL_PATH, dest)
    print(f"Copied a backup to Drive: {dest}")

Using: /content/drive/MyDrive/CUAD_v1/ner_outputs/jsonl_cuadv1_cleaned.json


In [107]:
# Load the cleaned JSONL file created in Step 7
# This file contains one JSON object per line, each with the contract text and its labeled clause spans,
# formatted for model training.
with open(CLEAN_JSONL_PATH, "r", encoding="utf-8") as f:
    cleaned_djson = [json.loads(line) for line in f]

print(f"Loaded {len(cleaned_djson)} records from cleaned, model-ready JSONL.")

Loaded 463 records from cleaned, model-ready JSONL.


In [108]:
# Get full label list and mappings
label_list = ["O"] + [f"{prefix}-{clause}" for clause in ["Agreement Date", "Parties", "Document Name"] for prefix in ["B", "I"]]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

In [109]:
## Helper to align labels with tokens
def align_labels_with_tokens(text, entities):
    """
    Convert entity character spans to BIO token labels aligned with tokenizer output.
    """
    tokenized = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=512)
    offsets = tokenized["offset_mapping"]
    labels = ["O"] * len(offsets)

    for ent in entities:
        start, end, label = ent["start"], ent["end"], ent["label"]
        entity_start = None

        for i, (token_start, token_end) in enumerate(offsets):
            if token_start == token_end:
                continue  # Skip special tokens like [CLS], [SEP]

            if token_end <= start:
                continue
            if token_start >= end:
                break

            if entity_start is None:
                labels[i] = f"B-{label}"
                entity_start = i
            else:
                labels[i] = f"I-{label}"

    label_ids = [label2id[label] for label in labels]
    tokenized["labels"] = label_ids
    return tokenized

In [110]:
## Apply to all samples
tokenized_dataset = []

for example in tqdm(cleaned_djson, desc="Tokenizing"):
    tokenized = align_labels_with_tokens(example["text"], example["entities"])
    tokenized_dataset.append(tokenized)

print(f"Tokenized {len(tokenized_dataset)} samples")
print(f"Example labels: {tokenized_dataset[0]['labels'][:30]}")
print(f" Example tokens: {tokenizer.convert_ids_to_tokens(tokenized_dataset[0]['input_ids'][:30])}")

Tokenizing: 100%|██████████| 463/463 [00:13<00:00, 33.35it/s]

Tokenized 463 samples
Example labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5]
 Example tokens: ['[CLS]', 'E', '##X', '##H', '##IB', '##IT', '10', '.', '3', 'T', '##RA', '##NS', '##PO', '##RT', '##AT', '##ION', 'SE', '##R', '##VI', '##CE', '##S', 'AG', '##RE', '##EM', '##EN', '##T', 'T', '##H', '##IS', 'MA']





# Step 9: Prepare Dataset and DataLoaders for Training

Now that we've tokenized our contract text and aligned the labels using the BIO format, we need to prepare the data for PyTorch training.

---

### Objective

Split the tokenized data into **training** and **validation** sets, and wrap them into DataLoader-compatible objects.

---

### What This Step Does

- Randomly shuffles the tokenized dataset
- Splits it into:
  - **Train set**: 90%
  - **Validation set**: 10%
- Converts each example into a format compatible with Hugging Face’s `Trainer` API:
  - Each item is a dictionary with:
    - `"input_ids"`: Tokenized input
    - `"attention_mask"`: Mask to ignore padding
    - `"labels"`: Token-level BIO labels
- Uses `DataCollatorForTokenClassification` to handle padding dynamically at batch time

---

### Why Use a Data Collator?

Hugging Face models expect **batches of equal-length tensors**, but legal contracts vary widely in length.

Instead of manually padding everything to 512 tokens:
- We use `DataCollatorForTokenClassification` to pad **on-the-fly** within each batch.
- This is efficient and avoids wasting computation on long sequences full of padding.

---

### Output

- `train_dataset` and `val_dataset` objects ready for model training
- Automatic batching, padding, and label alignment at training time

This sets us up for the next step: defining the model and starting training!


In [111]:
# Split tokenized dataset
train_data, val_data = train_test_split(tokenized_dataset, test_size=0.1, random_state=42)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

Training samples: 416
Validation samples: 47


In [112]:
## Hugging Face Trainer expects dicts, not lists
class NERDataset(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.examples[idx]["input_ids"]),
            "attention_mask": torch.tensor(self.examples[idx]["attention_mask"]),
            "labels": torch.tensor(self.examples[idx]["labels"]),
        }

train_dataset = NERDataset(train_data)
val_dataset = NERDataset(val_data)

In [113]:
## Collator for dynamic padding
data_collator = DataCollatorForTokenClassification(tokenizer)

# Sanity check
sample = train_dataset[0]
print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample["input_ids"][:20]))
print("Sample labels:", sample["labels"][:20])

Sample tokens: ['[CLS]', 'P', '##H', '##OT', '##O', 'R', '##ET', '##O', '##UC', '##H', '##ING', 'O', '##UT', '##SO', '##UR', '##CI', '##NG', 'AG', '##RE', '##EM']
Sample labels: tensor([0, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6])


# Step 10: Load Pretrained BERT Model for Token Classification

We now load a **pretrained BERT model** from Hugging Face and configure it for **token classification**.

---

### Why Use a Pretrained Model?

Instead of training from scratch (which would require millions of documents), we start with a model that already:

- Understands English syntax and semantics
- Knows how to tokenize and represent words and phrases
- Has been trained on a large corpus (e.g., books, Wikipedia)

We then **fine-tune** it on our legal contracts by teaching it to label tokens with clause types like `Agreement Date`, `Document Name`, and `Parties`.

---

### Configuration Details

- We're using `BertForTokenClassification`, which adds a classification head to BERT for each token.
- We specify the number of unique BIO labels.
- The model outputs logits for each token and each class during training.

---

### Output

A ready-to-train model that can accept tokenized input and output clause predictions for each token.


In [114]:
## Get number of unique labels
unique_labels = sorted({label for sample in tokenized_dataset for label in sample["labels"]})
num_labels = len(unique_labels)

print(f"Number of unique labels (including O, B-, I- tags): {num_labels}")

Number of unique labels (including O, B-, I- tags): 7


In [115]:
## Load pretrained model for token classification
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

print(f"Loaded {MODEL_NAME} with {len(label_list)} labels")

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded bert-base-cased with 7 labels


# Step 11: Define TrainingArguments, Evaluation Metrics, and Initialize the Trainer

The Hugging Face `Trainer` class handles the full training loop for us — all we need to do is provide:

1. The model
2. Training arguments (hyperparameters, evaluation strategy, batch size, etc.)
3. Data (train and validation sets)
4. A data collator for batching/padding
5. A function to compute evaluation metrics (defined in the next step)

---

### What are `TrainingArguments`?

This is a configuration object that controls:
- How long to train (`num_train_epochs`)
- Batch sizes
- Evaluation strategy (e.g., evaluate every epoch)
- Logging frequency
- Learning rate schedule
- Whether to save the best model

---

### Why Use `Trainer`?

`Trainer` simplifies everything:
- It handles batching, GPU usage, gradient updates
- Evaluates at set intervals
- Loads the best model after training
- Integrates tightly with Hugging Face token classification

---

### Output

A `Trainer` object ready to begin training and evaluation.

## Step 11A:Define Training Args

In [116]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,                         # Where to save model checkpoints
    num_train_epochs=EPOCHS,                       # Defined in Step 1
    per_device_train_batch_size=BATCH_SIZE,        # From Step 1
    per_device_eval_batch_size=BATCH_SIZE,
    lr_scheduler_type=LR_SCHEDULER,                # From Step 1
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    eval_strategy="epoch",                         # Evaluate after each epoch
    logging_strategy="steps",                      # Log loss every few steps
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=False,                  # no checkpoints, so do not try to load "best"
    overwrite_output_dir=True,
    report_to="none",
    save_safetensors=True,
)

## Step 11B: Define Evaluation Metrics (F1, Precision, Recall)

Now that our `Trainer` is initialized, we need to define how model performance is evaluated.

---

### What Metrics Are We Using?

For token classification, we use:
- **Precision**: How many predicted entities were correct?
- **Recall**: How many actual entities did we successfully detect?
- **F1 Score**: The harmonic mean of precision and recall (our main metric)

---

### How Does This Work?

- Hugging Face's `Trainer` will call `compute_metrics()` after each evaluation step.
- It passes in `EvalPrediction`, which contains:
  - `predictions`: predicted token-level label IDs
  - `label_ids`: ground truth token-level label IDs

---

### BIO Tagging Reminder

We're using BIO-format labels like `B-Parties`, `I-Document Name`, etc.  
So the F1 score refers to **entity-level matches**, not just individual token tags.

---

### Output

A dictionary containing:
- `precision`
- `recall`
- `f1`


In [117]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = predictions.argmax(axis=-1)

    true_labels = []
    pred_labels = []

    for pred, label in zip(preds, labels):
        current_true = []
        current_pred = []
        for p, l in zip(pred, label):
            if l != -100:  # ignore special tokens
                current_true.append(id2label[l])
                current_pred.append(id2label[p])
        true_labels.append(current_true)
        pred_labels.append(current_pred)

    # Use seqeval metrics directly with lists-of-lists
    precision = precision_score(true_labels, pred_labels)
    recall = recall_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

## Step 11C: Initialize Trainer

In [118]:
trainer = Trainer(
    model=model,                               # The DistilBERT model for token classification
    args=training_args,                        # Training configuration defined above
    train_dataset=train_dataset,               # Your training data
    eval_dataset=val_dataset,                  # Your validation data
    tokenizer=tokenizer,                       # Tokenizer for decoding input tokens
    data_collator=data_collator,               # Handles padding and label alignment
    compute_metrics=compute_metrics,           # Custom evaluation function (returns precision, recall, F1, accuracy)


    # Stops after 1 evaluation round (epoch) of no improvement to validation F1
)

  trainer = Trainer(


# Step 12: Train the Model

Now that everything is set up — the dataset, model, training arguments, and evaluation metrics — we’re ready to fine-tune our pre-trained BERT model on the CUAD clause spans.

---

### What Happens Here?

- The Hugging Face `Trainer` API handles:
  - Feeding batches of tokenized inputs and aligned labels to the model
  - Logging loss every 10 steps
  - Evaluating after each epoch using our custom `compute_metrics()` function
  - Saving the model with the best F1 score to `outputs/`

---

### How Long Will This Take?

Training time depends on:
- The number of epochs (`EPOCHS`)
- The size of the model (`MODEL_NAME`)
- Your Colab hardware

From my experience with T4 4–6 epochs and a lightweight model like `distilbert-base-uncased`, it typically takes 80–120 minutes.

---

### Output

- A fine-tuned model saved in the `outputs/` directory
- Training logs (loss, F1, precision, recall)


In [119]:
# Train with Hugging Face Trainer
train_result = trainer.train()

  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,0.123314,0.0,0.0,0.0
2,0.434000,0.072484,0.351562,0.473684,0.403587
3,0.434000,0.064495,0.302857,0.557895,0.392593
4,0.057600,0.066404,0.324468,0.642105,0.431095
5,0.057600,0.058865,0.443709,0.705263,0.544715
6,0.035300,0.058881,0.433333,0.684211,0.530612


  return forward_call(*args, **kwargs)


In [120]:
# Drop intermediate checkpoints
for p in Path(OUTPUT_DIR).glob("checkpoint-*"):
    shutil.rmtree(p, ignore_errors=True)

# Drop optimizer state, scheduler, trainer state, args bin
for fname in ["optimizer.pt", "scheduler.pt", "trainer_state.json", "training_args.bin"]:
    fpath = Path(OUTPUT_DIR) / fname
    if fpath.exists():
        os.remove(fpath)

In [121]:
# Save only what is needed for inference
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
trainer.save_model(OUTPUT_DIR)         # config.json + model.safetensors
tokenizer.save_pretrained(OUTPUT_DIR)  # tokenizer files

# Free some VRAM and temporary storage
torch.cuda.empty_cache()
shutil.rmtree("/content/tmp_training", ignore_errors=True)

print("Training complete. Saved essentials to:", OUTPUT_DIR)

Training complete. Saved essentials to: /content/drive/MyDrive/CUAD/CUAD_v1/model_outputs


# Step 13: Evaluate on the Validation Set and Inspect Predictions

In this step, we evaluate the fine-tuned model on the held-out validation dataset.

### **What this does:**
- Uses `trainer.predict(val_dataset)` to get logits and true labels.
- Converts logits to predicted BIO tags and aligns them with the non-padding tokens (skips `-100` labels).
- Computes entity-level Precision, Recall, and F1 using `seqeval` on lists of tag sequences.
- Prints a compact classification report by label.
- Shows a readable preview for the first validation example, pairing tokens with their true and predicted tags.

### **Why this matters:**
- Validation metrics track generalization performance, not just training loss.
- The token preview helps spot systematic errors, for example, broken BIO spans or confusion between clauses like `Document Name` and `Parties`.

### **Output:**
- A printed table of Precision, Recall, and F1 per label.
- A short token-level preview for manual inspection.


In [122]:
# 1) Run prediction on the validation set
pred_output = trainer.predict(val_dataset)
logits = pred_output.predictions
label_ids = pred_output.label_ids

In [123]:
# 2) Convert logits to predicted ids
pred_ids = np.argmax(logits, axis=-1)

In [124]:
# 3) Convert ids to tag strings, ignoring padding (-100)
# Inputs: pred_ids, label_ids, id2label

true_seqs, pred_seqs = [], []

for p_row, l_row in zip(pred_ids, label_ids):
    t_seq, p_seq = [], []
    for p, l in zip(p_row, l_row):
        if l == -100:  # skip padding or specials
            continue
        t_seq.append(id2label[int(l)])   # true tag
        p_seq.append(id2label[int(p)])   # predicted tag
    true_seqs.append(t_seq)
    pred_seqs.append(p_seq)

# true_seqs and pred_seqs are now token aligned per example, ready for seqeval

In [125]:
# 4) Print a seqeval classification report
print("Validation classification report (entity-level):")
print(seqeval_report(true_seqs, pred_seqs, digits=4))

Validation classification report (entity-level):
                precision    recall  f1-score   support

Agreement Date     0.7419    0.9583    0.8364        24
 Document Name     0.5345    0.7381    0.6200        42
       Parties     0.1803    0.3793    0.2444        29

     micro avg     0.4333    0.6842    0.5306        95
     macro avg     0.4856    0.6919    0.5669        95
  weighted avg     0.4788    0.6842    0.5600        95



In [126]:
# 5) Quick label distribution check
flat_true = [t for seq in true_seqs for t in seq]
flat_pred = [t for seq in pred_seqs for t in seq]
print("Label distribution, true :", Counter(flat_true))
print("Label distribution, pred :", Counter(flat_pred))

Label distribution, true : Counter({'O': 22917, 'I-Document Name': 483, 'I-Parties': 229, 'I-Agreement Date': 72, 'B-Document Name': 42, 'B-Parties': 29, 'B-Agreement Date': 24})
Label distribution, pred : Counter({'O': 22744, 'I-Document Name': 590, 'I-Parties': 263, 'I-Agreement Date': 92, 'B-Document Name': 45, 'B-Parties': 33, 'B-Agreement Date': 29})


In [127]:
# 6) Human-readable preview for the first validation example
example_idx = 0
tokens = tokenizer.convert_ids_to_tokens(val_dataset[example_idx]["input_ids"])
true_tags = true_seqs[example_idx]
pred_tags = pred_seqs[example_idx]

# Keep the same length after removing specials; align by skipping specials in tokens
# Fast tokenizers keep special tokens in the ids. We skip them by checking the label list length.
# The labels already exclude specials, so we trim tokens accordingly.
trimmed_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens][:len(true_tags)]

print("\nPreview of first validation example:")
for tok, tl, pl in zip(trimmed_tokens[:80], true_tags[:80], pred_tags[:80]):
    print(f"{tok:15} true: {tl:20} pred: {pl}")


Preview of first validation example:
Ex              true: O                    pred: O
##hibit         true: O                    pred: O
28              true: O                    pred: O
(               true: O                    pred: O
h               true: O                    pred: O
)               true: O                    pred: O
(               true: O                    pred: O
1               true: O                    pred: O
)               true: O                    pred: O
(               true: O                    pred: O
a               true: O                    pred: O
)               true: O                    pred: O
under           true: O                    pred: O
Form            true: O                    pred: O
N               true: O                    pred: O
‐               true: O                    pred: O
1A              true: O                    pred: O
Ex              true: O                    pred: O
##hibit         true: O                    p

# Step 14: Inference on Unseen PDFs

In this step, we run the fine-tuned model on new, unseen PDF contracts to extract predicted clause spans.

### **What this does:**
- Reads only unseen PDFs from `NEW_PDFS_PATH`.  
  If the folder is empty, it raises a **FileNotFoundError** and stops (no fallback).
- Loads the tokenizer and model from `OUTPUT_DIR` to ensure label consistency.
- Extracts and lightly cleans text from each PDF.
- Applies sliding-window inference with stride to handle long documents without truncating clauses.
- Decodes BIO tags back into contiguous character spans.
- Collects predictions into a DataFrame with:  
  `filename`, `source_dir`, `clause`, `start`, `end`, and `pred_text`.
- Writes results to `OUTPUT_DIR/predictions.csv` for later cleaning and summarization (Step 14B).

### **Why this matters:**
- Allows you to evaluate the model on real-world, unseen contracts.
- The sliding-window approach prevents missing entities that cross chunk boundaries.
- Saving to CSV enables downstream cleanup, deduplication, and reporting.

### **Output:**
- `predictions.csv` in `OUTPUT_DIR` containing one row per predicted span.
- Columns include filename, clause label, character positions, and extracted text.


In [128]:
# 0) Input folder with unseen PDFs
NEW_PDFS_PATH = os.path.join(MASTER_PATH, "new_pdfs")

pdf_paths = sorted(Path(NEW_PDFS_PATH).rglob("*.pdf"))
print(f"Found {len(pdf_paths)} unseen PDFs in {NEW_PDFS_PATH}")
for p in pdf_paths[:10]:
    print(" -", p.name)

if len(pdf_paths) == 0:
    raise FileNotFoundError(
        f"No PDFs found in {NEW_PDFS_PATH}. Add files there, or change NEW_PDFS_PATH."
    )
source_dir = NEW_PDFS_PATH

Found 5 unseen PDFs in /content/drive/MyDrive/CUAD/CUAD_v1/new_pdfs
 - ex10-4.pdf
 - ex101.pdf
 - ex21to8k11125001_02222018.pdf
 - exh10-21.pdf
 - rathgibson0303088kex101.pdf


In [129]:
# 1) Load tokenizer and model from OUTPUT_DIR
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)           # load the saved tokenizer
model = AutoModelForTokenClassification.from_pretrained(OUTPUT_DIR)  # load the fine tuned model
model.eval()                                                     # switch to eval mode

# choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)                                                 # move model to GPU if available

# label maps
id2label = model.config.id2label                                 # int id -> tag string
label2id = {v: k for k, v in id2label.items()}                   # tag string -> int id

In [130]:
# 2) Utilities: light text clean, BIO decode, sliding-window prediction
def pre_process_doc_common(text: str) -> str:
    text = text.replace("\n", " ").replace("\xa0", " ").replace("\x0c", " ")
    text = re.sub(r"\ \.\ ", ".", text)
    text = re.sub(r"_", " ", text)
    text = re.sub(r"--+", " ", text)
    text = re.sub(r"\*+", "*", text)
    text = re.sub(r"\ +", " ", text)
    return text.strip()

In [131]:
def extract_text_from_pdf(path: Path) -> str:
    """Open a PDF, extract plain text, then apply light cleaning."""
    # open with a context manager to ensure the file is closed
    with fitz.open(str(path)) as doc:
        txt = "".join(page.get_text("text") for page in doc)  # page by page text
    return pre_process_doc_common(txt)  # normalize whitespace and artifacts

In [132]:
def decode_bio_to_spans(offsets, pred_ids):
    """Convert BIO tag ids to contiguous character spans.
    offsets: list of (start_char, end_char) per token
    pred_ids: list of predicted label ids aligned to offsets
    Uses id2label in scope.
    """
    spans = []
    active = None  # current span as (label, start_char, end_char)

    for (s, e), tag_id in zip(offsets, pred_ids):
        if s == e:  # empty offset, skip
            continue

        tag = id2label[int(tag_id)]
        if tag == "O":
            # close any open span on outside
            if active:
                spans.append({"start": active[1], "end": active[2], "label": active[0]})
                active = None
            continue

        prefix, ent = tag.split("-", 1)  # "B-CLAUSE" -> ("B", "CLAUSE")

        if prefix == "B":
            # start a new span, closing the previous one if needed
            if active:
                spans.append({"start": active[1], "end": active[2], "label": active[0]})
            active = (ent, s, e)
        else:  # prefix == "I"
            # continue same entity if consistent, otherwise start a new one
            if active and active[0] == ent:
                active = (ent, active[1], e)
            else:
                active = (ent, s, e)

    # flush any remaining open span
    if active:
        spans.append({"start": active[1], "end": active[2], "label": active[0]})

    return spans

In [133]:
def predict_spans_long_text(
    text,
    max_length=256,
    stride=128,
    char_chunk=4000,
    back_overlap=2000
):
    """Run NER over long text by sliding windows and stitch spans back together."""
    spans_all = []
    start_char = 0

    while start_char < len(text):
        # slice a character chunk and tokenize with offsets
        end_char = min(len(text), start_char + char_chunk)
        chunk = text[start_char:end_char]

        inputs = tokenizer(
            chunk,
            return_offsets_mapping=True,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            stride=stride,
            padding=False,
        )
        offsets = inputs.pop("offset_mapping")[0].tolist()
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # forward pass and argmax ids
        with torch.no_grad():
            logits = model(**inputs).logits[0]
            pred_ids = torch.argmax(logits, dim=-1).detach().cpu().numpy()

        # make offsets absolute, then decode BIO into spans
        abs_offsets = [(s + start_char, e + start_char) for (s, e) in offsets]
        spans_all.extend(decode_bio_to_spans(abs_offsets, pred_ids))

        # advance window with overlap to avoid cutting entities
        if end_char == len(text):
            break
        start_char = max(0, end_char + 0 - back_overlap)

    return spans_all

In [134]:
# 3) Run inference, collect rows
rows = []

for pdf_path in tqdm(pdf_paths, desc="Running inference"):
    try:
        text = extract_text_from_pdf(pdf_path)  # read and lightly clean PDF text
        spans = predict_spans_long_text(text, max_length=256, stride=128)  # BIO -> spans
        for sp in spans:
            rows.append({
                "filename": pdf_path.name,
                "source_dir": source_dir,
                "clause": sp["label"],
                "start": sp["start"],
                "end": sp["end"],
                "pred_text": text[sp["start"]:sp["end"]],  # slice original text for the span
            })
    except Exception as e:
        # capture failures so a bad file does not stop the run
        rows.append({
            "filename": pdf_path.name,
            "source_dir": source_dir,
            "clause": "ERROR",
            "start": -1,
            "end": -1,
            "pred_text": f"{type(e).__name__}: {e}",
        })

pred_df = pd.DataFrame(rows)  # one row per predicted span or error

Running inference: 100%|██████████| 5/5 [00:03<00:00,  1.26it/s]


In [135]:
# 4) Save predictions
os.makedirs(OUTPUT_DIR, exist_ok=True)
out_csv = os.path.join(OUTPUT_DIR, "predictions.csv")
pred_df.to_csv(out_csv, index=False, encoding="utf-8")

print(f"\nWrote {len(pred_df)} rows to {out_csv}")
print("Run Step 14B next to clean duplicates and build a per-file summary.")


Wrote 50 rows to /content/drive/MyDrive/CUAD/CUAD_v1/model_outputs/predictions.csv
Run Step 14B next to clean duplicates and build a per-file summary.


# Step 14B: Clean, De-duplicate, and Summarize Predictions

This step tidies the raw predictions from Step 14.

### What this does:
- Removes tiny spans and exact duplicates within the same file and clause.
- Merges overlapping spans for the same file and clause, keeping the widest span.
- Builds a compact summary per file with the top predicted span for each clause.
- Saves two files:
  - `predictions_clean.csv`  cleaned, de-duplicated spans
  - `predictions_summary.csv`  one row per file with best span per clause

### Why this matters:
Raw sliding-window inference can produce near-duplicates. A quick cleanup makes QA easier and gives you one candidate span per clause per file.

In [136]:
# Control saving behavior for this step
SAVE = False  # set to True to write cleaned CSVs

# Load raw predictions produced in Step 14
PRED_PATH = os.path.join(OUTPUT_DIR, "predictions.csv")
assert os.path.exists(PRED_PATH), f"Missing {PRED_PATH}. Run Step 14 first."

In [137]:
# 1) Load and basic cleanup
# Load the raw span-level predictions from Step 14's output CSV
# Each row represents one predicted clause span with filename, start/end positions, and extracted text
df = pd.read_csv(PRED_PATH)

# Remove rows where inference failed (clause == "ERROR")
df = df[df["clause"] != "ERROR"].copy()

# Compute span length in characters, ensure no negative values
df["length"] = (df["end"] - df["start"]).clip(lower=0)

# Drop very short spans (<5 chars) which are usually noise or tokenization artifacts
df = df[df["length"] >= 5].copy()

In [138]:
# 2) Merge overlapping spans within each filename+clause, keep widest text
def merge_overlaps(g: pd.DataFrame) -> pd.DataFrame:
    """Merge overlapping character spans in a group and keep the longest text for the merged span.
    Expects columns: start, end, pred_text.
    """
    if g.empty:
        return g

    # sort so overlaps are adjacent
    g = g.sort_values(["start", "end"]).reset_index(drop=True)

    merged = []
    cur_s, cur_e = g.loc[0, "start"], g.loc[0, "end"]
    cur_texts = [g.loc[0, "pred_text"]]

    for i in range(1, len(g)):
        s, e = g.loc[i, "start"], g.loc[i, "end"]
        if s <= cur_e:
            # extend current window
            cur_e = max(cur_e, e)
            cur_texts.append(g.loc[i, "pred_text"])
        else:
            # close current window, keep the longest candidate text
            merged.append((cur_s, cur_e, max(cur_texts, key=len)))
            # start a new window
            cur_s, cur_e, cur_texts = s, e, [g.loc[i, "pred_text"]]

    # flush last window
    merged.append((cur_s, cur_e, max(cur_texts, key=len)))

    return pd.DataFrame(merged, columns=["start", "end", "pred_text"])

In [139]:
# Group by file and clause, merge overlaps per group, collect cleaned spans
clean_parts = []


for (fname, clause), g in df.groupby(["filename", "clause"], sort=False):
    m = merge_overlaps(g[["start", "end", "pred_text"]])  # collapse overlapping spans
    m["filename"] = fname
    m["clause"] = clause
    clean_parts.append(m)

In [140]:
# Combine all per group results into a single DataFrame of cleaned spans
# Each row = one merged span with columns: filename, clause, start, end, pred_text
clean_df = pd.concat(clean_parts, ignore_index=True)

# Compute the character length of each merged span
clean_df["length"] = clean_df["end"] - clean_df["start"]

In [141]:
# 3) Per file summary: take the single longest span per clause, then pivot to wide format
summary = (
    clean_df
      # sort so the first row per group is the longest span
      .sort_values(["filename", "clause", "length"], ascending=[True, True, False])
      # collapse to one row per filename+clause (keeps the longest pred_text)
      .groupby(["filename", "clause"], as_index=False)
      .first()
)

# Wide table: one row per file, one column per clause, cell contains the chosen pred_text
summary_pivot = (
    summary
      .pivot(index="filename", columns="clause", values="pred_text")
      .reset_index()
)

In [142]:
# 4) Optionally save artifacts
if SAVE:
    clean_path = os.path.join(OUTPUT_DIR, "predictions_clean.csv")
    summary_path = os.path.join(OUTPUT_DIR, "predictions_summary.csv")
    clean_df[["filename", "clause", "start", "end", "length", "pred_text"]].to_csv(clean_path, index=False, encoding="utf-8")
    summary_pivot.to_csv(summary_path, index=False, encoding="utf-8")
    print(f"Saved cleaned predictions to {clean_path}  and  summary to {summary_path}")

In [143]:
# 5) Light preview
print("Cleaned rows:", len(clean_df), " Unique files:", clean_df['filename'].nunique())
display(summary_pivot.head(5))

Cleaned rows: 42  Unique files: 4


clause,filename,Agreement Date,Document Name,Parties
0,ex101.pdf,"November 13, 2012",ASSET PURCHASE AGREEMENT,"P Electric Vehicles, Inc."
1,ex21to8k11125001_02222018.pdf,"October 30, 2017",ASSET PURCHASE AGREEMENT,"Holdings, Inc."
2,exh10-21.pdf,"March 1, 2007",LICENSE AGREEMENT,"ps Company, Inc."
3,rathgibson0303088kex101.pdf,"February 27, 2008",REPRESENTATIONS AND WARRANTIES OF,"RathGibson, Inc."
