In [None]:
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


#### Load Dataset

In [None]:
PATH = '../data'

# Read all CSV files and combine
all_data = []
for file in os.listdir(PATH):
    if file.endswith(".csv"):
        df = pd.read_csv(os.path.join(PATH, file))
        df["source_file"] = file
        all_data.append(df)

data = pd.concat(all_data, ignore_index=True)
print("Total clauses:", len(data))
print("Unique clause types:", data["clause_type"].nunique())
data.head()


Total clauses: 150881
Unique clause types: 395


Unnamed: 0,clause_text,clause_type,source_file
0,Certain Definitions. For purposes of this Agre...,certain-definitions,certain-definitions.csv
1,Certain Definitions. As used in this Agreement...,certain-definitions,certain-definitions.csv
2,Certain Definitions. For purposes of this Agre...,certain-definitions,certain-definitions.csv
3,Certain Definitions. For purposes of this Agre...,certain-definitions,certain-definitions.csv
4,Certain Definitions. As used in this Agreement...,certain-definitions,certain-definitions.csv


In [None]:
def get_shape():
    for file in os.listdir(PATH):
        if file.endswith(".csv"):
            csv_file = os.path.join(PATH, file)
            df = pd.read_csv(csv_file)
            print(f"{csv_file} | Shape: {df.shape}") 

In [None]:
def clean_text(t):
    t = str(t).strip().lower()
    t = " ".join(t.split())
    return t

data["clause_text"] = data["clause_text"].apply(clean_text)
data = data.dropna(subset=["clause_text", "clause_type"]).reset_index(drop=True)


#### Build Pair Dataset (positive & negative)

In [None]:
def make_pairs(df, num_pos=1, num_neg=1, seed=42):
    random.seed(seed)
    pairs = []
    labels = []
    grouped = df.groupby("clause_type")["clause_text"].apply(list).to_dict()
    types = list(grouped.keys())

    for t in types:
        examples = grouped[t]
        for ex in examples:
            # Positive pairs
            for _ in range(num_pos):
                pos = random.choice(examples)
                if pos != ex:
                    pairs.append((ex, pos))
                    labels.append(1)
            # Negative pairs
            for _ in range(num_neg):
                neg_type = random.choice(types)
                while neg_type == t:
                    neg_type = random.choice(types)
                neg = random.choice(grouped[neg_type])
                pairs.append((ex, neg))
                labels.append(0)

    pair_df = pd.DataFrame({"text1": [p[0] for p in pairs],
                            "text2": [p[1] for p in pairs],
                            "label": labels})
    return pair_df

pair_df = make_pairs(data, num_pos=1, num_neg=1)
print("Total pairs:", len(pair_df))
pair_df.head()


Total pairs: 301350


Unnamed: 0,text1,text2,label
0,absence of certain changes. there have been no...,absence of certain changes. except as disclose...,1
1,absence of certain changes. there have been no...,compensation and benefits. for all services re...,0
2,absence of certain changes. since september 30...,absence of certain changes. as of the closing ...,1
3,absence of certain changes. since september 30...,"erisa. borrower shall, and shall cause each of...",0
4,"absence of certain changes. since december 31,...",absence of certain changes. except as disclose...,1


#### Tokenization and Vocabulary

In [None]:
from collections import Counter
import re

def tokenize(text):
    return re.findall(r"\b\w+\b", text.lower())

# Build vocab
token_counts = Counter()
for t in tqdm(data["clause_text"], desc="Building vocab"):
    token_counts.update(tokenize(t))

min_freq = 3
vocab = ["<pad>", "<unk>"] + [w for w, c in token_counts.items() if c >= min_freq]
stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}

def encode(text, max_len=100):
    tokens = tokenize(text)
    ids = [stoi.get(t, 1) for t in tokens[:max_len]]  # 1 = <unk>
    if len(ids) < max_len:
        ids += [0] * (max_len - len(ids))
    return ids, len(tokens[:max_len])


Building vocab: 100%|██████████| 150881/150881 [00:06<00:00, 25007.26it/s]


In [None]:
class ClausePairDataset(Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x1, l1 = encode(row.text1)
        x2, l2 = encode(row.text2)
        y = row.label
        return torch.tensor(x1), torch.tensor(l1), torch.tensor(x2), torch.tensor(l2), torch.tensor(y)

# Split
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(pair_df, test_size=0.2, random_state=42, stratify=pair_df["label"])
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42, stratify=test_df["label"])

train_ds = ClausePairDataset(train_df)
val_ds = ClausePairDataset(val_df)
test_ds = ClausePairDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
test_loader = DataLoader(test_ds, batch_size=64)

print("Train:", len(train_ds), "Val:", len(val_ds), "Test:", len(test_ds))


Train: 241080 Val: 30135 Test: 30135


## Siamese BiLSTM

In [None]:
class BiLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=200, hidden=256, num_layers=1, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden, num_layers=num_layers,
                            bidirectional=True, batch_first=True,
                            dropout=dropout if num_layers > 1 else 0)
        self.hidden_dim = hidden * 2

    def forward(self, x, lengths):
        emb = self.emb(x)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (hn, cn) = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        mask = (x != 0).float().unsqueeze(-1)
        summed = (out * mask).sum(1)
        rep = summed / lengths.unsqueeze(1)
        return rep

class SiameseSim(nn.Module):
    def __init__(self, encoder, hidden_mlp=256):
        super().__init__()
        self.enc = encoder
        D = self.enc.hidden_dim
        self.mlp = nn.Sequential(
            nn.Linear(D * 4, hidden_mlp),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_mlp, 1)
        )

    def forward(self, x1, len1, x2, len2):
        u = self.enc(x1, len1)
        v = self.enc(x2, len2)
        feats = torch.cat([u, v, torch.abs(u - v), u * v], dim=1)
        logits = self.mlp(feats).squeeze(1)
        return logits

encoder = BiLSTMEncoder(vocab_size=len(vocab))
model1 = SiameseSim(encoder).to(device)


#### Train Function

In [None]:
def train_model(model, train_loader, val_loader, epochs=5, lr=1e-3):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    best_f1 = 0
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x1, l1, x2, l2, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x1, l1, x2, l2, y = x1.to(device), l1.to(device), x2.to(device), l2.to(device), y.float().to(device)
            logits = model(x1, l1, x2, l2)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()

        val_f1, val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, val_f1={val_f1:.4f}, val_acc={val_acc:.4f}")
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), "best_model.pt")
    print("Best Val F1:", best_f1)


In [None]:
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    roc_auc_score, average_precision_score, classification_report
)
import numpy as np

def evaluate(model, loader, verbose=True):
    """
    Evaluate a similarity model on a given DataLoader.

    Returns:
        dict with accuracy, precision, recall, f1, roc_auc, pr_auc
    """
    model.eval()
    y_true, y_pred, y_prob = [], [], []

    with torch.no_grad():
        for x1, l1, x2, l2, y in loader:
            x1, l1, x2, l2 = x1.to(device), l1.to(device), x2.to(device), l2.to(device)
            logits = model(x1, l1, x2, l2)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            preds = (probs > 0.5).astype(int)
            y_true.extend(y.numpy())
            y_pred.extend(preds)
            y_prob.extend(probs)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_prob = np.array(y_prob)

    # ---- Core Metrics ----
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    roc_auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0
    pr_auc = average_precision_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0

    if verbose:
        print("\n===== Evaluation Metrics =====")
        print(f"Accuracy  (Overall correctness):       {acc:.4f}")
        print(f"Precision (Exactness):                 {prec:.4f}")
        print(f"Recall    (Completeness):              {rec:.4f}")
        print(f"F1-Score  (Balance of P/R):            {f1:.4f}")
        print(f"ROC-AUC   (Ranking ability):           {roc_auc:.4f}")
        print(f"PR-AUC    (Precision-Recall curve):    {pr_auc:.4f}")
        print("=====================================")

    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc
    }


In [None]:
train_model(model1, train_loader, val_loader, epochs=15)

Epoch 2/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.10it/s]


Epoch 2: loss=0.0042, val_f1=0.9995, val_acc=0.9995


Epoch 3/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.10it/s]


Epoch 3: loss=0.0033, val_f1=0.9990, val_acc=0.9990


Epoch 4/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.11it/s]


Epoch 4: loss=0.0028, val_f1=0.9996, val_acc=0.9996


Epoch 5/15: 100%|██████████| 3767/3767 [04:10<00:00, 15.05it/s]


Epoch 5: loss=0.0024, val_f1=0.9995, val_acc=0.9995


Epoch 6/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.10it/s]


Epoch 6: loss=0.0027, val_f1=0.9995, val_acc=0.9995


Epoch 7/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.09it/s]


Epoch 7: loss=0.0022, val_f1=0.9995, val_acc=0.9995


Epoch 8/15: 100%|██████████| 3767/3767 [04:10<00:00, 15.06it/s]


Epoch 8: loss=0.0023, val_f1=0.9992, val_acc=0.9992


Epoch 9/15: 100%|██████████| 3767/3767 [04:09<00:00, 15.10it/s]


Epoch 9: loss=0.0022, val_f1=0.9996, val_acc=0.9996


Epoch 10/15: 100%|██████████| 3767/3767 [04:08<00:00, 15.13it/s]


Epoch 10: loss=0.0022, val_f1=0.9995, val_acc=0.9995


Epoch 11/15:   1%|          | 20/3767 [00:01<04:20, 14.41it/s]


KeyboardInterrupt: 

## Model 2: BiLSTM + Self-Attention Encoder

In [None]:
class AttentionPool(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.w = nn.Linear(hidden_dim, 1)
    def forward(self, mem, mask):
        scores = self.w(mem).squeeze(-1)
        scores = scores.masked_fill(~mask, -1e9)
        attn = torch.softmax(scores, dim=1).unsqueeze(-1)
        rep = (mem * attn).sum(1)
        return rep

class BiLSTMAttentionEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=200, hidden=256, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden, bidirectional=True, batch_first=True)
        self.attn = AttentionPool(hidden * 2)
        self.hidden_dim = hidden * 2
    def forward(self, x, lengths):
        emb = self.emb(x)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        mask = (x != 0)
        rep = self.attn(out, mask)
        return rep

encoder2 = BiLSTMAttentionEncoder(vocab_size=len(vocab))
model2 = SiameseSim(encoder2).to(device)

train_model(model2, train_loader, val_loader, epochs=5)


Epoch 1/5: 100%|██████████| 3767/3767 [04:13<00:00, 14.84it/s]


Epoch 1: loss=0.0254, val_f1=0.9991, val_acc=0.9991


Epoch 2/5: 100%|██████████| 3767/3767 [04:14<00:00, 14.82it/s]


Epoch 2: loss=0.0037, val_f1=0.9991, val_acc=0.9991


Epoch 3/5: 100%|██████████| 3767/3767 [04:14<00:00, 14.80it/s]


Epoch 3: loss=0.0022, val_f1=0.9995, val_acc=0.9995


Epoch 4/5: 100%|██████████| 3767/3767 [04:14<00:00, 14.83it/s]


Epoch 4: loss=0.0017, val_f1=0.9995, val_acc=0.9995


Epoch 5/5: 100%|██████████| 3767/3767 [04:13<00:00, 14.85it/s]


In [None]:
m1 = evaluate(model1 , test_loader)
m2 = evaluate(model2 , test_loader)

print("\n===== Final Comparison =====")
print(f"Siamese BiLSTM     — {m1}")
print(f"BiLSTM + Attention — {m2}")


===== Evaluation Metrics =====
Accuracy  (Overall correctness):       0.9997
Precision (Exactness):                 0.9996
Recall    (Completeness):              0.9999
F1-Score  (Balance of P/R):            0.9997
ROC-AUC   (Ranking ability):           0.9999
PR-AUC    (Precision-Recall curve):    0.9998

===== Evaluation Metrics =====
Accuracy  (Overall correctness):       0.9992
Precision (Exactness):                 0.9985
Recall    (Completeness):              0.9999
F1-Score  (Balance of P/R):            0.9992
ROC-AUC   (Ranking ability):           0.9998
PR-AUC    (Precision-Recall curve):    0.9995

===== Final Comparison =====
Siamese BiLSTM     — {'accuracy': 0.9997345279575245, 'precision': 0.9996013553916683, 'recall': 0.9998670831394962, 'f1': 0.9997342016080802, 'roc_auc': 0.9999338675191012, 'pr_auc': 0.9997977179639045}
BiLSTM + Attention — {'accuracy': 0.999170399867264, 'precision': 0.9984735864082824, 'recall': 0.9998670831394962, 'f1': 0.999169848912502, 'roc_auc'