In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

In [2]:
data = np.load('data/ontonotes_embeddings_full.npz')

In [3]:
print(f"Data keys: {data.keys()}")

Data keys: KeysView(NpzFile 'data/ontonotes_embeddings_full.npz' with keys: X, Y)


In [4]:
X = data['X']
y = data['Y']

print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")


X shape: (2200865, 768)
y shape: (2200865,)


In [5]:
num_positives = np.sum(y == 1)
num_negatives = np.sum(y == 0)
print(f"Positives: {num_positives}, Negatives: {num_negatives}")


Positives: 125904, Negatives: 2074961


In [6]:
# this part fries my computer, so I limit the dataset to 200k samples, but still does

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
#X_tensor = X_tensor[:200000]
#y_tensor = y_tensor[:200000]

# Create Dataset
dataset = TensorDataset(X_tensor, y_tensor)

# Split into training and validation sets, later on we use the real validation set, but for now...
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512)


In [8]:
print(len(X_tensor))
print(len(y_tensor))

2200865
2200865


In [None]:
class confidence_model(nn.Module):
    def __init__(self):
        super(confidence_model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
model = confidence_model()
optimizer = optim.Adam(model.parameters(), lr=0.001) 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# because this dataset is unbalanced, we use a weighted loss function
pos_weight = torch.tensor([num_negatives / num_positives], dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

epochs = 50
for epoch in range(epochs):  
    model.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device).unsqueeze(1)

        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Training Loss: {total_loss/len(train_loader):.4f}")


Epoch 1, Training Loss: 0.7141
Epoch 2, Training Loss: 0.6102
Epoch 3, Training Loss: 0.5770
Epoch 4, Training Loss: 0.5560
Epoch 5, Training Loss: 0.5508
Epoch 6, Training Loss: 0.5376
Epoch 7, Training Loss: 0.5303
Epoch 8, Training Loss: 0.5271
Epoch 9, Training Loss: 0.5230
Epoch 10, Training Loss: 0.5230
Epoch 11, Training Loss: 0.5223
Epoch 12, Training Loss: 0.5169
Epoch 13, Training Loss: 0.5111
Epoch 14, Training Loss: 0.5116
Epoch 15, Training Loss: 0.5044
Epoch 16, Training Loss: 0.5010
Epoch 17, Training Loss: 0.5009
Epoch 18, Training Loss: 0.4960
Epoch 19, Training Loss: 0.4946
Epoch 20, Training Loss: 0.4902
Epoch 21, Training Loss: 0.4871
Epoch 22, Training Loss: 0.4858
Epoch 23, Training Loss: 0.4876
Epoch 24, Training Loss: 0.4815
Epoch 25, Training Loss: 0.4839
Epoch 26, Training Loss: 0.4816
Epoch 27, Training Loss: 0.4840
Epoch 28, Training Loss: 0.4807
Epoch 29, Training Loss: 0.4814
Epoch 30, Training Loss: 0.4777
Epoch 31, Training Loss: 0.4757
Epoch 32, Trainin

In [None]:
# Saving the loss, for deciding on the epochs later on ...
"""""
Epoch 1, Training Loss: 0.7141
Epoch 2, Training Loss: 0.6102
Epoch 3, Training Loss: 0.5770
Epoch 4, Training Loss: 0.5560
Epoch 5, Training Loss: 0.5508
Epoch 6, Training Loss: 0.5376
Epoch 7, Training Loss: 0.5303
Epoch 8, Training Loss: 0.5271
Epoch 9, Training Loss: 0.5230
Epoch 10, Training Loss: 0.5230
Epoch 11, Training Loss: 0.5223
Epoch 12, Training Loss: 0.5169
Epoch 13, Training Loss: 0.5111
Epoch 14, Training Loss: 0.5116
Epoch 15, Training Loss: 0.5044
Epoch 16, Training Loss: 0.5010
Epoch 17, Training Loss: 0.5009
Epoch 18, Training Loss: 0.4960
Epoch 19, Training Loss: 0.4946
Epoch 20, Training Loss: 0.4902
Epoch 21, Training Loss: 0.4871
Epoch 22, Training Loss: 0.4858
Epoch 23, Training Loss: 0.4876
Epoch 24, Training Loss: 0.4815
Epoch 25, Training Loss: 0.4839
Epoch 26, Training Loss: 0.4816
Epoch 27, Training Loss: 0.4840
Epoch 28, Training Loss: 0.4807
Epoch 29, Training Loss: 0.4814
Epoch 30, Training Loss: 0.4777
Epoch 31, Training Loss: 0.4757
Epoch 32, Training Loss: 0.4736
Epoch 33, Training Loss: 0.4740
Epoch 34, Training Loss: 0.4738
Epoch 35, Training Loss: 0.4718
Epoch 36, Training Loss: 0.4690
Epoch 37, Training Loss: 0.4713
Epoch 38, Training Loss: 0.4680
Epoch 39, Training Loss: 0.4687
Epoch 40, Training Loss: 0.4662
Epoch 41, Training Loss: 0.4663
Epoch 42, Training Loss: 0.4662
Epoch 43, Training Loss: 0.4648
Epoch 44, Training Loss: 0.4630
Epoch 45, Training Loss: 0.4625
Epoch 46, Training Loss: 0.4608
Epoch 47, Training Loss: 0.4611
Epoch 48, Training Loss: 0.4609
Epoch 49, Training Loss: 0.4602
Epoch 50, Training Loss: 0.4582
"""""

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

model.eval()
correct = 0
total = 0

all_preds = []
all_labels = []

with torch.no_grad():
    for X_batch, y_batch in val_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device).unsqueeze(1)

        logits = model(X_batch)
        probs = torch.sigmoid(logits)  
        predicted = (probs > 0.8).float()

        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(y_batch.cpu().numpy())

accuracy = correct / total
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")


Validation Accuracy: 0.9483
Precision: 0.5341, Recall: 0.7758, F1-score: 0.6327


# Inference

In [19]:
from transformers import AutoTokenizer, AutoModel

In [20]:
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

In [None]:
import torch
import torch.nn.functional as F 

def infer_entity_transitions(sentence, tokenizer, bert_model, classifier_model, threshold=0.8):
    """
    Incrementally feeds partial sentences into BERT and classifies via trained classifier.
    Matches precompute behavior exactly. Only prints results step by step.
    Applies sigmoid to logits before thresholding.
    """
    bert_model.eval()
    classifier_model.eval()

    words = sentence.strip().split()

    print(f"\n Inference for: \"{sentence}\"\n")

    for i in range(1, len(words) + 1):
        partial_sentence = " ".join(words[:i])
        inputs = tokenizer(partial_sentence, return_tensors="pt")

        with torch.no_grad():
            outputs = bert_model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :]  

            logits = classifier_model(cls_embedding)  
            prob = torch.sigmoid(logits).item()       
            label = int(prob > threshold)

        print(f"Step {i:2}: {partial_sentence:60} → score: {prob:.3f}, label: {label}")


In [25]:
sentence = "Apple, Microsoft and Google are tech giants."
# sentence = "On the afternoon of August 22 , Peng Dehuai was listening to the combat operation director report on battle developments at Eighth Route Army operational headquarters ."
infer_entity_transitions(sentence, tokenizer, bert_model, model)


 Inference for: "Apple, Microsoft and Google are tech giants."

Step  1: Apple,                                                       → score: 0.000, label: 0
Step  2: Apple, Microsoft                                             → score: 0.953, label: 1
Step  3: Apple, Microsoft and                                         → score: 0.000, label: 0
Step  4: Apple, Microsoft and Google                                  → score: 0.962, label: 1
Step  5: Apple, Microsoft and Google are                              → score: 0.000, label: 0
Step  6: Apple, Microsoft and Google are tech                         → score: 0.147, label: 0
Step  7: Apple, Microsoft and Google are tech giants.                 → score: 0.003, label: 0


In [29]:
model.eval()
with torch.no_grad():
    logits = model(X_tensor)
    probs = torch.sigmoid(logits).squeeze()

threshold = 0.8
Y_pred = (probs >= threshold).int().numpy()
total_timesteps = len(y)
bert_calls_baseline = total_timesteps
bert_calls_classifier = Y_pred.sum()
reduction = 1 - (bert_calls_classifier / bert_calls_baseline)

true_positives_captured = sum(
    1 for yt, yp in zip(y, Y_pred) if yt == 1 and yp == 1
)
total_entity_completions = sum(yt == 1 for yt in y)

entity_recall = true_positives_captured / total_entity_completions if total_entity_completions > 0 else 0


print("=== Inference Evaluation Summary ===")
print(f"Total Timesteps: {total_timesteps}")
print(f"BERT Calls (Baseline): {bert_calls_baseline}")
print(f"BERT Calls (Classifier-Gated): {bert_calls_classifier}")
print(f"Reduction in BERT Calls: {reduction:.2%}")
print(f"Entity Completion Recall (Y=1 captured): {entity_recall:.2%}")


=== Inference Evaluation Summary ===
Total Timesteps: 2200865
BERT Calls (Baseline): 2200865
BERT Calls (Classifier-Gated): 182468
Reduction in BERT Calls: 91.71%
Entity Completion Recall (Y=1 captured): 78.10%
