In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline,AutoModel

  torch.utils._pytree._register_pytree_node(


In [2]:
from src.utils import convert_predictions
from src.streaming import process_ontonotes_example, stream_sentence

In [3]:
ontonotes = load_dataset("conll2012_ontonotesv5", "english_v12")

In [4]:
max_sentences = 500 # it takes 2 whole days to precompute this...
sentences = []
for doc in ontonotes["train"]:
    for sent in doc['sentences']:
        if 'words' in sent and len(sent['words']) > 1:
            sentences.append(sent['words'])
    if len(sentences) >= max_sentences:
        break

In [None]:
print(sentences[:50])

In [5]:
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, pipeline
from seqeval.metrics import f1_score
from tqdm import tqdm
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [6]:
model_name = "dslim/bert-base-NER"
f1_threshold = 0.9

In [7]:
# Pretrained Models
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)
ner_pipeline = pipeline("ner", model=model_name, tokenizer=tokenizer, aggregation_strategy="simple")

  torch.utils._pytree._register_pytree_node(
Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


$$ \mathcal{L}(f) = \sum_{t \in \mathcal{I}f} \ell{\mathrm{NER}}(\hat{y}_t, y_t) + \lambda, C(\mathcal{I}_f), $$

In [8]:
# If can encapsulate this, then we need to do it.

def run_ner_on_tokens(tokens):
    text = " ".join(tokens)
    return ner_pipeline(text)

def get_f1_label(pred_bio, gold_bio_prefix, threshold):
    return int(f1_score([gold_bio_prefix], [pred_bio]) >= threshold)

def get_cls_embedding(tokens):
    text = " ".join(tokens)
    inputs = tokenizer(text, return_tensors="pt", truncation=True)
    with torch.no_grad():
        outputs = encoder(**inputs)
    return outputs.last_hidden_state[:, 0, :].squeeze().numpy()

# Main Part for dataset generation

In [9]:
X, y = [], []
sentence_ids = []
sentence_lengths = []

sentence_idx = 0  # Sentence counter

for sentence in tqdm(sentences, desc="Processing Sentences"):
    try:
        output = run_ner_on_tokens(sentence)
        labels = convert_predictions(sentence, output)

        prefix_count = 0  # Count how many prefixes this sentence has

        for k in range(1, len(sentence) + 1):
            prefix = sentence[:k]
            prefix_output = run_ner_on_tokens(prefix)
            pred_bio = convert_predictions(prefix, prefix_output)

            prefix_bio = labels[:k]
            label = get_f1_label(pred_bio, prefix_bio, threshold=f1_threshold)

            embedding = get_cls_embedding(prefix)

            X.append(embedding)
            y.append(label)
            sentence_ids.append(sentence_idx)
            prefix_count += 1

        sentence_lengths.append(prefix_count)
        sentence_idx += 1

    except Exception as e:
        print(f"Skipping sentence due to error: {e}")
        continue

X = np.array(X)
y = np.array(y)
sentence_ids = np.array(sentence_ids)
sentence_lengths = np.array(sentence_lengths)

np.savez("ner_trigger_dataset.npz", X=X, y=y, ids=sentence_ids, lengths=sentence_lengths)
print(f"\nSaved dataset with {len(X)} examples from {sentence_idx} sentences to ner_trigger_dataset.npz")


  _warn_prf(
Processing Sentences: 100%|██████████| 749/749 [23:25<00:00,  1.88s/it] 


Saved dataset with 11829 examples from 749 sentences to ner_trigger_dataset.npz





# Model Specific Stuff

In [10]:
# Preprocessing stuff

def load_ner_trigger_dataset(path="ner_trigger_dataset.npz"):
    data = np.load(path)
    X = data["X"]                 # CLS embeddings (num_prefixes, 768)
    y = data["y"]                 # Labels (0 or 1)
    sentence_ids = data["ids"]    # Sentence ID for each prefix
    sentence_lengths = data["lengths"]  # Number of prefixes per sentence
    return X, y, sentence_ids, sentence_lengths

X, y, sentence_ids, sentence_lengths = load_ner_trigger_dataset()

In [11]:
print(f"Total prefixes: {len(X)}")
print(f"Total sentences: {len(sentence_lengths)}")
print(f"Sentence 0 has {sentence_lengths[0]} prefixes.")
print(f"First 5 sentence IDs: {sentence_ids[:5]}")

Total prefixes: 11829
Total sentences: 749
Sentence 0 has 5 prefixes.
First 5 sentence IDs: [0 0 0 0 0]


In [12]:
class PrefixDataset(Dataset):
    def __init__(self, X, y, sentence_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.ids = torch.tensor(sentence_ids, dtype=torch.int64)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.ids[idx]

dataset = PrefixDataset(X, y, sentence_ids)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [13]:
model = nn.Sequential(
    nn.Linear(768, 128),
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid()
)

In [None]:
class NERTriggerLoss(nn.Module):
    def __init__(self, sentence_lengths):
        super().__init__()
        self.sentence_lengths = sentence_lengths  

    def forward(self, predictions, targets, sentence_ids):
        """
        predictions: tensor of shape (batch_size,) — model outputs (logits or probs)
        targets:     tensor of shape (batch_size,) — 0 or 1
        sentence_ids: tensor of shape (batch_size,) — maps each prefix to sentence index
        """
        losses = []
        for i in range(len(predictions)):
            pred = predictions[i]
            target = targets[i]
            sent_id = sentence_ids[i].item()

            # Get total length of that sentence
            total_len = self.sentence_lengths[sent_id]
            if total_len == 0:
                continue  # skip edge case

            # Compute current position within the sentence
            prefix_pos = (sentence_ids[:i] == sent_id).sum().item()

            # Delay penalty is linear here. maybe consider something else 
            delay_penalty = prefix_pos / total_len

            # Standard BCE loss
            bce = F.binary_cross_entropy_with_logits(pred, target.float(), reduction="none")

            # If the target is 1 (we should trigger), penalize late triggers
            if target == 1:
                # Penalize early triggers more — inverse of delay
                early_penalty = (1.0 - delay_penalty)

                # increase penatly here. atm it is 20%
                if delay_penalty <= 0.2:
                    early_penalty *= 2.0  # here double, but set as needed

                weighted_loss = bce * early_penalty

            else:
                weighted_loss = bce  

            losses.append(weighted_loss)

        return torch.stack(losses).mean()


In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = NERTriggerLoss(sentence_lengths)

In [19]:
epochs = 20
losses = []

for epoch in range(epochs):
    epoch_loss = 0.0
    for X_batch, y_batch, id_batch in dataloader:
        preds = model(X_batch).squeeze()
        loss = criterion(preds, y_batch, id_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}")

Epoch 1/20 — Loss: 0.7092
Epoch 2/20 — Loss: 0.7091
Epoch 3/20 — Loss: 0.7097
Epoch 4/20 — Loss: 0.7081
Epoch 5/20 — Loss: 0.7055
Epoch 6/20 — Loss: 0.7066
Epoch 7/20 — Loss: 0.7046
Epoch 8/20 — Loss: 0.7067
Epoch 9/20 — Loss: 0.7048
Epoch 10/20 — Loss: 0.7030
Epoch 11/20 — Loss: 0.7026
Epoch 12/20 — Loss: 0.7020
Epoch 13/20 — Loss: 0.7023
Epoch 14/20 — Loss: 0.7017
Epoch 15/20 — Loss: 0.7042
Epoch 16/20 — Loss: 0.7031
Epoch 17/20 — Loss: 0.7035
Epoch 18/20 — Loss: 0.7022
Epoch 19/20 — Loss: 0.7007
Epoch 20/20 — Loss: 0.7031


In [21]:
model.eval()
predictions_by_sentence = {}
truths_by_sentence = {}

with torch.no_grad():
    for X_batch, y_batch, id_batch in dataloader:
        outputs = model(X_batch).squeeze()
        preds_binary = (outputs >= 0.5).int()

        for i in range(len(X_batch)):
            sid = id_batch[i].item()
            if sid not in predictions_by_sentence:
                predictions_by_sentence[sid] = []
                truths_by_sentence[sid] = []
            predictions_by_sentence[sid].append(preds_binary[i].item())
            truths_by_sentence[sid].append(y_batch[i].item())

# Print examples
for sid in sorted(predictions_by_sentence.keys())[:20]:
    pred_seq = predictions_by_sentence[sid]
    truth_seq = truths_by_sentence[sid]
    pred_trigger = next((i for i, val in enumerate(pred_seq) if val == 1), None)
    true_trigger = next((i for i, val in enumerate(truth_seq) if val == 1), None)

    print(f"Sentence {sid}:")
    print(f"Ground Truth Trigger : {true_trigger}")
    print(f"Model Trigger        : {pred_trigger}")
    print(f"Prediction Sequence  : {pred_seq}")
    print(f"Ground Truth Sequence: {truth_seq}")
    print("-" * 60)

Sentence 0:
Ground Truth Trigger : None
Model Trigger        : None
Prediction Sequence  : [0, 0, 0, 0, 0]
Ground Truth Sequence: [0.0, 0.0, 0.0, 0.0, 0.0]
------------------------------------------------------------
Sentence 1:
Ground Truth Trigger : 5
Model Trigger        : 5
Prediction Sequence  : [0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0]
Ground Truth Sequence: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0]
------------------------------------------------------------
Sentence 2:
Ground Truth Trigger : 5
Model Trigger        : 0
Prediction Sequence  : [1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]
Ground Truth Sequence: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]
------------------------------------------------------------
Sentence 3:
Ground Truth Trigger : 2
Model Trigger        : 0
Prediction Sequence  : [1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1]
Ground Truth Sequence: [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 