In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

In [2]:
data = np.load('data/ner_trigger_dataset_embeddings.npz')

In [3]:
print(f"Data keys: {data.keys()}")

Data keys: KeysView(NpzFile 'data/ner_trigger_dataset_embeddings.npz' with keys: X, y)


In [4]:
X = data['X']
y = data['y']

print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")


X shape: (2148223, 768)
y shape: (2148223,)


In [5]:
num_positives = np.sum(y == 1)
num_negatives = np.sum(y == 0)
print(f"Positives: {num_positives}, Negatives: {num_negatives}")

Positives: 519979, Negatives: 1628244


In [45]:
data_raw = np.load("data/ner_trigger_dataset.npz")
X_raw = data_raw['X']
y_raw = data_raw['y']
for i in range(40):
    print(f"Raw X[{i+50}]: {X_raw[i+50]}, Raw y[{i+50}]: {y_raw[i+50]}")

Raw X[50]: ['of' 'a' 'primary' 'stele' ',' 'secondary'], Raw y[50]: 0
Raw X[51]: ['a' 'primary' 'stele' ',' 'secondary' 'steles'], Raw y[51]: 0
Raw X[52]: ['primary' 'stele' ',' 'secondary' 'steles' ','], Raw y[52]: 0
Raw X[53]: ['stele' ',' 'secondary' 'steles' ',' 'a'], Raw y[53]: 0
Raw X[54]: [',' 'secondary' 'steles' ',' 'a' 'huge'], Raw y[54]: 0
Raw X[55]: ['secondary' 'steles' ',' 'a' 'huge' 'round'], Raw y[55]: 0
Raw X[56]: ['steles' ',' 'a' 'huge' 'round' 'sculpture'], Raw y[56]: 0
Raw X[57]: [',' 'a' 'huge' 'round' 'sculpture' 'and'], Raw y[57]: 0
Raw X[58]: ['a' 'huge' 'round' 'sculpture' 'and' 'beacon'], Raw y[58]: 0
Raw X[59]: ['huge' 'round' 'sculpture' 'and' 'beacon' 'tower'], Raw y[59]: 0
Raw X[60]: ['round' 'sculpture' 'and' 'beacon' 'tower' ','], Raw y[60]: 0
Raw X[61]: ['sculpture' 'and' 'beacon' 'tower' ',' 'and'], Raw y[61]: 0
Raw X[62]: ['and' 'beacon' 'tower' ',' 'and' 'the'], Raw y[62]: 0
Raw X[63]: ['beacon' 'tower' ',' 'and' 'the' 'Great'], Raw y[63]: 0
Raw X[6

In [None]:
# this part fries my computer, so I limit the dataset to 200k samples, but still does

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Create Dataset
dataset = TensorDataset(X_tensor, y_tensor)

# Split into training and validation sets, later on we use the real validation set, but for now...
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512)


In [7]:
print(len(X_tensor))
print(len(y_tensor))

2148223
2148223


In [10]:
class window_model(nn.Module):
    def __init__(self):
        super(window_model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
model = window_model()
optimizer = optim.Adam(model.parameters(), lr=0.001) 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# because this dataset is unbalanced, we use a weighted loss function, maybe not here but it is still implemented.
# the ratio is only one to three.
pos_weight = torch.tensor([num_negatives / num_positives], dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

epochs = 50
for epoch in range(epochs):  
    model.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device).unsqueeze(1)

        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Training Loss: {total_loss/len(train_loader):.4f}")


Epoch 1, Training Loss: 0.5471
Epoch 2, Training Loss: 0.4882
Epoch 3, Training Loss: 0.4750
Epoch 4, Training Loss: 0.4664
Epoch 5, Training Loss: 0.4617
Epoch 6, Training Loss: 0.4596
Epoch 7, Training Loss: 0.4551
Epoch 8, Training Loss: 0.4517
Epoch 9, Training Loss: 0.4507
Epoch 10, Training Loss: 0.4486
Epoch 11, Training Loss: 0.4457
Epoch 12, Training Loss: 0.4433
Epoch 13, Training Loss: 0.4449
Epoch 14, Training Loss: 0.4434
Epoch 15, Training Loss: 0.4413
Epoch 16, Training Loss: 0.4394
Epoch 17, Training Loss: 0.4381
Epoch 18, Training Loss: 0.4367
Epoch 19, Training Loss: 0.4364
Epoch 20, Training Loss: 0.4344
Epoch 21, Training Loss: 0.4332
Epoch 22, Training Loss: 0.4318
Epoch 23, Training Loss: 0.4317
Epoch 24, Training Loss: 0.4305
Epoch 25, Training Loss: 0.4303
Epoch 26, Training Loss: 0.4293
Epoch 27, Training Loss: 0.4281
Epoch 28, Training Loss: 0.4287
Epoch 29, Training Loss: 0.4295
Epoch 30, Training Loss: 0.4284
Epoch 31, Training Loss: 0.4276
Epoch 32, Trainin

In [None]:
"""""
Epoch 1, Training Loss: 0.5471
Epoch 2, Training Loss: 0.4882
Epoch 3, Training Loss: 0.4750
Epoch 4, Training Loss: 0.4664
Epoch 5, Training Loss: 0.4617
Epoch 6, Training Loss: 0.4596
Epoch 7, Training Loss: 0.4551
Epoch 8, Training Loss: 0.4517
Epoch 9, Training Loss: 0.4507
Epoch 10, Training Loss: 0.4486
Epoch 11, Training Loss: 0.4457
Epoch 12, Training Loss: 0.4433
Epoch 13, Training Loss: 0.4449
Epoch 14, Training Loss: 0.4434
Epoch 15, Training Loss: 0.4413
Epoch 16, Training Loss: 0.4394
Epoch 17, Training Loss: 0.4381
Epoch 18, Training Loss: 0.4367
Epoch 19, Training Loss: 0.4364
Epoch 20, Training Loss: 0.4344
Epoch 21, Training Loss: 0.4332
Epoch 22, Training Loss: 0.4318
Epoch 23, Training Loss: 0.4317
Epoch 24, Training Loss: 0.4305
Epoch 25, Training Loss: 0.4303
Epoch 26, Training Loss: 0.4293
Epoch 27, Training Loss: 0.4281
Epoch 28, Training Loss: 0.4287
Epoch 29, Training Loss: 0.4295
Epoch 30, Training Loss: 0.4284
Epoch 31, Training Loss: 0.4276
Epoch 32, Training Loss: 0.4266
Epoch 33, Training Loss: 0.4266
Epoch 34, Training Loss: 0.4261
Epoch 35, Training Loss: 0.4256
Epoch 36, Training Loss: 0.4251
Epoch 37, Training Loss: 0.4249
Epoch 38, Training Loss: 0.4237
Epoch 39, Training Loss: 0.4230
Epoch 40, Training Loss: 0.4229
Epoch 41, Training Loss: 0.4222
Epoch 42, Training Loss: 0.4227
Epoch 43, Training Loss: 0.4218
Epoch 44, Training Loss: 0.4216
Epoch 45, Training Loss: 0.4208
Epoch 46, Training Loss: 0.4203
Epoch 47, Training Loss: 0.4204
Epoch 48, Training Loss: 0.4200
Epoch 49, Training Loss: 0.4197
Epoch 50, Training Loss: 0.4192
"""""

In [13]:
from sklearn.metrics import precision_score, recall_score, f1_score

model.eval()
correct = 0
total = 0

all_preds = []
all_labels = []

with torch.no_grad():
    for X_batch, y_batch in val_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device).unsqueeze(1)

        logits = model(X_batch)
        probs = torch.sigmoid(logits)  
        predicted = (probs > 0.8).float()

        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(y_batch.cpu().numpy())

accuracy = correct / total
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")


Validation Accuracy: 0.8937
Precision: 0.7897, Recall: 0.7639, F1-score: 0.7766


# Inference

In [16]:
from transformers import AutoTokenizer, AutoModel

In [17]:
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

In [19]:
# np.load("data/ner_trigger_dataset_validation.npz")

In [39]:
def inference(sentence, model, window_size=6, tokenizer=tokenizer, bert_model=bert_model, threshold=0.8):
    model.eval()
    words = sentence.strip().split()
    
    for i in range(len(words) - window_size + 1):
        window_words = words[i:i+window_size]
        window = " ".join(window_words)


        inputs = tokenizer(window, return_tensors="pt")
        
        with torch.no_grad():
            outputs = bert_model(**inputs)
            cls_token = outputs.last_hidden_state[:, 0, :]  
            
            logits = model(cls_token)
            prob = torch.sigmoid(logits).item()       
            label = int(prob > threshold)
            print(f"Window {i}: {window}")
            print(f"Probability: {prob:.4f}, Label: {label}")

    return 

In [48]:
sentence = "On August 17 , the Taiwan military held the Lianhsing 94 amphibious landing exercise , testing and enhancing the army 's response"
inference(sentence, model)  # Example usage

Window 0: On August 17 , the Taiwan
Probability: 0.8930, Label: 1
Window 1: August 17 , the Taiwan military
Probability: 0.9952, Label: 1
Window 2: 17 , the Taiwan military held
Probability: 0.9769, Label: 1
Window 3: , the Taiwan military held the
Probability: 0.9644, Label: 1
Window 4: the Taiwan military held the Lianhsing
Probability: 0.8235, Label: 1
Window 5: Taiwan military held the Lianhsing 94
Probability: 0.5979, Label: 0
Window 6: military held the Lianhsing 94 amphibious
Probability: 0.8810, Label: 1
Window 7: held the Lianhsing 94 amphibious landing
Probability: 0.9220, Label: 1
Window 8: the Lianhsing 94 amphibious landing exercise
Probability: 0.9233, Label: 1
Window 9: Lianhsing 94 amphibious landing exercise ,
Probability: 0.9280, Label: 1
Window 10: 94 amphibious landing exercise , testing
Probability: 0.8432, Label: 1
Window 11: amphibious landing exercise , testing and
Probability: 0.2078, Label: 0
Window 12: landing exercise , testing and enhancing
Probability: 0.0

In [47]:
model.eval()
with torch.no_grad():
    logits = model(X_tensor)
    probs = torch.sigmoid(logits).squeeze()

threshold = 0.8
Y_pred = (probs >= threshold).int().numpy()
total_timesteps = len(y)
bert_calls_baseline = total_timesteps
bert_calls_classifier = Y_pred.sum()
reduction = 1 - (bert_calls_classifier / bert_calls_baseline)

true_positives_captured = sum(
    1 for yt, yp in zip(y, Y_pred) if yt == 1 and yp == 1
)
total_entity_completions = sum(yt == 1 for yt in y)

entity_recall = true_positives_captured / total_entity_completions if total_entity_completions > 0 else 0


print("=== Inference Evaluation Summary ===")
print(f"Total Timesteps: {total_timesteps}")
print(f"BERT Calls (Baseline): {bert_calls_baseline}")
print(f"BERT Calls (Classifier-Gated): {bert_calls_classifier}")
print(f"Reduction in BERT Calls: {reduction:.2%}")
print(f"Entity Completion Recall (Y=1 captured): {entity_recall:.2%}")

=== Inference Evaluation Summary ===
Total Timesteps: 2148223
BERT Calls (Baseline): 2148223
BERT Calls (Classifier-Gated): 504149
Reduction in BERT Calls: 76.53%
Entity Completion Recall (Y=1 captured): 76.95%
