# Confidence Score Baseline

This notebook implements the confidence score (approach 1) baseline for near real-time Named Entity Recognition (NER).

## 1. Setup and imports

In [None]:
%pip install datasets transformers

Note: you may need to restart the kernel to use updated packages.


In [35]:
from datasets import load_dataset
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import sys

In [51]:
sys.path.append("src")
from utils import convert_ids_to_bio, convert_predictions
from confidence_model import confidence_model
from metrics import Metrics

In [59]:
from importlib import reload
import metrics
reload(metrics)
from metrics import Metrics

## 2. Load OntoNotes dataset

In [5]:
# Load the English portion of OntoNotes 5.0
ontonotes = load_dataset(
    "conll2012_ontonotesv5",
    "english_v12",
    cache_dir="./dataset/ontonotes",
)
print(f"Dataset loaded with splits: {ontonotes.keys()}")

Dataset loaded with splits: dict_keys(['train', 'validation', 'test'])


## 3. Preprocessing
Here we create all window prefixes and pre-compute the CLS token over the OntoNotes test split.

In [12]:
prefixes = []
prefix_count = 0

# Iterate through the test split
for doc in tqdm(ontonotes["test"], desc="Processing documents", unit="doc"):
    # Fix: Sometimes doc['sentences'] is a list of lists, so we need to flatten it
    if isinstance(doc['sentences'], list) and isinstance(doc['sentences'][0], list):
        doc['sentences'] = [sentence for sublist in doc['sentences'] for sentence in sublist]
    for sentence in doc['sentences']:
        sentence_prefixes = []
        curr_prefix = []
        for word in sentence['words']:
            curr_prefix.append(word)
            prefix_count += 1
            sentence_prefixes.append(curr_prefix.copy())

        true_bio = convert_ids_to_bio(sentence['named_entities'])

        # Store the current prefix and BIO tags
        prefixes.append((true_bio, sentence_prefixes))

print(f"Total prefixes created: {prefix_count}")
print(f"Example prefix: {prefixes[0][0]}\n{"\n".join([str(x) for x in prefixes[0][1]])}")

Processing documents: 100%|██████████| 1200/1200 [00:00<00:00, 1336.49doc/s]

Total prefixes created: 230118
Example prefix: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['--']
['--', 'basically']
['--', 'basically', ',']
['--', 'basically', ',', 'it']
['--', 'basically', ',', 'it', 'was']
['--', 'basically', ',', 'it', 'was', 'unanimously']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties']
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various'




In [48]:
from transformers import AutoTokenizer, AutoModel, AutoModelForTokenClassification, pipeline
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_model = AutoModel.from_pretrained("dslim/bert-base-NER")

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
print(f"Using device: {device}")

bert_model.to(device)
bert_model.eval()

Using device: mps


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [34]:
all_prefixes = []
for _, sentence_prefixes in prefixes:
    all_prefixes.extend(sentence_prefixes)

print(f"Total prefixes to process: {len(all_prefixes)}")

Total prefixes to process: 230118


In [36]:
# Calculate all embeddings for the prefixes in batches

batch_size = 128  # Adjust batch size based on your GPU memory
embeddings = []
for i in tqdm(range(0, len(all_prefixes), batch_size), desc="Computing CLS token for prefixes"):
    batch = all_prefixes[i:i + batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, is_split_into_words=True).to(device)
    
    with torch.no_grad():
        outputs = bert_model(**inputs)
    
    cls_token = outputs.last_hidden_state[:, 0, :]
    embeddings.append(cls_token.cpu().numpy())

# Unbatch the embeddings
embeddings = np.concatenate(embeddings, axis=0)
print(f"Shape of embeddings: {embeddings.shape}")

Computing CLS token for windows: 100%|██████████| 1798/1798 [07:45<00:00,  3.86it/s]

Shape of embeddings: (230118, 768)





In [39]:
np.savez("data/ontonotes_embeddings_test_all_prefixes.npz", embeddings=embeddings)

In [41]:
data = np.load("data/ontonotes_embeddings_test_all_prefixes.npz")
embeddings = data["embeddings"]
print(f"Shape of embeddings: {embeddings.shape}")

Shape of embeddings: (230118, 768)


## 4. Implementing confidence score baseline

In [43]:
conf_model = confidence_model()
conf_model.load_state_dict(torch.load("models/confidence_model.pth", weights_only=True))
conf_model.to(device)
conf_model.eval()

confidence_model(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [46]:
batch_size = 128  # Adjust batch size based on your GPU memory
threshold = 0.8
y_hat = []
for i in tqdm(range(0, len(embeddings), batch_size), desc="Determining invocation times"):
    cls_token = torch.tensor(embeddings[i:i + batch_size]).to(device)
    with torch.no_grad():
        logits = conf_model(cls_token)
    y_hat_batch = torch.where(torch.sigmoid(logits) > threshold, 1, 0)

    y_hat.append(y_hat_batch.cpu().numpy())

# unbatch of y_hat
y_hat = np.concatenate(y_hat, axis=0)
print(f"Shape of predictions: {y_hat.shape}")
print(f"Number of invocations: {np.sum(y_hat)}")

Determining invocation times: 100%|██████████| 1798/1798 [00:01<00:00, 1078.57it/s]

Shape of predictions: (230118, 1)
Number of invocations: 14733





## 5. Evaludating sliding window


In [49]:
ner_classifier = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_classifier.to(device)
ner_classifier.eval()
ner_pipeline = pipeline("ner", model=ner_classifier, tokenizer=tokenizer, aggregation_strategy="simple")

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use mps:0


In [56]:
confidence_metrics = Metrics()
ix = 0

for true_bio, sentence_prefixes in tqdm(prefixes):
    invocations = []
    for prefix in sentence_prefixes:
        if y_hat[ix] == 1:
            results = ner_pipeline(" ".join(prefix))
            invocations.append(convert_predictions(prefix, results))
        ix += 1

    confidence_metrics.evaluate_metrics(([], true_bio), invocations)

confidence_metrics.save_metrics("baselines/confidence_score.pkl")

100%|██████████| 12217/12217 [02:11<00:00, 93.26it/s] 


In [70]:
confidence_metrics = Metrics()
confidence_metrics.load_metrics("baselines/confidence_score.pkl")
confidence_metrics.print_metrics()

Metrics:
Total NER invocations: 14733
Avg TTFD: 1.46
FPR@FNR: 0.0858@0.7293
Entity Type          TP         TN         FP (#B-/I-MISC)      FN        
----------------------------------------------------------------------
O                    N/A        52632      1463                 N/A       
B-PERSON             1198       N/A        81                   855       
I-PERSON             736        N/A        62                   714       
B-NORP               0          N/A        681 (663/3)          309       
I-NORP               0          N/A        107 (37/62)          55        
B-FAC                0          N/A        73 (2/0)             76        
I-FAC                0          N/A        163 (3/6)            80        
B-ORG                933        N/A        106                  963       
I-ORG                1270       N/A        347                  1086      
B-GPE                0          N/A        1852 (36/1)          694       
I-GPE                0      

In [71]:
print(f"Rectified TP: {confidence_metrics._calculate_tp()}")
print(f"Rectified FP: {confidence_metrics._calculate_fp()}")

Rectified TP: 5372
Rectified FP: 4940


In [69]:
from flops_calculator import FlopsCalculator
flops_calculator = FlopsCalculator("baselines/flops_coefficients.pkl")

num_prefixes = len(y_hat)
num_invocations = np.sum(y_hat)
token_lengths = np.array([], dtype=int)
invocation_token_length = np.array([], dtype=int)
ix = 0
for true_bio, sentence_prefixes in tqdm(prefixes):
    invocations = []
    for prefix in sentence_prefixes:
        token_lengths = np.append(token_lengths, len(prefix) + 2)
        if y_hat[ix] == 1:
            invocation_token_length = np.append(invocation_token_length, len(prefix) + 2)  # +2 for [CLS] and [SEP]
        ix += 1

avg_prefix_tokens = np.mean(token_lengths) if token_lengths.size > 0 else 0
avg_invocation_tokens = np.mean(invocation_token_length) if invocation_token_length.size > 0 else 0

print(f"Average prefix tokens: {avg_prefix_tokens} tokens * {num_prefixes} prefixes")
print(f"Average invocation tokens: {avg_invocation_tokens} tokens * {num_invocations} invocations")

confidence_model_flops = flops_calculator.calculate_flops("model_1", avg_prefix_tokens) * num_prefixes
ner_running_flops = flops_calculator.calculate_flops("ner", avg_invocation_tokens) * num_invocations

total_flops = confidence_model_flops + ner_running_flops
print(f"Total FLOPs for confidence model: {confidence_model_flops:,}")

100%|██████████| 12217/12217 [00:04<00:00, 2446.57it/s]

Average prefix tokens: 16.637286088007023 tokens * 230118 prefixes
Average invocation tokens: 14.684789248625535 tokens * 14733 invocations
Total FLOPs for confidence model: 325,724,260,216,320



