# Confidence Score Baseline

This notebook implements the confidence score (approach 1) baseline for near real-time Named Entity Recognition (NER).

## 1. Setup and imports

In [None]:
%pip install datasets transformers

Note: you may need to restart the kernel to use updated packages.


In [35]:
from datasets import load_dataset
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import sys

In [42]:
sys.path.append("src")
from utils import convert_ids_to_bio
from confidence_model import confidence_model

## 2. Load OntoNotes dataset

In [5]:
# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)
print(f"Dataset loaded with splits: {ontonotes.keys()}")

Dataset loaded with splits: dict_keys(['train', 'validation', 'test'])


## 3. Preprocessing
Here we create all window prefixes and pre-compute the CLS token over the OntoNotes test split.

In [12]:
prefixes = []
prefix_count = 0

# Iterate through the test split
for doc in tqdm(ontonotes["test"], desc="Processing documents", unit="doc"):
    # Fix: Sometimes doc['sentences'] is a list of lists, so we need to flatten it
    if isinstance(doc['sentences'], list) and isinstance(doc['sentences'][0], list):
        doc['sentences'] = [sentence for sublist in doc['sentences'] for sentence in sublist]
    for sentence in doc['sentences']:
        sentence_prefixes = []
        curr_prefix = []
        for word in sentence['words']:
            curr_prefix.append(word)
            prefix_count += 1
            sentence_prefixes.append(curr_prefix.copy())

        true_bio = convert_ids_to_bio(sentence['named_entities'])

        # Store the current prefix and BIO tags
        prefixes.append((true_bio, sentence_prefixes))

print(f"Total prefixes created: {prefix_count}")
print(f"Example prefix: {prefixes[0][0]}\n{"\n".join([str(x) for x in prefixes[0][1]])}")

Processing documents: 100%|██████████| 1200/1200 [00:00<00:00, 1336.49doc/s]

Total prefixes created: 230118
Example prefix: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['--']
['--', 'basically']
['--', 'basically', ',']
['--', 'basically', ',', 'it']
['--', 'basically', ',', 'it', 'was']
['--', 'basically', ',', 'it', 'was', 'unanimously']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various'




In [33]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
print(f"Using device: {device}")

bert_model.to(device)
bert_model.eval()

Using device: mps


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [34]:
all_prefixes = []
for _, sentence_prefixes in prefixes:
    all_prefixes.extend(sentence_prefixes)

print(f"Total prefixes to process: {len(all_prefixes)}")

Total prefixes to process: 230118


In [36]:
# Calculate all embeddings for the prefixes in batches

batch_size = 128  # Adjust batch size based on your GPU memory
embeddings = []
for i in tqdm(range(0, len(all_prefixes), batch_size), desc="Computing CLS token for prefixes"):
    batch = all_prefixes[i:i + batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, is_split_into_words=True).to(device)
    
    with torch.no_grad():
        outputs = bert_model(**inputs)
    
    cls_token = outputs.last_hidden_state[:, 0, :]
    embeddings.append(cls_token.cpu().numpy())

# Unbatch the embeddings
embeddings = np.concatenate(embeddings, axis=0)
print(f"Shape of embeddings: {embeddings.shape}")

Computing CLS token for windows: 100%|██████████| 1798/1798 [07:45<00:00,  3.86it/s]

Shape of embeddings: (230118, 768)





In [39]:
np.savez("data/ontonotes_embeddings_test_all_prefixes.npz", embeddings=embeddings)

In [41]:
data = np.load("data/ontonotes_embeddings_test_all_prefixes.npz")
embeddings = data["embeddings"]
print(f"Shape of embeddings: {embeddings.shape}")

Shape of embeddings: (230118, 768)


## 4. Implementing confidence score baseline

In [43]:
conf_model = confidence_model()
conf_model.load_state_dict(torch.load("models/confidence_model.pth", weights_only=True))
conf_model.to(device)
conf_model.eval()

confidence_model(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)