In [None]:

import os
import glob
import random
import itertools
import numpy as np
import pandas as pd
from typing import List, Tuple

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    roc_auc_score, average_precision_score, classification_report
)

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

DATA_DIR = os.path.join(os.getcwd(), "archive")
assert os.path.isdir(DATA_DIR), f"Expected data directory at {DATA_DIR}"


# ===============================================
# 2. Load Dataset
# ===============================================
def load_all_clauses(data_dir: str) -> pd.DataFrame:
    paths = sorted(glob.glob(os.path.join(data_dir, "*.csv")))
    frames = []
    for p in paths:
        try:
            df = pd.read_csv(p)
            if {"clause_text", "clause_type"}.issubset(df.columns):
                frames.append(df[["clause_text", "clause_type"]].dropna())
        except Exception as e:
            print(f"Skipping {p}: {e}")
    all_df = pd.concat(frames, ignore_index=True)
    all_df = all_df.drop_duplicates(subset=["clause_text"]).reset_index(drop=True)
    return all_df

clauses_df = load_all_clauses(DATA_DIR)
print(clauses_df.shape)
print(clauses_df["clause_type"].value_counts().head())
clauses_df.head()


# ===============================================
# 3. Build Positive/Negative Pairs for Siamese Input
# ===============================================
def build_pairs(df: pd.DataFrame, max_pairs_per_class: int = 4000, negatives_multiplier: int = 1) -> pd.DataFrame:
    by_type = df.groupby("clause_type")["clause_text"].apply(list).to_dict()
    types = list(by_type.keys())

    # Positive pairs
    pos_pairs = []
    for t, texts in by_type.items():
        if len(texts) < 2:
            continue
        combinations = list(itertools.combinations(range(len(texts)), 2))
        random.shuffle(combinations)
        combinations = combinations[:max_pairs_per_class]
        for i, j in combinations:
            pos_pairs.append((texts[i], texts[j], 1))

    # Negative pairs
    neg_pairs = []
    all_texts_by_type = {t: v[:] for t, v in by_type.items() if v}
    type_pairs = list(itertools.combinations(types, 2))
    random.shuffle(type_pairs)
    target_negs = len(pos_pairs) * negatives_multiplier
    k = 0
    while len(neg_pairs) < target_negs and k < len(type_pairs):
        t1, t2 = type_pairs[k % len(type_pairs)]
        texts1 = all_texts_by_type.get(t1, [])
        texts2 = all_texts_by_type.get(t2, [])
        if texts1 and texts2:
            a = random.choice(texts1)
            b = random.choice(texts2)
            neg_pairs.append((a, b, 0))
        k += 1

    data = pos_pairs + neg_pairs
    random.shuffle(data)
    return pd.DataFrame(data, columns=["text_left", "text_right", "label"]).dropna()

pairs_df = build_pairs(clauses_df, max_pairs_per_class=3000, negatives_multiplier=1)
print(pairs_df.shape, pairs_df["label"].value_counts())
pairs_df.head()


# ===============================================
# 4. Split into Train / Val / Test
# ===============================================
train_df, test_df = train_test_split(pairs_df, test_size=0.15, random_state=SEED, stratify=pairs_df["label"])
train_df, val_df  = train_test_split(train_df, test_size=0.15, random_state=SEED, stratify=train_df["label"])
len(train_df), len(val_df), len(test_df)


# ===============================================
# 5. Text Vectorization
# ===============================================
MAX_VOCAB = 30000
MAX_LEN = 128

vectorizer = layers.TextVectorization(
    max_tokens=MAX_VOCAB,
    output_mode="int",
    output_sequence_length=MAX_LEN,
    standardize="lower_and_strip_punctuation",
)

vectorizer.adapt(tf.data.Dataset.from_tensor_slices(pd.concat([
    train_df["text_left"], train_df["text_right"]
]).astype(str)).batch(1024))

vocab_size = len(vectorizer.get_vocabulary())
print("Vocab size:", vocab_size)


# ===============================================
# 6. Model & Dataset Builders
# ===============================================
def to_ds(df: pd.DataFrame, batch_size: int = 64) -> tf.data.Dataset:
    ds = tf.data.Dataset.from_tensor_slices((
        {
            "text_left": df["text_left"].astype(str).values,
            "text_right": df["text_right"].astype(str).values,
        },
        df["label"].values,
    ))
    return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)


EMBED_DIM = 128
LSTM_DIM = 64
DROPOUT = 0.3


# ----- BiLSTM Encoder -----
def build_text_encoder_bilstm():
    vocab_size = len(vectorizer.get_vocabulary())
    inp = keras.Input(shape=(), dtype=tf.string)
    x = vectorizer(inp)
    x = layers.Embedding(vocab_size, EMBED_DIM, mask_zero=True)(x)
    x = layers.Bidirectional(layers.LSTM(LSTM_DIM, return_sequences=False))(x)
    x = layers.Dropout(DROPOUT)(x)
    x = layers.Dense(EMBED_DIM)(x)
    return keras.Model(inp, x)


# ----- Attention Encoder -----
def build_text_encoder_attention():
    vocab_size = len(vectorizer.get_vocabulary())
    inp = keras.Input(shape=(), dtype=tf.string)
    x = vectorizer(inp)
    x = layers.Embedding(vocab_size, EMBED_DIM, mask_zero=True)(x)
    x = layers.MultiHeadAttention(num_heads=4, key_dim=EMBED_DIM)(x, x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(DROPOUT)(x)
    x = layers.Dense(EMBED_DIM)(x)
    return keras.Model(inp, x)


# ----- Siamese Network -----
def build_siamese(encoder_builder) -> keras.Model:
    encoder = encoder_builder()

    left_inp = keras.Input(shape=(), dtype=tf.string, name="text_left")
    right_inp = keras.Input(shape=(), dtype=tf.string, name="text_right")

    left_emb = encoder(left_inp)
    right_emb = encoder(right_inp)

    diff = tf.abs(left_emb - right_emb)
    out = layers.Dense(1, activation="sigmoid")(diff)

    model = keras.Model(inputs=[left_inp, right_inp], outputs=out)
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy"]
    )
    return model


# ----- Prepare Datasets -----
train_ds = to_ds(train_df)
val_ds   = to_ds(val_df)
test_ds  = to_ds(test_df)


# ===============================================
# 7. Train BiLSTM Siamese Model (reduced epochs)
# ===============================================
bilstm_model = build_siamese(build_text_encoder_bilstm)
cb = [keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True, monitor="val_accuracy")]

h_bilstm = bilstm_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,   
    callbacks=cb,
    verbose=1,
)


# ===============================================
# 8. Train Attention Siamese Model (reduced epochs)
# ===============================================
attn_model = build_siamese(build_text_encoder_attention)
h_attn = attn_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,  
    callbacks=cb,
    verbose=1,
)


# ===============================================
# 9. Evaluation Function
# ===============================================
def evaluate_model(model: keras.Model, df: pd.DataFrame, name: str):
    ds = to_ds(df)
    probs = model.predict(ds, verbose=0).ravel()
    preds = (probs >= 0.5).astype(int)
    y_true = df["label"].values

    acc = accuracy_score(y_true, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, preds, average="binary")
    try:
        roc = roc_auc_score(y_true, probs)
    except Exception:
        roc = np.nan
    try:
        pr_auc = average_precision_score(y_true, probs)
    except Exception:
        pr_auc = np.nan

    print(f"\n{name} Results (Test)")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc:.4f}")
    print(f"PR-AUC:    {pr_auc:.4f}")
    print("\nClassification Report:\n", classification_report(y_true, preds, digits=4))

    return {
        "name": name,
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "roc_auc": roc,
        "pr_auc": pr_auc,
    }


# ===============================================
# 10. Evaluate Models
# ===============================================
bilstm_metrics = evaluate_model(bilstm_model, test_df, "BiLSTM Siamese")
attn_metrics   = evaluate_model(attn_model,   test_df, "Attention Siamese")

pd.DataFrame([bilstm_metrics, attn_metrics])


# ===============================================
# 11. Qualitative Predictions
# ===============================================
def sample_predictions(model, df: pd.DataFrame, k: int = 6):
    sub = df.sample(n=min(k, len(df)), random_state=SEED).copy()
    ds = to_ds(sub)
    probs = model.predict(ds, verbose=0).ravel()
    preds = (probs >= 0.5).astype(int)
    sub["pred"] = preds
    sub["prob"] = probs
    return sub[["text_left", "text_right", "label", "pred", "prob"]]

print("BiLSTM samples:")
print(sample_predictions(bilstm_model,  test_df, k=6).to_string(index=False)[:2000])

print("\nAttention samples:")
print(sample_predictions(attn_model, test_df, k=6).to_string(index=False)[:2000])



(150545, 2)
clause_type
time-of-essence                   630
time-of-the-essence               620
capitalized-terms                 590
definitions-and-interpretation    590
captions                          580
Name: count, dtype: int64
(1254654, 3) label
1    1177233
0      77421
Name: count, dtype: int64


Vocab size: 30000
Epoch 1/5

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
  198/14164 [..............................] - ETA: 2:29:08 - loss: 0.3004 - accuracy: 0.9269

KeyboardInterrupt: 