In [19]:
import os
import copy
import math
import hashlib
from collections import Counter, defaultdict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score, classification_report

In [20]:
# =============================================================================
# 0) Load data
# =============================================================================
data_path = os.path.join(os.getcwd(), "..", "data", "cleaned", "full_data.csv")
df = pd.read_csv(data_path)

# =============================================================================
# 1) Label merge (APPLY IT)
# =============================================================================
LABEL_MERGE_MAP = {
    # keep
    "Other": "Other",

    # Name variants
    "Name": "Name",
    "NameLink": "Name",
    "NameLocation": "Name",

    # Date variants
    "Date": "Date",
    "DateTime": "Date",

    # Time variants
    "Time": "Time",
    "StartTime": "Time",
    "EndTime": "Time",
    "StartEndTime": "Time",
    "TimeLocation": "Time",

    "Location": "Location",

    "Description": "Description",
    "Desc": "Description",
    "Details": "Description",
}

def merge_labels(df, label_col="label", mapping=None, default_to_other=True):
    df = df.copy()
    mapping = mapping or {}

    def _map(x):
        x = str(x)
        if x in mapping:
            return mapping[x]
        return "Other" if default_to_other else x

    df[label_col] = df[label_col].map(_map)
    return df

df = merge_labels(df, "label", LABEL_MERGE_MAP)

In [21]:
df = df.sort_values(["source", "rendering_order"]).reset_index(drop=True)

df["in_event"] = df["event_id"].notna().astype(int)
df["start_event"] = 0

m = df["event_id"].notna()
first_idx = df.loc[m].groupby(["source", "event_id"], sort=False).head(1).index
df.loc[first_idx, "start_event"] = 1

# BIO: 0=O, 1=B, 2=I
# B = start of each event record
df["bio"] = 0
df.loc[df["in_event"].eq(1), "bio"] = 2
df.loc[df["start_event"].eq(1), "bio"] = 1

print("Pages:", df["source"].nunique())
print("Total nodes:", len(df))
print("Unique events:", df.loc[m].drop_duplicates(["source","event_id"]).shape[0])
print("Start positives:", int(df["start_event"].sum()))
print("Start rate:", float(df["start_event"].mean()))
print("Label counts:\n", df["label"].value_counts())

Pages: 15
Total nodes: 2764
Unique events: 177
Start positives: 177
Start rate: 0.06403762662807526
Label counts:
 label
Other          1976
Date            314
Name            150
Time            146
Location        121
Description      57
Name: count, dtype: int64


In [22]:
# =============================================================================
# 3) Deduplicate identical pages BEFORE splitting (prevents leakage / misleading CV)
# =============================================================================
def page_fingerprint(g: pd.DataFrame) -> str:
    # Strong-ish fingerprint using ordered sequence of (tag, parent_tag, text_context)
    # You can add more columns if you want.
    parts = (
        g["tag"].astype(str).fillna("") + "\t" +
        g["parent_tag"].astype(str).fillna("") + "\t" +
        g["text_context"].astype(str).fillna("")
    ).tolist()
    s = "\n".join(parts)
    return hashlib.md5(s.encode("utf-8")).hexdigest()

fps = []
for src, g in df.groupby("source", sort=False):
    g = g.sort_values("rendering_order")
    fps.append((src, page_fingerprint(g)))

fp_df = pd.DataFrame(fps, columns=["source", "fp"])
dup_groups = fp_df.groupby("fp")["source"].apply(list)

# Keep the first source per fingerprint, drop others
keep_sources = []
drop_sources = []
for fp, sources in dup_groups.items():
    keep_sources.append(sources[0])
    if len(sources) > 1:
        drop_sources.extend(sources[1:])

if drop_sources:
    print("\n[DEDUP] Dropping duplicate pages:")
    for s in drop_sources:
        print("  -", s)

df = df[df["source"].isin(set(keep_sources))].copy()
df = df.sort_values(["source", "rendering_order"]).reset_index(drop=True)

print("\nAfter dedup:")
print("Pages:", df["source"].nunique())
print("Total nodes:", len(df))


[DEDUP] Dropping duplicate pages:
  - members.sacac.org_pattern_labeled

After dedup:
Pages: 14
Total nodes: 2491


In [23]:
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

LABELS = sorted(df["label"].unique().tolist())
label2id = {l: i for i, l in enumerate(LABELS)}
id2label = {i: l for l, i in label2id.items()}
OTHER_ID = label2id["Other"]

TAG_VOCAB = {t: i for i, t in enumerate(sorted(df["tag"].astype(str).unique().tolist()))}
PARENT_TAG_VOCAB = {t: i for i, t in enumerate(sorted(df["parent_tag"].astype(str).unique().tolist()))}

STRUCT_COLS_NUM = [
    "depth","sibling_index","children_count","same_tag_sibling_count",
    "same_text_sibling_count","text_length","word_count",
    "letter_ratio","digit_ratio","whitespace_ratio","attribute_count"
]
STRUCT_COLS_BOOL = [
    "has_link","link_is_absolute","parent_has_link","is_leaf",
    "contains_date","contains_time","starts_with_digit","ends_with_digit",
    "has_class","has_id",
    "attr_has_word_name","attr_has_word_date","attr_has_word_time","attr_has_word_location","attr_has_word_link",
    "text_has_word_name","text_has_word_date","text_word_time","text_word_description","text_word_location"
]

CUDA available: True


In [24]:
all_sources = np.array(sorted(df["source"].unique()))
rng = np.random.default_rng(42)
rng.shuffle(all_sources)

TEST_N_PAGES = min(2, len(all_sources))
test_sources = set(all_sources[:TEST_N_PAGES])
cv_sources   = all_sources[TEST_N_PAGES:]

print("\nHoldout TEST pages:", len(test_sources), sorted(list(test_sources)))
test_df = df[df["source"].isin(test_sources)].copy()


Holdout TEST pages: 2 [np.str_('nacacnet.org_pattern_labeled'), np.str_('neacac_fall.net_pattern_labeled')]


In [25]:
# =============================================================================
# 6) Normalization: train-only stats per fold
# =============================================================================
def compute_num_stats(train_df: pd.DataFrame, cols):
    x = train_df[cols].fillna(0).values.astype("float32")
    mean = x.mean(axis=0)
    std = x.std(axis=0)
    std = np.where(std < 1e-6, 1.0, std).astype("float32")
    return mean.astype("float32"), std.astype("float32")

# =============================================================================
# 7) Dataset (stores BIO + in_event) and applies normalization
# =============================================================================
class PageDataset(Dataset):
    """
    Caches tokenization as Python lists.
    Applies numeric normalization using provided (mean, std).
    """
    def __init__(self, df, tokenizer, num_mean=None, num_std=None, max_tokens=64):
        self.pages = []
        self.max_tokens = max_tokens
        self.tokenizer = tokenizer

        if num_mean is None:
            num_mean = np.zeros((len(STRUCT_COLS_NUM),), dtype="float32")
        if num_std is None:
            num_std = np.ones((len(STRUCT_COLS_NUM),), dtype="float32")
        self.num_mean = num_mean
        self.num_std = num_std

        for src, g in df.groupby("source", sort=False):
            g = g.sort_values("rendering_order").reset_index(drop=True)

            texts = g["text_context"].astype(str).tolist()
            enc = tokenizer(
                texts,
                padding=False,
                truncation=True,
                max_length=max_tokens,
                return_attention_mask=True,
                return_tensors=None
            )

            num = g[STRUCT_COLS_NUM].fillna(0).values.astype("float32")
            num = (num - self.num_mean) / self.num_std

            page = {
                "input_ids": enc["input_ids"],
                "attention_mask": enc["attention_mask"],
                "field_y": [label2id[x] for x in g["label"].tolist()],
                "bio_y": g["bio"].astype(int).tolist(),          # 0/1/2
                "in_event_y": g["in_event"].astype(int).tolist(),# 0/1
                "tag_id": [TAG_VOCAB[str(x)] for x in g["tag"]],
                "parent_tag_id": [PARENT_TAG_VOCAB[str(x)] for x in g["parent_tag"]],
                "num_feats": num,
                "bool_feats": g[STRUCT_COLS_BOOL].astype(int).values.astype("float32"),
            }
            self.pages.append(page)

    def __len__(self):
        return len(self.pages)

    def __getitem__(self, idx):
        return self.pages[idx]


In [26]:
def collate_fn(batch):
    B = len(batch)
    max_nodes = max(len(x["input_ids"]) for x in batch)

    flat = []
    node_offsets = []
    node_mask = torch.zeros((B, max_nodes), dtype=torch.bool)

    start = 0
    for i, item in enumerate(batch):
        n = len(item["input_ids"])
        node_mask[i, :n] = True
        for j in range(n):
            flat.append({"input_ids": item["input_ids"][j], "attention_mask": item["attention_mask"][j]})
        end = start + n
        node_offsets.append((start, end))
        start = end

    enc = tokenizer.pad(flat, padding=True, return_tensors="pt")

    def pad_1d_list(list_of_lists, pad_value, dtype):
        out = torch.full((B, max_nodes), pad_value, dtype=dtype)
        for i, lst in enumerate(list_of_lists):
            n = len(lst)
            out[i, :n] = torch.tensor(lst, dtype=dtype)
        return out

    def pad_2d_array(list_of_arrays, feat_dim, pad_value=0.0, dtype=torch.float32):
        out = torch.full((B, max_nodes, feat_dim), pad_value, dtype=dtype)
        for i, arr in enumerate(list_of_arrays):
            n = arr.shape[0]
            out[i, :n, :] = torch.tensor(arr, dtype=dtype)
        return out

    field_y    = pad_1d_list([x["field_y"] for x in batch], pad_value=-100, dtype=torch.long)
    bio_y      = pad_1d_list([x["bio_y"] for x in batch],   pad_value=-100, dtype=torch.long)
    in_event_y = pad_1d_list([x["in_event_y"] for x in batch], pad_value=-100, dtype=torch.long)

    tag_id        = pad_1d_list([x["tag_id"] for x in batch], pad_value=0, dtype=torch.long)
    parent_tag_id = pad_1d_list([x["parent_tag_id"] for x in batch], pad_value=0, dtype=torch.long)

    num_feats  = pad_2d_array([x["num_feats"] for x in batch], feat_dim=len(STRUCT_COLS_NUM), pad_value=0.0)
    bool_feats = pad_2d_array([x["bool_feats"] for x in batch], feat_dim=len(STRUCT_COLS_BOOL), pad_value=0.0)

    return {
        "enc": enc,
        "node_offsets": node_offsets,
        "node_mask": node_mask,
        "field_y": field_y,
        "bio_y": bio_y,
        "in_event_y": in_event_y,
        "tag_id": tag_id,
        "parent_tag_id": parent_tag_id,
        "num_feats": num_feats,
        "bool_feats": bool_feats,
    }

In [27]:
class DOMAwareEventExtractor(nn.Module):
    def __init__(
        self,
        text_model_name: str,
        num_field_labels: int,
        tag_vocab_size: int,
        parent_tag_vocab_size: int,
        d_model: int = 128,
        nhead: int = 4,
        num_layers: int = 2,
        dropout: float = 0.2
    ):
        super().__init__()
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        text_dim = self.text_encoder.config.hidden_size

        self.text_proj = nn.Linear(text_dim, d_model)

        self.tag_emb = nn.Embedding(tag_vocab_size, d_model)
        self.parent_tag_emb = nn.Embedding(parent_tag_vocab_size, d_model)

        self.num_proj = nn.Linear(len(STRUCT_COLS_NUM), d_model)
        self.bool_proj = nn.Linear(len(STRUCT_COLS_BOOL), d_model)

        self.layernorm = nn.LayerNorm(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True
        )
        self.node_encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.field_head = nn.Linear(d_model, num_field_labels)
        self.bio_head = nn.Linear(d_model, 3)       # O/B/I
        self.in_event_head = nn.Linear(d_model, 1)  # binary

    def forward(self, enc, node_offsets, node_mask, tag_id, parent_tag_id, num_feats, bool_feats):
        out = self.text_encoder(**enc)
        cls = out.last_hidden_state[:, 0, :]     # [total_nodes, text_dim]
        node_text = self.text_proj(cls)          # [total_nodes, d_model]

        B, max_nodes = node_mask.shape
        packed = node_text.new_zeros((B, max_nodes, node_text.shape[-1]))
        for i, (s, e) in enumerate(node_offsets):
            packed[i, : (e - s), :] = node_text[s:e]

        x = (
            packed
            + self.tag_emb(tag_id)
            + self.parent_tag_emb(parent_tag_id)
            + self.num_proj(num_feats)
            + self.bool_proj(bool_feats)
        )
        x = self.layernorm(x)

        key_padding_mask = ~node_mask
        x = self.node_encoder(x, src_key_padding_mask=key_padding_mask)

        field_logits = self.field_head(x)                  # [B, N, C]
        bio_logits   = self.bio_head(x)                    # [B, N, 3]
        in_event_logits = self.in_event_head(x).squeeze(-1)# [B, N]
        return field_logits, bio_logits, in_event_logits


In [28]:
def make_losses_for_train_df(
    train_df,
    LABELS,
    device,
    other_scale=0.05,
    weight_cap=50.0
):
    # ---- Field weights ----
    label_counts = Counter(train_df["label"].tolist())
    total = sum(label_counts.values())
    weights = []
    for label in LABELS:
        c = label_counts.get(label, 1)
        weights.append(total / c)
    weights = torch.tensor(weights, dtype=torch.float32, device=device)
    weights = torch.clamp(weights, max=weight_cap)
    if "Other" in LABELS:
        weights[LABELS.index("Other")] *= other_scale

    field_loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100)

    # ---- BIO weights ----
    bio_counts = Counter(train_df["bio"].tolist())
    # bio labels are 0/1/2, but ensure all exist
    bio_total = sum(bio_counts.get(i, 0) for i in [0,1,2])
    bio_w = []
    for i in [0,1,2]:
        c = bio_counts.get(i, 1)
        bio_w.append(bio_total / c)
    bio_w = torch.tensor(bio_w, dtype=torch.float32, device=device)
    bio_w = torch.clamp(bio_w, max=weight_cap)
    bio_loss_fn = nn.CrossEntropyLoss(weight=bio_w, ignore_index=-100)

    # ---- in_event pos_weight ----
    pos = float(train_df["in_event"].sum())
    neg = float(len(train_df) - pos)
    pos_weight = torch.tensor([neg / (pos + 1e-6)], dtype=torch.float32, device=device)
    in_event_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    return field_loss_fn, bio_loss_fn, in_event_loss_fn

In [29]:
def pick_starts_from_probs(probs, threshold=0.5, nms_k=1, min_gap=2):
    """
    probs: 1D numpy array of prob(B) per node (valid nodes only)
    Returns sorted list of predicted start indices (within valid-node indexing)
    Steps:
      - threshold
      - local maxima (nms_k neighborhood)
      - greedy keep with min_gap
    """
    probs = np.asarray(probs)
    N = probs.shape[0]
    if N == 0:
        return []

    # Candidates above threshold
    cand = np.where(probs >= threshold)[0]
    if cand.size == 0:
        return []

    # Local maxima within +/- nms_k
    if nms_k > 0:
        keep = []
        for i in cand:
            lo = max(0, i - nms_k)
            hi = min(N, i + nms_k + 1)
            if probs[i] >= probs[lo:hi].max() - 1e-12:
                keep.append(i)
        cand = np.array(keep, dtype=int)
        if cand.size == 0:
            return []

    # Greedy by prob descending, enforce min_gap
    order = cand[np.argsort(-probs[cand])]
    chosen = []
    for idx in order:
        if all(abs(idx - j) > min_gap for j in chosen):
            chosen.append(int(idx))
    return sorted(chosen)

def start_prf_with_tolerance(true_starts, pred_starts, tol=1):
    """
    true_starts, pred_starts: sorted lists of indices
    tolerance: pred counts as TP if within +/- tol of an unmatched true
    """
    true_starts = list(true_starts)
    pred_starts = list(pred_starts)

    matched_true = set()
    tp = 0
    for p in pred_starts:
        best = None
        best_dist = None
        for ti, t in enumerate(true_starts):
            if ti in matched_true:
                continue
            d = abs(p - t)
            if d <= tol and (best_dist is None or d < best_dist):
                best = ti
                best_dist = d
        if best is not None:
            matched_true.add(best)
            tp += 1

    fp = len(pred_starts) - tp
    fn = len(true_starts) - tp

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return prec, rec, f1

In [30]:
def make_losses_for_train_df(
    train_df,
    LABELS,
    device,
    other_scale=0.05,
    weight_cap=50.0
):
    # ---- Field weights ----
    label_counts = Counter(train_df["label"].tolist())
    total = sum(label_counts.values())
    weights = []
    for label in LABELS:
        c = label_counts.get(label, 1)
        weights.append(total / c)
    weights = torch.tensor(weights, dtype=torch.float32, device=device)
    weights = torch.clamp(weights, max=weight_cap)
    if "Other" in LABELS:
        weights[LABELS.index("Other")] *= other_scale

    field_loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100)

    # ---- BIO weights ----
    bio_counts = Counter(train_df["bio"].tolist())
    # bio labels are 0/1/2, but ensure all exist
    bio_total = sum(bio_counts.get(i, 0) for i in [0,1,2])
    bio_w = []
    for i in [0,1,2]:
        c = bio_counts.get(i, 1)
        bio_w.append(bio_total / c)
    bio_w = torch.tensor(bio_w, dtype=torch.float32, device=device)
    bio_w = torch.clamp(bio_w, max=weight_cap)
    bio_loss_fn = nn.CrossEntropyLoss(weight=bio_w, ignore_index=-100)

    # ---- in_event pos_weight ----
    pos = float(train_df["in_event"].sum())
    neg = float(len(train_df) - pos)
    pos_weight = torch.tensor([neg / (pos + 1e-6)], dtype=torch.float32, device=device)
    in_event_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    return field_loss_fn, bio_loss_fn, in_event_loss_fn

# =============================================================================
# 11) Peak-based decoding + start metrics (NMS/min-gap/tolerance)
# =============================================================================
def pick_starts_from_probs(probs, threshold=0.5, nms_k=1, min_gap=2):
    """
    probs: 1D numpy array of prob(B) per node (valid nodes only)
    Returns sorted list of predicted start indices (within valid-node indexing)
    Steps:
      - threshold
      - local maxima (nms_k neighborhood)
      - greedy keep with min_gap
    """
    probs = np.asarray(probs)
    N = probs.shape[0]
    if N == 0:
        return []

    # Candidates above threshold
    cand = np.where(probs >= threshold)[0]
    if cand.size == 0:
        return []

    # Local maxima within +/- nms_k
    if nms_k > 0:
        keep = []
        for i in cand:
            lo = max(0, i - nms_k)
            hi = min(N, i + nms_k + 1)
            if probs[i] >= probs[lo:hi].max() - 1e-12:
                keep.append(i)
        cand = np.array(keep, dtype=int)
        if cand.size == 0:
            return []

    # Greedy by prob descending, enforce min_gap
    order = cand[np.argsort(-probs[cand])]
    chosen = []
    for idx in order:
        if all(abs(idx - j) > min_gap for j in chosen):
            chosen.append(int(idx))
    return sorted(chosen)

def start_prf_with_tolerance(true_starts, pred_starts, tol=1):
    """
    true_starts, pred_starts: sorted lists of indices
    tolerance: pred counts as TP if within +/- tol of an unmatched true
    """
    true_starts = list(true_starts)
    pred_starts = list(pred_starts)

    matched_true = set()
    tp = 0
    for p in pred_starts:
        best = None
        best_dist = None
        for ti, t in enumerate(true_starts):
            if ti in matched_true:
                continue
            d = abs(p - t)
            if d <= tol and (best_dist is None or d < best_dist):
                best = ti
                best_dist = d
        if best is not None:
            matched_true.add(best)
            tp += 1

    fp = len(pred_starts) - tp
    fn = len(true_starts) - tp

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return prec, rec, f1


In [31]:

@torch.no_grad()
def collect_page_probs_and_truth(loader, model, device):
    """
    Returns list of (prob_B_valid, true_start_valid) for each page instance in loader.
    prob_B_valid: np array length = #valid nodes
    true_start_valid: np array length = #valid nodes, {0,1} where start_event==1 (bio==B)
    """
    model.eval()
    out = []
    for batch in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in batch["enc"].items()}
        node_mask = batch["node_mask"].to(device).bool()

        field_logits, bio_logits, in_event_logits = model(
            enc=enc,
            node_offsets=batch["node_offsets"],
            node_mask=node_mask,
            tag_id=batch["tag_id"].to(device),
            parent_tag_id=batch["parent_tag_id"].to(device),
            num_feats=batch["num_feats"].to(device),
            bool_feats=batch["bool_feats"].to(device),
        )

        # prob of B (bio label 1)
        prob_B = torch.softmax(bio_logits, dim=-1)[..., 1]  # [B,N]

        bio_y = batch["bio_y"].to(device)
        true_start = (bio_y == 1).long()

        prob_B = prob_B.detach().cpu()
        true_start = true_start.detach().cpu()
        mask = node_mask.detach().cpu()

        B, N = prob_B.shape
        for b in range(B):
            valid = torch.where(mask[b])[0]
            if valid.numel() == 0:
                continue
            out.append((
                prob_B[b, valid].numpy(),
                true_start[b, valid].numpy().astype(int)
            ))
    return out

@torch.no_grad()
def find_best_threshold_peak(
    loader,
    model,
    device,
    thresholds=None,
    nms_k=1,
    min_gap=2,
    tol=1
):
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.95, 19)

    pages = collect_page_probs_and_truth(loader, model, device)
    if len(pages) == 0:
        return 0.5, 0.0

    best_th = 0.5
    best_f1 = -1.0

    for th in thresholds:
        ps = []
        rs = []
        fs = []
        # Micro over pages by summing TP/FP/FN (more stable than averaging F1)
        TP = FP = FN = 0
        for prob_B, true_start in pages:
            true_idx = np.where(true_start == 1)[0].tolist()
            pred_idx = pick_starts_from_probs(prob_B, threshold=th, nms_k=nms_k, min_gap=min_gap)
            p, r, f = start_prf_with_tolerance(true_idx, pred_idx, tol=tol)
            # Convert per-page p/r/f to counts for micro:
            # We can directly derive counts:
            # - tp from tolerance matching (recompute counts)
            # We'll approximate by converting:
            # tp = p*(tp+fp) etc is messy. Let's recompute counts exactly:
            # Do exact matching:
            matched_true = set()
            tp = 0
            for pr in pred_idx:
                best = None
                best_dist = None
                for ti, t in enumerate(true_idx):
                    if ti in matched_true:
                        continue
                    d = abs(pr - t)
                    if d <= tol and (best_dist is None or d < best_dist):
                        best = ti
                        best_dist = d
                if best is not None:
                    matched_true.add(best)
                    tp += 1
            fp = len(pred_idx) - tp
            fn = len(true_idx) - tp

            TP += tp
            FP += fp
            FN += fn

        prec = TP / (TP + FP + 1e-9)
        rec  = TP / (TP + FN + 1e-9)
        f1   = 2 * prec * rec / (prec + rec + 1e-9)

        if f1 > best_f1:
            best_f1 = f1
            best_th = float(th)

    return best_th, float(best_f1)

In [32]:
@torch.no_grad()
def boundary_metrics_peak(loader, model, device, threshold, nms_k=1, min_gap=2, tol=1):
    pages = collect_page_probs_and_truth(loader, model, device)
    if len(pages) == 0:
        return 0.0, 0.0, 0.0

    TP = FP = FN = 0
    for prob_B, true_start in pages:
        true_idx = np.where(true_start == 1)[0].tolist()
        pred_idx = pick_starts_from_probs(prob_B, threshold=threshold, nms_k=nms_k, min_gap=min_gap)

        matched_true = set()
        tp = 0
        for pr in pred_idx:
            best = None
            best_dist = None
            for ti, t in enumerate(true_idx):
                if ti in matched_true:
                    continue
                d = abs(pr - t)
                if d <= tol and (best_dist is None or d < best_dist):
                    best = ti
                    best_dist = d
            if best is not None:
                matched_true.add(best)
                tp += 1
        fp = len(pred_idx) - tp
        fn = len(true_idx) - tp

        TP += tp
        FP += fp
        FN += fn

    prec = TP / (TP + FP + 1e-9)
    rec  = TP / (TP + FN + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return float(prec), float(rec), float(f1)

# =============================================================================
# 12) Field metrics: ignore Other, ignore padding. (Optionally restrict to in_event)
# =============================================================================
@torch.no_grad()
def field_metrics_fast(loader, model, device, label2id, average="micro", restrict_to_true_in_event=False):
    model.eval()
    yt, yp = [], []
    OTHER_ID = label2id["Other"]

    for batch in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in batch["enc"].items()}
        node_mask = batch["node_mask"].to(device).bool()
        field_y = batch["field_y"].to(device)
        in_event_y = batch["in_event_y"].to(device)

        field_logits, bio_logits, in_event_logits = model(
            enc=enc,
            node_offsets=batch["node_offsets"],
            node_mask=node_mask,
            tag_id=batch["tag_id"].to(device),
            parent_tag_id=batch["parent_tag_id"].to(device),
            num_feats=batch["num_feats"].to(device),
            bool_feats=batch["bool_feats"].to(device),
        )
        pred = torch.argmax(field_logits, dim=-1)

        valid = node_mask & (field_y != -100) & (field_y != OTHER_ID)
        if restrict_to_true_in_event:
            valid = valid & (in_event_y == 1)

        yt.extend(field_y[valid].detach().cpu().tolist())
        yp.extend(pred[valid].detach().cpu().tolist())

    if len(yt) == 0:
        return 0.0
    return float(f1_score(yt, yp, average=average, zero_division=0))

In [33]:
# =============================================================================
# 13) Loaders (pass num stats)
# =============================================================================
def make_loaders(train_df, val_df, num_mean, num_std, batch_size=2, max_tokens=64):
    train_dataset = PageDataset(train_df, tokenizer=tokenizer, num_mean=num_mean, num_std=num_std, max_tokens=max_tokens)
    val_dataset   = PageDataset(val_df, tokenizer=tokenizer, num_mean=num_mean, num_std=num_std, max_tokens=max_tokens)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )
    return train_loader, val_loader

In [34]:
# =============================================================================
# 14) Optim: freeze warmup + differential LRs
# =============================================================================
def init_model_and_optim(lr_bert=5e-6, lr_other=1e-4, weight_decay=0.01):
    model = DOMAwareEventExtractor(
        text_model_name=MODEL_NAME,
        num_field_labels=len(LABELS),
        tag_vocab_size=len(TAG_VOCAB),
        parent_tag_vocab_size=len(PARENT_TAG_VOCAB),
        d_model=128,
        nhead=4,
        num_layers=2
    ).to(device)

    bert_params = []
    other_params = []
    for n, p in model.named_parameters():
        if n.startswith("text_encoder."):
            bert_params.append(p)
        else:
            other_params.append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": bert_params, "lr": lr_bert, "weight_decay": weight_decay},
            {"params": other_params, "lr": lr_other, "weight_decay": weight_decay},
        ]
    )
    return model, optimizer

def set_bert_trainable(model, trainable: bool):
    for p in model.text_encoder.parameters():
        p.requires_grad = trainable

# =============================================================================
# 15) Train / Eval epoch (losses include Other; BIO + in_event auxiliary)
# =============================================================================
def run_epoch(
    model,
    optimizer,
    loader,
    field_loss_fn,
    bio_loss_fn,
    in_event_loss_fn,
    w_bio=2.0,
    w_in_event=1.0,
    training=True
):
    model.train() if training else model.eval()
    total_loss = 0.0

    for batch in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in batch["enc"].items()}
        node_mask = batch["node_mask"].to(device).bool()

        field_y = batch["field_y"].to(device)
        bio_y   = batch["bio_y"].to(device)
        in_event_y = batch["in_event_y"].to(device)

        with torch.set_grad_enabled(training):
            field_logits, bio_logits, in_event_logits = model(
                enc=enc,
                node_offsets=batch["node_offsets"],
                node_mask=node_mask,
                tag_id=batch["tag_id"].to(device),
                parent_tag_id=batch["parent_tag_id"].to(device),
                num_feats=batch["num_feats"].to(device),
                bool_feats=batch["bool_feats"].to(device),
            )

            # Field loss: ALL nodes (including Other), excluding padding
            field_mask = node_mask & (field_y != -100)
            field_loss = field_loss_fn(field_logits[field_mask], field_y[field_mask])

            # BIO loss: excluding padding
            bio_mask = node_mask & (bio_y != -100)
            bio_loss = bio_loss_fn(bio_logits[bio_mask], bio_y[bio_mask])

            # in_event loss: excluding padding
            ie_mask = node_mask & (in_event_y != -100)
            in_event_loss = in_event_loss_fn(in_event_logits[ie_mask], in_event_y[ie_mask].float())

            loss = field_loss + w_bio * bio_loss + w_in_event * in_event_loss

            if training:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

        total_loss += loss.detach().item()

    return total_loss / max(1, len(loader))

In [35]:
# =============================================================================
# 16) Cross-val by source
# =============================================================================
N_SPLITS = min(5, len(cv_sources))
kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

# Peak-decoding hyperparams (tune these a bit if needed)
NMS_K = 1
MIN_GAP = 2
TOL = 1

EPOCHS = 20
FREEZE_EPOCHS = 4  # freeze DistilBERT for first few epochs

cv_results = []

for fold, (tr_idx, va_idx) in enumerate(kf.split(cv_sources), start=1):
    fold_train_sources = set(cv_sources[tr_idx])
    fold_val_sources   = set(cv_sources[va_idx])

    fold_train_df = df[df["source"].isin(fold_train_sources)].copy()
    fold_val_df   = df[df["source"].isin(fold_val_sources)].copy()

    num_mean, num_std = compute_num_stats(fold_train_df, STRUCT_COLS_NUM)

    field_loss_fn, bio_loss_fn, in_event_loss_fn = make_losses_for_train_df(
        fold_train_df,
        LABELS,
        device,
        other_scale=0.01,   # downweight Other strongly
        weight_cap=50.0
    )

    print(f"\n===== Fold {fold}/{N_SPLITS} =====")
    print("Train pages:", fold_train_df["source"].nunique(), "Val pages:", fold_val_df["source"].nunique())

    train_loader, val_loader = make_loaders(
        fold_train_df, fold_val_df,
        num_mean=num_mean, num_std=num_std,
        batch_size=2, max_tokens=64
    )

    model, optimizer = init_model_and_optim(lr_bert=5e-6, lr_other=1e-4)

    best = {"f1": -1.0, "th": 0.5, "state": None}

    # Freeze BERT initially
    set_bert_trainable(model, False)

    for epoch in range(EPOCHS):
        if epoch == FREEZE_EPOCHS:
            set_bert_trainable(model, True)

        tr_loss = run_epoch(
            model, optimizer, train_loader,
            field_loss_fn, bio_loss_fn, in_event_loss_fn,
            w_bio=2.0, w_in_event=1.0,
            training=True
        )
        va_loss = run_epoch(
            model, optimizer, val_loader,
            field_loss_fn, bio_loss_fn, in_event_loss_fn,
            w_bio=2.0, w_in_event=1.0,
            training=False
        )

        th, f1 = find_best_threshold_peak(
            val_loader, model, device,
            thresholds=np.linspace(0.05, 0.95, 19),
            nms_k=NMS_K, min_gap=MIN_GAP, tol=TOL
        )

        if f1 > best["f1"]:
            best["f1"] = f1
            best["th"] = th
            best["state"] = copy.deepcopy(model.state_dict())

        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:02d} | tr_loss={tr_loss:.4f} va_loss={va_loss:.4f} best_startF1={best['f1']:.4f} best_th={best['th']:.2f}")

    model.load_state_dict(best["state"])

    bp, br, bf1 = boundary_metrics_peak(val_loader, model, device, threshold=best["th"], nms_k=NMS_K, min_gap=MIN_GAP, tol=TOL)

    # Field metrics: ignore Other. You can choose restrict_to_true_in_event=True if you prefer.
    field_micro_f1 = field_metrics_fast(val_loader, model, device, label2id=label2id, average="micro", restrict_to_true_in_event=False)
    field_macro_f1 = field_metrics_fast(val_loader, model, device, label2id=label2id, average="macro", restrict_to_true_in_event=False)

    print(
        f"Fold {fold} | START (peak-based): P={bp:.4f} R={br:.4f} F1={bf1:.4f} (th={best['th']:.2f}, nms_k={NMS_K}, min_gap={MIN_GAP}, tol=±{TOL})"
        f" | field: microF1={field_micro_f1:.4f} macroF1={field_macro_f1:.4f}"
    )

    cv_results.append({
        "fold": fold,
        "bp": bp, "br": br, "bf1": bf1, "th": best["th"],
        "field_micro_f1": field_micro_f1,
        "field_macro_f1": field_macro_f1,
        "num_mean": num_mean,
        "num_std": num_std,
    })

print("\n===== CV Summary =====")
mean_bf1 = float(np.mean([x["bf1"] for x in cv_results])) if cv_results else 0.0
mean_bp  = float(np.mean([x["bp"]  for x in cv_results])) if cv_results else 0.0
mean_br  = float(np.mean([x["br"]  for x in cv_results])) if cv_results else 0.0
mean_th  = float(np.mean([x["th"]  for x in cv_results])) if cv_results else 0.5
mean_field_micro = float(np.mean([x["field_micro_f1"] for x in cv_results])) if cv_results else 0.0
mean_field_macro = float(np.mean([x["field_macro_f1"] for x in cv_results])) if cv_results else 0.0

print(f"START (peak) Mean F1: {mean_bf1:.4f}")
print(f"START (peak) Mean P : {mean_bp:.4f}")
print(f"START (peak) Mean R : {mean_br:.4f}")
print(f"START (peak) Mean th: {mean_th:.4f}")
print(f"Field Mean micro-F1 (ignore Other): {mean_field_micro:.4f}")
print(f"Field Mean macro-F1 (ignore Other): {mean_field_macro:.4f}")

best_th_cv = mean_th
print("Using CV-avg threshold:", best_th_cv)


===== Fold 1/5 =====
Train pages: 9 Val pages: 3


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 308.75it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
  output = torch._nested_tensor_from_mask(


Epoch 05 | tr_loss=3.0877 va_loss=3.4717 best_startF1=0.5000 best_th=0.20
Epoch 10 | tr_loss=1.8160 va_loss=2.9478 best_startF1=0.5000 best_th=0.20
Epoch 15 | tr_loss=1.0159 va_loss=2.7307 best_startF1=0.6286 best_th=0.20
Epoch 20 | tr_loss=0.5482 va_loss=2.9602 best_startF1=0.6667 best_th=0.15
Fold 1 | START (peak-based): P=0.5500 R=0.8462 F1=0.6667 (th=0.15, nms_k=1, min_gap=2, tol=±1) | field: microF1=0.7872 macroF1=0.8193

===== Fold 2/5 =====
Train pages: 9 Val pages: 3


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 388.88it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Epoch 05 | tr_loss=3.2468 va_loss=4.5417 best_startF1=0.2727 best_th=0.15
Epoch 10 | tr_loss=1.8907 va_loss=3.8869 best_startF1=0.3415 best_th=0.15
Epoch 15 | tr_loss=1.1671 va_loss=3.5111 best_startF1=0.5556 best_th=0.30
Epoch 20 | tr_loss=0.6552 va_loss=3.5826 best_startF1=0.5556 best_th=0.30
Fold 2 | START (peak-based): P=0.6250 R=0.5000 F1=0.5556 (th=0.30, nms_k=1, min_gap=2, tol=±1) | field: microF1=0.5714 macroF1=0.5905

===== Fold 3/5 =====
Train pages: 10 Val pages: 2


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 376.28it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Epoch 05 | tr_loss=3.3755 va_loss=3.1104 best_startF1=0.7692 best_th=0.50
Epoch 10 | tr_loss=2.1441 va_loss=2.6599 best_startF1=0.7692 best_th=0.50
Epoch 15 | tr_loss=1.2009 va_loss=2.7127 best_startF1=0.7692 best_th=0.50
Epoch 20 | tr_loss=0.6288 va_loss=2.9545 best_startF1=0.7692 best_th=0.50
Fold 3 | START (peak-based): P=1.0000 R=0.6250 F1=0.7692 (th=0.50, nms_k=1, min_gap=2, tol=±1) | field: microF1=0.3617 macroF1=0.1773

===== Fold 4/5 =====
Train pages: 10 Val pages: 2


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 394.87it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Epoch 05 | tr_loss=3.2640 va_loss=6.2487 best_startF1=0.2571 best_th=0.45
Epoch 10 | tr_loss=1.9724 va_loss=7.1554 best_startF1=0.2571 best_th=0.45
Epoch 15 | tr_loss=1.1729 va_loss=8.1777 best_startF1=0.2571 best_th=0.45
Epoch 20 | tr_loss=0.6589 va_loss=9.6351 best_startF1=0.2571 best_th=0.45
Fold 4 | START (peak-based): P=0.4737 R=0.1765 F1=0.2571 (th=0.45, nms_k=1, min_gap=2, tol=±1) | field: microF1=0.1769 macroF1=0.1371

===== Fold 5/5 =====
Train pages: 10 Val pages: 2


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 376.97it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Epoch 05 | tr_loss=3.5180 va_loss=4.7029 best_startF1=0.8788 best_th=0.45
Epoch 10 | tr_loss=2.0957 va_loss=4.6493 best_startF1=0.8788 best_th=0.45
Epoch 15 | tr_loss=1.2147 va_loss=4.4498 best_startF1=0.8788 best_th=0.45
Epoch 20 | tr_loss=0.6494 va_loss=4.6067 best_startF1=0.8788 best_th=0.45
Fold 5 | START (peak-based): P=0.8788 R=0.8788 F1=0.8788 (th=0.45, nms_k=1, min_gap=2, tol=±1) | field: microF1=0.3646 macroF1=0.3083

===== CV Summary =====
START (peak) Mean F1: 0.6255
START (peak) Mean P : 0.7055
START (peak) Mean R : 0.6053
START (peak) Mean th: 0.3700
Field Mean micro-F1 (ignore Other): 0.4524
Field Mean macro-F1 (ignore Other): 0.4065
Using CV-avg threshold: 0.37


In [36]:
# =============================================================================
# 17) Final training on ALL CV pages, evaluate on HOLDOUT TEST
# =============================================================================
final_train_df = df[df["source"].isin(set(cv_sources))].copy()

# Normalize using ALL final-train stats
final_num_mean, final_num_std = compute_num_stats(final_train_df, STRUCT_COLS_NUM)

final_train_loader, _ = make_loaders(
    final_train_df, final_train_df,
    num_mean=final_num_mean, num_std=final_num_std,
    batch_size=2, max_tokens=64
)

test_dataset = PageDataset(test_df, tokenizer=tokenizer, num_mean=final_num_mean, num_std=final_num_std, max_tokens=64)
test_loader  = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

field_loss_fn, bio_loss_fn, in_event_loss_fn = make_losses_for_train_df(
    final_train_df, LABELS, device, other_scale=0.01, weight_cap=50.0
)

final_model, final_optimizer = init_model_and_optim(lr_bert=5e-6, lr_other=1e-4)

# Freeze then unfreeze
set_bert_trainable(final_model, False)

EPOCHS_FINAL = 20
for epoch in range(EPOCHS_FINAL):
    if epoch == FREEZE_EPOCHS:
        set_bert_trainable(final_model, True)

    tr_loss = run_epoch(
        final_model, final_optimizer, final_train_loader,
        field_loss_fn, bio_loss_fn, in_event_loss_fn,
        w_bio=2.0, w_in_event=1.0,
        training=True
    )

    if (epoch + 1) % 5 == 0:
        print(f"[FINAL] Epoch {epoch+1:02d} | tr_loss={tr_loss:.4f}")

# Evaluate peak-based start detection on HOLDOUT
p, r, f1 = boundary_metrics_peak(test_loader, final_model, device, threshold=best_th_cv, nms_k=NMS_K, min_gap=MIN_GAP, tol=TOL)
print("\n===== HOLDOUT TEST (START / peak-based) =====")
print(f"Threshold={best_th_cv:.2f}  P={p:.4f}  R={r:.4f}  F1={f1:.4f} (nms_k={NMS_K}, min_gap={MIN_GAP}, tol=±{TOL})")

# Field metrics on HOLDOUT (ignore Other)
field_micro = field_metrics_fast(test_loader, final_model, device, label2id, average="micro", restrict_to_true_in_event=False)
field_macro = field_metrics_fast(test_loader, final_model, device, label2id, average="macro", restrict_to_true_in_event=False)
print("\n===== HOLDOUT TEST (FIELD / ignore Other) =====")
print(f"microF1={field_micro:.4f} macroF1={field_macro:.4f}")

print("\nHoldout test sources:", sorted(list(test_sources)))
print("Holdout start positives (total):", int(test_df["start_event"].sum()))
per_page = test_df.groupby("source")["start_event"].sum().sort_values(ascending=False)
print("Holdout start positives per page:\n", per_page.astype(int))

# Optionally tune threshold on holdout (NOT for reporting, just for debugging)
best_th_test, best_f1_test = find_best_threshold_peak(test_loader, final_model, device, nms_k=NMS_K, min_gap=MIN_GAP, tol=TOL)
print("\n[DEBUG] Best holdout threshold (peak metric):", best_th_test, "Best holdout F1:", best_f1_test)

Loading weights: 100%|██████████| 100/100 [00:00<00:00, 429.90it/s, Materializing param=transformer.layer.5.sa_layer_norm.weight]   
[1mDistilBertModel LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     |  | 
------------------------+------------+--+-
vocab_projector.bias    | UNEXPECTED |  | 
vocab_transform.weight  | UNEXPECTED |  | 
vocab_layer_norm.weight | UNEXPECTED |  | 
vocab_layer_norm.bias   | UNEXPECTED |  | 
vocab_transform.bias    | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


[FINAL] Epoch 05 | tr_loss=3.1594
[FINAL] Epoch 10 | tr_loss=2.0299
[FINAL] Epoch 15 | tr_loss=1.1433
[FINAL] Epoch 20 | tr_loss=0.6481

===== HOLDOUT TEST (START / peak-based) =====
Threshold=0.37  P=0.5000  R=0.2812  F1=0.3600 (nms_k=1, min_gap=2, tol=±1)

===== HOLDOUT TEST (FIELD / ignore Other) =====
microF1=0.9073 macroF1=0.5762

Holdout test sources: [np.str_('nacacnet.org_pattern_labeled'), np.str_('neacac_fall.net_pattern_labeled')]
Holdout start positives (total): 32
Holdout start positives per page:
 source
nacacnet.org_pattern_labeled       22
neacac_fall.net_pattern_labeled    10
Name: start_event, dtype: int64

[DEBUG] Best holdout threshold (peak metric): 0.05 Best holdout F1: 0.7341772146893126
