## 1. Setup
- Install/import libraries
- Set seeds
- GPU check

In [3]:
# If needed, install extras (uncomment)
# %pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
# %pip install pandas numpy matplotlib scikit-learn tqdm

import json
import math
import os
import random
from collections import Counter, defaultdict
from pathlib import Path
from typing import List, Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
torch.backends.cudnn.benchmark = torch.cuda.is_available()

Using device: cpu


  from .autonotebook import tqdm as notebook_tqdm


## 2. Load CUAD Dataset
Assumes CUAD v1 JSON files are in `data/cuad/`.
- `CUAD_v1.json` (contracts)
- `CUAD_v1_questions.json` (41 clause types)
- `CUAD_v1_annotations.json` (labels)

In [None]:
DATA_DIR = Path('data/cuad')
CONTRACTS_PATH = DATA_DIR / 'CUAD_v1.json'
QUESTIONS_PATH = DATA_DIR / 'CUAD_v1_questions.json'
ANNOTATIONS_PATH = DATA_DIR / 'CUAD_v1_annotations.json'

def load_json(path: Path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

contracts = load_json(CONTRACTS_PATH)  # list of dicts with 'doc_id' and 'text'
questions = load_json(QUESTIONS_PATH)  # list of clause definitions
annotations = load_json(ANNOTATIONS_PATH)  # list of annotation objects

print(f'Contracts: {len(contracts)} | Questions: {len(questions)} | Annotations: {len(annotations)}')

clause_labels = [q['question_text'] for q in questions]
label_to_idx = {label: i for i, label in enumerate(clause_labels)}
num_labels = len(clause_labels)
print('Num labels:', num_labels)

# Build doc_id -> multi-hot labels
doc_to_labels: Dict[str, List[int]] = defaultdict(lambda: [0] * num_labels)
for ann in annotations:
    doc_id = ann['doc_id']
    q_text = ann['question_text']
    if q_text not in label_to_idx:
        continue
    doc_to_labels[doc_id][label_to_idx[q_text]] = 1

# Merge contracts with labels
records = []
for c in contracts:
    doc_id = c['doc_id']
    text = c.get('text', '')
    labels = doc_to_labels[doc_id]
    records.append({'doc_id': doc_id, 'text': text, 'labels': labels})

df = pd.DataFrame(records)
print(df.head())
print('Label density (avg positives per doc):', np.mean([sum(l) for l in df.labels]))

## 3. Head–Tail Sampling Function
Take first 1000 and last 1000 words, insert `[HT_SPLIT]`.

In [None]:
HT_SPLIT = '[HT_SPLIT]'
HEAD_N = 1000
TAIL_N = 1000
MAX_SEQ_WORDS = HEAD_N + TAIL_N + 1  # include split token

def simple_word_tokenize(text: str) -> List[str]:
    return text.split()

def head_tail_sample(text: str) -> List[str]:
    tokens = simple_word_tokenize(text)
    head = tokens[:HEAD_N]
    tail = tokens[-TAIL_N:] if len(tokens) >= TAIL_N else tokens[-len(tokens):]
    return head + [HT_SPLIT] + tail

# Preview on first contract
sample_tokens = head_tail_sample(df.iloc[0].text)
print('Sample length:', len(sample_tokens))
print('Snippet:', ' '.join(sample_tokens[:30]), '...')

## 4. Tokenization + Numerical Encoding
Build vocab from scratch and encode text; pad/truncate sequences.

In [None]:
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
VOCAB_MIN_FREQ = 2
MAX_SEQ_LEN = 2000 + 1  # should match MAX_SEQ_WORDS but kept explicit

def build_vocab(tokenized_texts: List[List[str]], min_freq: int = VOCAB_MIN_FREQ) -> Dict[str, int]:
    counter = Counter()
    for toks in tokenized_texts:
        counter.update(toks)
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for tok, freq in counter.items():
        if freq >= min_freq and tok not in vocab:
            vocab[tok] = len(vocab)
    return vocab

def encode_tokens(tokens: List[str], vocab: Dict[str, int], max_len: int = MAX_SEQ_LEN) -> List[int]:
    ids = [vocab.get(tok, vocab[UNK_TOKEN]) for tok in tokens[:max_len]]
    if len(ids) < max_len:
        ids += [vocab[PAD_TOKEN]] * (max_len - len(ids))
    return ids

# Build tokenized corpus
df['tokens'] = df.text.apply(head_tail_sample)
vocab = build_vocab(df.tokens.tolist(), min_freq=VOCAB_MIN_FREQ)
vocab_size = len(vocab)
print('Vocab size:', vocab_size)

df['input_ids'] = df.tokens.apply(lambda t: encode_tokens(t, vocab, MAX_SEQ_LEN))
print('Encoded example length:', len(df.iloc[0].input_ids))

## 5. Dataset + DataLoader
Custom dataset, 70/15/15 split, DataLoaders.

In [None]:
class CuadDataset(Dataset):
    def __init__(self, df_slice: pd.DataFrame):
        self.inputs = df_slice.input_ids.tolist()
        self.labels = df_slice.labels.tolist()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        x = torch.tensor(self.inputs[idx], dtype=torch.long)
        y = torch.tensor(self.labels[idx], dtype=torch.float)
        return x, y

train_df, temp_df = train_test_split(df, test_size=0.30, random_state=SEED, shuffle=True)
val_df, test_df = train_test_split(temp_df, test_size=0.50, random_state=SEED, shuffle=True)

BATCH_SIZE = 8

train_loader = DataLoader(CuadDataset(train_df), batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader = DataLoader(CuadDataset(val_df), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
test_loader = DataLoader(CuadDataset(test_df), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

len(train_loader), len(val_loader), len(test_loader)

## 6. Model Architecture (Stacked LSTM)
Embedding from scratch, 2–3 LSTM layers, dropout, sigmoid output.

In [None]:
class StackedLSTMClassifier(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers: int, num_labels: int, dropout: float = 0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0, bidirectional=False)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_labels)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        emb = self.embedding(x)
        out, _ = self.lstm(emb)
        pooled = out[:, -1, :]  # last timestep
        logits = self.fc(self.dropout(pooled))
        return self.activation(logits)

model = StackedLSTMClassifier(
    vocab_size=vocab_size,
    embed_dim=200,
    hidden_dim=256,
    num_layers=2,
    num_labels=num_labels,
    dropout=0.3,
).to(device)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

## 7. Training Loop
Binary cross-entropy, optimizer choices (SGD/Adam/RMSprop), checkpoint best val F1.

In [None]:
def choose_optimizer(name: str, params, lr: float):
    name = name.lower()
    if name == 'sgd':
        return torch.optim.SGD(params, lr=lr, momentum=0.9)
    if name == 'rmsprop':
        return torch.optim.RMSprop(params, lr=lr)
    return torch.optim.Adam(params, lr=lr)

OPTIMIZER_NAME = 'adam'
LR = 1e-3
EPOCHS = 5

optimizer = choose_optimizer(OPTIMIZER_NAME, model.parameters(), LR)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for xb, yb in tqdm(loader, leave=False):
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    return total_loss / len(loader.dataset)

def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb)
            loss = criterion(preds, yb)
            total_loss += loss.item() * xb.size(0)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(yb.cpu().numpy())
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    return total_loss / len(loader.dataset), all_preds, all_labels

def threshold_outputs(preds, thresh=0.5):
    return (preds >= thresh).astype(int)

best_val_f1 = 0.0
history = {'train_loss': [], 'val_loss': [], 'val_f1': []}
BEST_PATH = 'best_model.pt'

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_preds, val_labels = eval_epoch(model, val_loader, criterion, device)
    val_bin = threshold_outputs(val_preds)
    val_f1 = f1_score(val_labels, val_bin, average='micro', zero_division=0)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)

    print(f'Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_f1={val_f1:.4f}')

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), BEST_PATH)
        print('Saved new best model')

plt.figure(figsize=(6,4))
plt.plot(history['train_loss'], label='train_loss')
plt.plot(history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(6,4))
plt.plot(history['val_f1'], label='val_f1')
plt.xlabel('Epoch')
plt.ylabel('F1 (micro)')
plt.legend()
plt.show()

## 8. Evaluation
Per-class metrics and overall scores; optional confusion matrix.

In [None]:
from pathlib import Path

# Load best model if checkpoint exists (otherwise use current weights)
if Path(BEST_PATH).exists():
    model.load_state_dict(torch.load(BEST_PATH, map_location=device))
    print(f"Loaded checkpoint: {BEST_PATH}")
else:
    print(f"Checkpoint {BEST_PATH} not found. Using current model weights.")

model.to(device)

test_loss, test_preds, test_labels = eval_epoch(model, test_loader, criterion, device)
test_bin = threshold_outputs(test_preds)

overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(
    test_labels, test_bin, average='micro', zero_division=0
)
overall_acc = accuracy_score(test_labels.flatten(), test_bin.flatten())

print(f"Test loss: {test_loss:.4f}")
print(
    f"Accuracy: {overall_acc:.4f}  Precision: {overall_precision:.4f}  "
    f"Recall: {overall_recall:.4f}  F1: {overall_f1:.4f}"
)

report = classification_report(test_labels, test_bin, target_names=clause_labels, zero_division=0)
print(report)

# Optional: multilabel confusion matrix (sklearn supports per-class)
from sklearn.metrics import multilabel_confusion_matrix
cm = multilabel_confusion_matrix(test_labels, test_bin)
print('Confusion matrix shape:', cm.shape)

## 9. Ablation Setup (Optional)
Compare head-only, tail-only, head–tail, simple truncation.

In [None]:
def make_variant_tokens(text: str, variant: str) -> List[str]:
    toks = simple_word_tokenize(text)
    if variant == 'head':
        return toks[:HEAD_N]
    if variant == 'tail':
        return toks[-TAIL_N:]
    if variant == 'truncate':
        return toks[:MAX_SEQ_WORDS]
    return head_tail_sample(text)

def prepare_variant_df(df_in: pd.DataFrame, variant: str) -> Tuple[pd.DataFrame, Dict[str, int]]:
    df_copy = df_in.copy()
    df_copy['tokens'] = df_copy.text.apply(lambda t: make_variant_tokens(t, variant))
    vocab_var = build_vocab(df_copy.tokens.tolist(), min_freq=VOCAB_MIN_FREQ)
    df_copy['input_ids'] = df_copy.tokens.apply(lambda t: encode_tokens(t, vocab_var, MAX_SEQ_LEN))
    return df_copy, vocab_var

# Example: head-only dataset (demo; re-train with same pipeline for proper comparison)
head_df, head_vocab = prepare_variant_df(df, 'head')
print('Head-only vocab size:', len(head_vocab))

## 10. Inference on a Sample Contract
Load best checkpoint, run prediction for a single contract string.

In [None]:
def predict_text(text: str, model: nn.Module, vocab: Dict[str, int], threshold: float = 0.5):
    tokens = head_tail_sample(text)
    ids = torch.tensor([encode_tokens(tokens, vocab, MAX_SEQ_LEN)], dtype=torch.long).to(device)
    model.eval()
    with torch.no_grad():
        probs = model(ids).cpu().numpy()[0]
    pred_idxs = [i for i, p in enumerate(probs) if p >= threshold]
    return [(clause_labels[i], probs[i]) for i in pred_idxs]

sample_text = df.iloc[0].text
preds = predict_text(sample_text, model, vocab, threshold=0.5)
print('Predicted clauses:')
for label, prob in preds:
    print(f'{label}: {prob:.3f}')