# Window size model (approach 2) Baseline

This notebook implements the sliding window baseline for near real-time Named Entity Recognition (NER).

## 1. Setup and preparation

First, let's import the necessary libraries and set up our environment.

In [33]:
%pip install datasets transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [34]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import random
from tqdm import tqdm
import numpy as np
import sys

In [55]:
sys.path.append("src")
from utils import convert_ids_to_bio, convert_predictions
from window_slide_model import WindowSlideModel
from metrics import Metrics

## 2. Loading Dataset

In [36]:
# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)
print(f"Dataset loaded with splits: {ontonotes.keys()}")

Dataset loaded with splits: dict_keys(['train', 'validation', 'test'])


## 3. Creating all window sizes of size 6 accross test split

In [37]:
windows = []
bio_windows = []
SPAN_LENGTH = 6

# Iterate through the test split
for doc in ontonotes["test"]:
    curr_window = []
    curr_bio_window = []
    # Fix: Sometimes doc['sentences'] is a list of lists, so we need to flatten it
    if isinstance(doc['sentences'], list) and isinstance(doc['sentences'][0], list):
        doc['sentences'] = [sentence for sublist in doc['sentences'] for sentence in sublist]
    for sentence in doc["sentences"]:
        for idx, word in enumerate(sentence['words']):
            curr_window.append(word)
            curr_bio_window.append(sentence['named_entities'][idx])
            # If the current window reaches the defined span length, add it to the list
            if len(curr_window) == SPAN_LENGTH:
                windows.append(curr_window.copy())
                bio_windows.append(convert_ids_to_bio(curr_bio_window))
                # Slide the window by one position
                curr_window = curr_window[1:]
                curr_bio_window = curr_bio_window[1:]

print(f"Total windows created: {len(windows)}")
ix = random.randint(0, len(windows) - 1)
print(f"Example window: {windows[ix]}, BIO: {bio_windows[ix]}")


Total windows created: 224128
Example window: ['equipment', 'and', 'the', 'engineers', 'and', 'they'], BIO: ['O', 'O', 'O', 'O', 'O', 'O']


In [38]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
print(f"Using device: {device}")

bert_model.to(device)
bert_model.eval()

Using device: mps


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [39]:
# Calculate all embeddings for the windows in a batch-wise manner
batch_size = 128  # Adjust batch size based on your GPU memory
embeddings = []
for i in tqdm(range(0, len(windows), batch_size), desc="Computing CLS token for windows"):
    batch = windows[i:i + batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, is_split_into_words=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    with torch.no_grad():
        outputs = bert_model(**inputs)
    
    cls_token = outputs.last_hidden_state[:, 0, :]
    embeddings.append(cls_token.cpu().numpy())

embeddings = np.concatenate(embeddings, axis=0)

Computing CLS token for windows: 100%|██████████| 1751/1751 [01:56<00:00, 15.07it/s]


In [42]:
np.savez("data/ner_trigger_dataset_test_embeddings.npz", windows=windows, bio_windows=bio_windows, embeddings=embeddings)

In [43]:
data = np.load("data/ner_trigger_dataset_test_embeddings.npz")
windows, bio_windows, embeddings = data['windows'], data['bio_windows'], data['embeddings']
print(f"Shape of data: {windows.shape}; {bio_windows.shape}; {embeddings.shape}")

Shape of data: (224128, 6); (224128, 6); (224128, 768)


## 4. Implementing sliding window baseline

In [46]:
window_model = WindowSlideModel(embeddings.shape[1])
window_model.load_state_dict(torch.load("models/bert_cls_pooling_model.pkl", weights_only=True))
window_model.to(device)
window_model.eval()

WindowSlideModel(
  (net): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [50]:
# Determining invocation times
batch_size = 128  # Adjust batch size based on your GPU memory
y_hat = []
with torch.no_grad():
    for i in tqdm(range(0, len(embeddings), batch_size), desc="Predicting BIO tags"):
        cls_token = torch.tensor(embeddings[i:i + batch_size]).to(device)
        logits = window_model(cls_token)
        y_hat_batch = (torch.sigmoid(logits) > 0.5).int()  # Convert logits to binary predictions

        y_hat.append(y_hat_batch.cpu().numpy())

# unbatch of y_hat
y_hat = np.concatenate(y_hat, axis=0)
print(f"Shape of predictions: {y_hat.shape}")

Predicting BIO tags: 100%|██████████| 1751/1751 [00:00<00:00, 1993.53it/s]

Shape of predictions: (224128,)





In [53]:
print(f"Number of invocations: {np.sum(y_hat)}")

Number of invocations: 52502


## 5. Evaludating sliding window

In [65]:
sliding_window_metrics = Metrics()

# For sliding window evaluation, we need to simulate the streaming process
# We'll reconstruct the document token by token and evaluate at each invocation point

# First, let's load a pre-trained NER model to get actual BIO predictions when invoked
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import torch.nn.functional as F

# Load NER model for actual predictions
ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_model.eval()
ner_model.to(device)
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")

# Function to get NER predictions for a window
def get_ner_predictions(tokens_window):
    """Get NER predictions for a window of tokens."""
    outputs = ner_pipeline(tokens_window)
    bio_tags = convert_predictions(tokens_window, outputs)
    return bio_tags

# Group windows by document to reconstruct the streaming process
# We need to simulate how the sliding window would work in practice
print("Evaluating sliding window model...")

# Since we have sliding windows, we need to reconstruct the original document structure
# and simulate the streaming evaluation process

# For simplicity, let's evaluate on a subset of windows to demonstrate the approach
num_windows_to_evaluate = len(windows)
print(f"Evaluating on {num_windows_to_evaluate} windows...")

for i in tqdm(range(num_windows_to_evaluate), desc="Evaluating windows"):
    # Get the current window and its prediction
    current_window = windows[i]
    current_bio_ground_truth = bio_windows[i]
    window_prediction = y_hat[i]  # 1 if model says invoke NER, 0 otherwise
    
    # Simulate the streaming process:
    # - We have tokens coming in one by one
    # - At each step, we decide whether to invoke NER based on the sliding window prediction
    # - When we invoke NER, we get predictions for the current window
    
    predictions_for_evaluation = []
    
    # If the sliding window model says to invoke NER (prediction = 1)
    if window_prediction == 1:
        # Get NER predictions for this window
        ner_predictions = get_ner_predictions(" ".join(current_window))
        predictions_for_evaluation.append(ner_predictions)

    # If we have predictions to evaluate
    if predictions_for_evaluation:
        # Evaluate metrics for this window
        ground_truth = (current_window, current_bio_ground_truth)
        sliding_window_metrics.evaluate_metrics(ground_truth, predictions_for_evaluation)

print("Evaluation completed!")
sliding_window_metrics.print_metrics()


Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use mps:0


Evaluating sliding window model...
Evaluating on 224128 windows...


Evaluating windows: 100%|██████████| 224128/224128 [06:58<00:00, 535.36it/s]  

Evaluation completed!
Metrics:
Total NER invocations: 52502
Avg TTFD: N/A
Entity Type          TP         TN         FP (#B-/I-MISC)      FN        
----------------------------------------------------------------------
O                    N/A        221345     0                    N/A       
B-PERSON             0          N/A        0                    9694      
I-PERSON             0          N/A        0                    6910      
B-NORP               0          N/A        0                    5085      
I-NORP               0          N/A        0                    832       
B-FAC                0          N/A        0                    546       
I-FAC                0          N/A        0                    793       
B-ORG                0          N/A        0                    8190      
I-ORG                0          N/A        0                    9140      
B-GPE                0          N/A        0                    12962     
I-GPE                0        


