# Window size model (approach 2) Baseline

This notebook implements the sliding window baseline for near real-time Named Entity Recognition (NER).

## 1. Setup and preparation

First, let's import the necessary libraries and set up our environment.

In [70]:
%pip install datasets transformers

Note: you may need to restart the kernel to use updated packages.


In [71]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import random
from tqdm import tqdm
import numpy as np
import sys

In [72]:
sys.path.append("src")
from utils import convert_ids_to_bio, convert_predictions
from window_slide_model import WindowSlideModel
from metrics import Metrics

In [88]:
from importlib import reload
import metrics
reload(metrics)
from metrics import Metrics

## 2. Loading Dataset

In [73]:
# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)
print(f"Dataset loaded with splits: {ontonotes.keys()}")

Dataset loaded with splits: dict_keys(['train', 'validation', 'test'])


## 3. Creating all window sizes of size 6 accross test split

In [74]:
num_windows = 0
windows = []
SPAN_LENGTH = 6

# Iterate through the test split
for doc in ontonotes["test"]:
    # Fix: Sometimes doc['sentences'] is a list of lists, so we need to flatten it
    if isinstance(doc['sentences'], list) and isinstance(doc['sentences'][0], list):
        doc['sentences'] = [sentence for sublist in doc['sentences'] for sentence in sublist]
    for sentence in doc["sentences"]:
        curr_window = []
        bio_tags = convert_ids_to_bio(sentence['named_entities'])
        sentence_windows = []
        for idx, word in enumerate(sentence['words']):
            curr_window.append(word)
            # If the current window reaches the defined span length, add it to the list
            if len(curr_window) == SPAN_LENGTH:
                sentence_windows.append(curr_window.copy())
                # Slide the window by one position
                curr_window = curr_window[1:]
                num_windows += 1

        if len(curr_window) < SPAN_LENGTH:  # If there are remaining words in the current window
            sentence_windows.append(curr_window.copy())
            num_windows += 1

        windows.append((bio_tags, sentence_windows))

print(f"Total windows created: {num_windows}")
ix = random.randint(0, len(windows) - 1)
print(f"Example window: {windows[ix][1]}, BIO: {windows[ix][0]}")


Total windows created: 183876
Example window: [['Aha', ',', 'aha', '.']], BIO: ['O', 'O', 'O', 'O']


In [38]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
print(f"Using device: {device}")

bert_model.to(device)
bert_model.eval()

Using device: mps


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [75]:
all_windows = []
for _, sentence_windows in tqdm(windows):
    all_windows.extend(sentence_windows)

print(f"Total windows to process: {len(all_windows)}")

100%|██████████| 12217/12217 [00:00<00:00, 2750204.59it/s]

Total windows to process: 183876





In [76]:
# Calculate all embeddings for the windows in a batch-wise manner
batch_size = 128  # Adjust batch size based on your GPU memory
embeddings = []
for i in tqdm(range(0, len(all_windows), batch_size), desc="Computing CLS token for windows"):
    batch = all_windows[i:i + batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, is_split_into_words=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    with torch.no_grad():
        outputs = bert_model(**inputs)
    
    cls_token = outputs.last_hidden_state[:, 0, :]
    embeddings.append(cls_token.cpu().numpy())

embeddings = np.concatenate(embeddings, axis=0)

Computing CLS token for windows: 100%|██████████| 1437/1437 [01:39<00:00, 14.48it/s]


In [81]:
np.savez("data/ner_trigger_dataset_test_embeddings_sentences.npz", embeddings=embeddings)

In [82]:
data = np.load("data/ner_trigger_dataset_test_embeddings_sentences.npz")
embeddings = data['embeddings']
print(f"Shape of embeddings: {embeddings.shape}")

Shape of embeddings: (183876, 768)


## 4. Implementing sliding window baseline

In [46]:
window_model = WindowSlideModel(embeddings.shape[1])
window_model.load_state_dict(torch.load("models/bert_cls_pooling_model.pkl", weights_only=True))
window_model.to(device)
window_model.eval()

WindowSlideModel(
  (net): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [83]:
# Determining invocation times
batch_size = 128  # Adjust batch size based on your GPU memory
y_hat = []
with torch.no_grad():
    for i in tqdm(range(0, len(embeddings), batch_size), desc="Predicting BIO tags"):
        cls_token = torch.tensor(embeddings[i:i + batch_size]).to(device)
        logits = window_model(cls_token)
        y_hat_batch = (torch.sigmoid(logits) > 0.5).int()  # Convert logits to binary predictions

        y_hat.append(y_hat_batch.cpu().numpy())

# unbatch of y_hat
y_hat = np.concatenate(y_hat, axis=0)
print(f"Shape of predictions: {y_hat.shape}")

Predicting BIO tags: 100%|██████████| 1437/1437 [00:00<00:00, 1827.17it/s]

Shape of predictions: (183876,)





In [84]:
print(f"Number of invocations: {np.sum(y_hat)}")

Number of invocations: 43404


## 5. Evaludating sliding window

In [86]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_model.eval()
ner_model.to(device)
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use mps:0


In [104]:
sliding_window_metrics = Metrics()

ix = 0
for bio_tags, sentence_windows in tqdm(windows, desc="Evaluating sliding window metrics"):
    invocations = []
    sentence_window_offsets = 0
    prev_ix = ix
    for window in sentence_windows:
        if y_hat[ix] == 1:
            results = ner_pipeline(" ".join(window))
            pred_bio_tags = convert_predictions(window, results)
            current_invocation = invocations[-1].copy() if len(invocations) > 0 else []
            # previous windows overlap at 5 positions with the current window
            target_current_invocation_length = sentence_window_offsets + len(window)
            while len(current_invocation) < target_current_invocation_length:
                current_invocation.append("O")

            # overwrite the end of current_invocation with the new predictions
            current_invocation[sentence_window_offsets:sentence_window_offsets + len(pred_bio_tags)] = pred_bio_tags
            invocations.append(current_invocation)

        sentence_window_offsets += 1
        ix += 1


    sliding_window_metrics.evaluate_metrics(([], bio_tags), invocations)

sliding_window_metrics.save_metrics("baseline/sliding_window_metrics.pkl")


Evaluating sliding window metrics: 100%|██████████| 12217/12217 [05:42<00:00, 35.71it/s] 


In [105]:
# sliding_window_metrics = Metrics()
# sliding_window_metrics.load_metrics("baseline/sliding_window_metrics.pkl")
sliding_window_metrics.print_metrics()

Metrics:
Total NER invocations: 43404
Avg TTFD: 1.33
FPR@FNR: 0.0647@0.7210
Entity Type          TP         TN         FP (#B-/I-MISC)      FN        
----------------------------------------------------------------------
O                    N/A        97327      1802                 N/A       
B-PERSON             1409       N/A        47                   678       
I-PERSON             428        N/A        432                  652       
B-NORP               0          N/A        843 (826/0)          147       
I-NORP               0          N/A        144 (106/33)         18        
B-FAC                0          N/A        87 (1/0)             62        
I-FAC                0          N/A        196 (3/2)            47        
B-ORG                1058       N/A        122                  822       
I-ORG                952        N/A        825                  926       
B-GPE                0          N/A        2095 (29/0)          451       
I-GPE                0      

In [107]:
from flops_calculator import FlopsCalculator
flops_calculator = FlopsCalculator("baselines/flops_coefficients.pkl")

num_windows = len(y_hat)
num_invocations = np.sum(y_hat)
avg_tokens = 8 # window of 6 words + 2 for [CLS] and [SEP]

print(f"Average prefix tokens: {avg_tokens} tokens * {num_windows} windows")
print(f"Average invocation tokens: {avg_tokens} tokens * {num_invocations} invocations")

sliding_window_flops = flops_calculator.calculate_flops("model_2", avg_tokens) * num_windows
ner_running_flops = flops_calculator.calculate_flops("ner", avg_tokens) * num_invocations

total_flops = sliding_window_flops + ner_running_flops
print(f"Total FLOPs for sliding window model: {sliding_window_flops:,}")

Average prefix tokens: 8 tokens * 183876 windows
Average invocation tokens: 8 tokens * 43404 invocations
Total FLOPs for sliding window model: 125,249,513,445,672
