In [1]:
!pip install datasets



In [2]:
from random import shuffle
from math import ceil

import torch
import torch.nn as nn

from transformers import AutoModel, AutoTokenizer
import datasets

from tqdm.auto import tqdm

from collections import defaultdict
from urllib import request
import json
import pandas as pd

In [3]:
from sklearn.utils import gen_batches


In [3]:
def parse_conllu_using_pandas(block):
    records = []
    for line in block.splitlines():
        if not line.startswith('#'):
            records.append(line.strip().split('\t'))
    return pd.DataFrame.from_records(
        records,
        columns=['ID', 'FORM', 'TAG', 'Misc1', 'Misc2'])

In [4]:
def tokens_to_labels(df):
    return (
        df.FORM.tolist(),
        df.TAG.tolist()
    )

In [5]:
PREFIX = "https://raw.githubusercontent.com/UniversalNER/"
DATA_URLS = {
    "en_ewt": {
        "train": "UNER_English-EWT/master/en_ewt-ud-train.iob2",
        "dev": "UNER_English-EWT/master/en_ewt-ud-dev.iob2",
        "test": "UNER_English-EWT/master/en_ewt-ud-test.iob2"
    },
    "en_pud": {
        "test": "UNER_English-PUD/master/en_pud-ud-test.iob2"
    }
}

In [6]:
# en_ewt is the main train-dev-test split
# en_pud is the OOD test set
data_dict = defaultdict(dict)
for corpus, split_dict in DATA_URLS.items():
    for split, url_suffix in split_dict.items():
        url = PREFIX + url_suffix
        with request.urlopen(url) as response:
            txt = response.read().decode('utf-8')
            data_frames = map(parse_conllu_using_pandas,
                              txt.split('\n\n'))
            token_label_alignments = list(map(tokens_to_labels,
                                              data_frames))
            data_dict[corpus][split] = token_label_alignments

In [7]:
# Saving the data so that you don't have to redownload it each time.
with open('ner_data_dict.json', 'w', encoding='utf-8') as out:
    json.dump(data_dict, out, indent=2, ensure_ascii=False)

In [8]:
# Each subset of each corpus is a list of tuples where each tuple
# is a list of tokens with a corresponding list of labels.

# Train on data_dict['en_ewt']['train']; validate on data_dict['en_ewt']['dev']
# and test on data_dict['en_ewt']['test'] and data_dict['en_pud']['test']
data_dict['en_ewt']['train'][0], data_dict['en_pud']['test'][1]


((['Where', 'in', 'the', 'world', 'is', 'Iguazu', '?'],
  ['O', 'O', 'O', 'O', 'O', 'B-LOC', 'O']),
 (['For',
   'those',
   'who',
   'follow',
   'social',
   'media',
   'transitions',
   'on',
   'Capitol',
   'Hill',
   ',',
   'this',
   'will',
   'be',
   'a',
   'little',
   'different',
   '.'],
  ['O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'B-LOC',
   'I-LOC',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O',
   'O']))

In [8]:
def convert_to_word_tag_tuples(data):
    return [list(zip(words, tags)) for words, tags in data]


In [9]:
train_data = convert_to_word_tag_tuples(data_dict['en_ewt']['train'])
val_data   = convert_to_word_tag_tuples(data_dict['en_ewt']['dev'])
test_data  = convert_to_word_tag_tuples(data_dict['en_ewt']['test'])
ood_data   = convert_to_word_tag_tuples(data_dict['en_pud']['test'])


In [14]:
train_data[1]

[('Iguazu', 'B-LOC'), ('Falls', 'I-LOC')]

In [15]:
# Let's see how many different labels we have to set up the classification_head
# accordingly:
labels = set()
for ex in train_data:
    # We assume that there will be no new POS tags in the dev and test sets.
    labels.update([el[1] for el in ex])
n_classes = len(labels)
sorted(labels)

['B-LOC', 'B-ORG', 'B-PER', 'I-LOC', 'I-ORG', 'I-PER', 'O']

In [16]:
# The models expect class numbers, not strings
label_to_i = {
    label: i
    for i, label in enumerate(sorted(labels))
}
i_to_label = {
    i: label
    for label, i in label_to_i.items()
}

In [None]:
# label_to_i

In [17]:
model_tag = 'google-bert/bert-base-uncased'

tokeniser = AutoTokenizer.from_pretrained(model_tag)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [18]:
class ClassificationHead(nn.Module):
    def __init__(self, model_dim=768, n_classes=7):
        super().__init__()
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(model_dim, n_classes)

    def forward(self, x):
        return self.linear(self.dropout(x))

In [None]:
# def process_sentence(sentence, label_to_i, tokeniser, encoder, clf_head,
#                      encoder_device, clf_head_device):
#     if not sentence:
#         return torch.zeros((0, clf_head.linear.in_features), device=clf_head_device), \
#                torch.tensor([], device=clf_head_device, dtype=torch.long)

#     words = [word for word, _ in sentence]
#     tokenisation = tokeniser(words, is_split_into_words=True, return_tensors='pt')

#     # Handle cases where tokenization produces no input IDs
#     if tokenisation.input_ids.size(1) < 2:  # Only [CLS] and [SEP]
#         return torch.zeros((0, clf_head.linear.in_features), device=clf_head_device), \
#                torch.tensor([], device=clf_head_device, dtype=torch.long)
#     inputs = {k: v.to(encoder_device) for k, v in tokenisation.items()}

#     # We don't need the embedding of the CLS token this time.
#     # Neither do we need the embedding of the SEP token added at the end.
#     outputs = encoder(**inputs).last_hidden_state[0, 1:-1, :]

#     # Now we need to decide how to combine subword embeddings into
#     # word embeddings. The simplest way is to take the first/last subword,
#     # and we will do this here. The logic is that we will fine-tune the
#     # encoder as well, and we hope that it will learn to channel all the
#     # necessary information into first subwords.
#     # Note that word_ids are found only in the original tokeniser output,
#     # in the dictionary with tensors copied to the GPU.
#     # We ignore the CLS and the SEP tokens
#     word_ids = tokenisation.word_ids()[1:-1]
#     processed_words = set()
#     first_subword_embeddings = []
#     # Indices of subwords in outputs are aligned with word_ids, so we can use
#     # the same indices in both arrays.
#     for i, word_id in enumerate(word_ids):
#         if word_id not in processed_words:
#             first_subword_embeddings.append(outputs[i])
#             processed_words.add(word_id)

#     # Check that we aligned words and labels correctly.
#     assert len(first_subword_embeddings) == gold_labels.size(0)

#     # Combine subword embeddings into a tensor and copy to the device
#     # where the classifier head resides.
#     clf_head_inputs = torch.vstack(
#         first_subword_embeddings).to(clf_head_device)
#     gold_labels = torch.tensor([label_to_i[label] for _, label in sentence], device=clf_head_device)


#     # Return the logits and gold labels for subsequent processing
#     return clf_head(clf_head_inputs), gold_labels

In [19]:
def process_sentence(sentence, label_to_i, tokeniser, encoder, clf_head,
                     encoder_device, clf_head_device):
    # Handle empty sentences
    if not sentence:
        return torch.zeros((0, clf_head.linear.in_features), device=clf_head_device), \
               torch.tensor([], device=clf_head_device, dtype=torch.long)

    words = [word for word, _ in sentence]
    tokenisation = tokeniser(words, is_split_into_words=True, return_tensors='pt')

    # Handle cases where tokenization produces no input IDs
    if tokenisation.input_ids.size(1) < 2:  # Only [CLS] and [SEP]
        return torch.zeros((0, clf_head.linear.in_features), device=clf_head_device), \
               torch.tensor([], device=clf_head_device, dtype=torch.long)

    inputs = {k: v.to(encoder_device) for k, v in tokenisation.items()}
    outputs = encoder(**inputs).last_hidden_state[0, 1:-1, :]  # Skip [CLS] and [SEP]

    word_ids = tokenisation.word_ids()[1:-1]  # Align with outputs after slicing
    first_subword_embeddings = []
    processed_words = set()

    for i, word_id in enumerate(word_ids):
        if word_id is not None and word_id not in processed_words:
            first_subword_embeddings.append(outputs[i])
            processed_words.add(word_id)

    # Handle empty embeddings (no valid words)
    if not first_subword_embeddings:
        return torch.zeros((0, clf_head.linear.in_features), device=clf_head_device), \
               torch.tensor([], device=clf_head_device, dtype=torch.long)

    clf_head_inputs = torch.vstack(first_subword_embeddings).to(clf_head_device)
    gold_labels = torch.tensor([label_to_i[label] for _, label in sentence], device=clf_head_device)

    return clf_head(clf_head_inputs), gold_labels


In [20]:
def train_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                encoder_device, clf_head_device, loss_fn, optimiser):
    encoder.train()
    epoch_losses = torch.empty(len(data))
    for step_n, sentence in tqdm(
        enumerate(data),
        total=len(data),
        desc='Train',
        leave=False
    ):
        if not sentence:
          continue
        else:
          optimiser.zero_grad()
          logits, gold_labels = process_sentence(
              sentence, label_to_i, tokeniser,
              encoder, clf_head, encoder_device,
              clf_head_device)
          loss = loss_fn(logits, gold_labels)
          loss.backward()
          optimiser.step()
          epoch_losses[step_n] = loss.item()
    return epoch_losses.mean().item()

In [27]:
def validate_epoch(data, label_to_i, tokeniser, encoder, clf_head,
                   encoder_device, clf_head_device):
    encoder.eval()
    epoch_accuracies = torch.empty(len(data))
    for step_n, sentence in tqdm(
        enumerate(data),
        total=len(data),
        desc='Eval',
        leave=False
    ):
        with torch.no_grad():
          if not sentence:
            epoch_accuracies[step_n] = 0.0
            continue
          logits, gold_labels = process_sentence(
              sentence, label_to_i, tokeniser,
              encoder, clf_head, encoder_device,
              clf_head_device)
        if logits.size(0) == 0 or gold_labels.size(0) == 0:
          epoch_accuracies[step_n] = 0.0
          continue
        predicted_labels = torch.argmax(logits, dim=-1)

        epoch_accuracies[step_n] = (
            predicted_labels == gold_labels).sum().item() / len(sentence)
    return epoch_accuracies.mean().item()

In [25]:
encoder_device = 0  # Can also be 'cpu'
encoder = AutoModel.from_pretrained(
    model_tag).to(encoder_device)
# NB: pass the number of different POS tags
clf_head = ClassificationHead(n_classes=n_classes)
clf_head_device = 0
clf_head.to(clf_head_device);

In [28]:
n_epochs = 4
batch_size = 32
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.AdamW(
    list(encoder.parameters()) + list(clf_head.parameters()), lr=2 * 10**(-5))
for epoch_n in tqdm(range(n_epochs)):
    loss = train_epoch(train_data, label_to_i = label_to_i, tokeniser = tokeniser, encoder = encoder, clf_head = clf_head,
                       encoder_device = encoder_device, clf_head_device = clf_head_device, loss_fn = loss_fn, optimiser = optimiser)



    print(f'Epoch {epoch_n+1} training loss: {loss:.2f}')
    accuracy = validate_epoch(val_data, label_to_i, tokeniser, encoder,
                              clf_head, encoder_device, clf_head_device)
    print(f'Epoch {epoch_n+1} dev accuracy: {accuracy:.2f}')


  0%|          | 0/4 [00:00<?, ?it/s]

Train:   0%|          | 0/12544 [00:00<?, ?it/s]

Epoch 1 training loss: 0.05


Eval:   0%|          | 0/2002 [00:00<?, ?it/s]

Epoch 1 dev accuracy: 0.97


Train:   0%|          | 0/12544 [00:00<?, ?it/s]

Epoch 2 training loss: 0.03


Eval:   0%|          | 0/2002 [00:00<?, ?it/s]

Epoch 2 dev accuracy: 0.97


Train:   0%|          | 0/12544 [00:00<?, ?it/s]

Epoch 3 training loss: 0.02


Eval:   0%|          | 0/2002 [00:00<?, ?it/s]

Epoch 3 dev accuracy: 0.98


Train:   0%|          | 0/12544 [00:00<?, ?it/s]

Epoch 4 training loss: 0.02


Eval:   0%|          | 0/2002 [00:00<?, ?it/s]

Epoch 4 dev accuracy: 0.98


In [22]:
!pip install seqeval pandas


Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=9781d04d457f8d70cc6ca942e0c1f1b72bf6afe950b8c8225341fe8769155b8c
  Stored in directory: /root/.cache/pip/wheels/bc/92/f0/243288f899c2eacdfa8c5f9aede4c71a9bad0ee26a01dc5ead
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [23]:
from collections import defaultdict
from seqeval.metrics import classification_report
from tqdm.auto import tqdm

def evaluate_full_tagset(
    model, tokeniser, data, label_to_i, i_to_label, encoder_device=0, clf_head_device=0
):
    """
    Evaluates a complex encoder-only model on all required NER metrics for the full BIO tagset.
    Returns token-level, labelled span, unlabelled span, per-label and macro F1 metrics.
    """
    def bio_to_spans(tags):
        spans = []
        current_start = None
        current_label = None
        for i, tag in enumerate(tags):
            if tag.startswith('B-'):
                if current_start is not None:
                    spans.append((current_start, i-1, current_label))
                current_start = i
                current_label = tag  # Keep full tag (e.g., "B-LOC")
            elif tag.startswith('I-'):
                if current_label != tag:  # Compare full tags
                    if current_start is not None:
                        spans.append((current_start, i-1, current_label))
                    current_start = i
                current_label = tag  # Keep full tag (e.g., "I-LOC")
            else:
                if current_start is not None:
                    spans.append((current_start, i-1, current_label))
                    current_start = None
                    current_label = None
        if current_start is not None:
            spans.append((current_start, len(tags)-1, current_label))
        return spans

    # Collect all spans and token-level labels
    all_true_token = []
    all_pred_token = []
    labelled_true = []
    labelled_pred = []
    unlabelled_true = []
    unlabelled_pred = []
    label_counts = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})

    for sentence in tqdm(data, desc="Evaluating"):
        with torch.no_grad():
            logits, gold_labels = process_sentence(
                sentence, label_to_i, tokeniser,
                model['encoder'], model['clf_head'],
                encoder_device, clf_head_device
            )
        pred_indices = torch.argmax(logits, dim=-1).cpu().numpy()
        true_tags = [i_to_label[i] for i in gold_labels.cpu().numpy()]
        pred_tags = [i_to_label[i] for i in pred_indices]
        all_true_token.append(true_tags)
        all_pred_token.append(pred_tags)

        # Span-level
        true_spans = bio_to_spans(true_tags)
        pred_spans = bio_to_spans(pred_tags)

        labelled_true.append([(s, e, l) for (s, e, l) in true_spans])
        labelled_pred.append([(s, e, l) for (s, e, l) in pred_spans])
        unlabelled_true.append([(s, e) for (s, e, l) in true_spans])
        unlabelled_pred.append([(s, e) for (s, e, l) in pred_spans])

        # Per-label F1
        for (s, e, l) in true_spans:
            label_counts[l]['fn'] += 1
        for (s, e, l) in pred_spans:
            label_counts[l]['fp'] += 1
        for span in set(true_spans) & set(pred_spans):
            label_counts[span[2]]['tp'] += 1
            label_counts[span[2]]['fn'] -= 1

    # Token-level metrics
    token_report = classification_report(
        all_true_token,
        all_pred_token,
        output_dict=True
    )

    # Span-level metrics
    def compute_span_metrics(true, pred):
        true_flat = set((i, s, e, *rest) for i, spans in enumerate(true) for span in spans for s,e,*rest in [span])
        pred_flat = set((i, s, e, *rest) for i, spans in enumerate(pred) for span in spans for s,e,*rest in [span])
        correct = true_flat & pred_flat
        precision = len(correct)/len(pred_flat) if pred_flat else 0
        recall = len(correct)/len(true_flat) if true_flat else 0
        f1 = 2*precision*recall/(precision+recall) if (precision+recall) else 0
        return {'precision': precision, 'recall': recall, 'f1': f1}

    # Macro F1 calculation (all 7 labels)
    labels = ['B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'O']
    per_label = {}
    for label in labels:
        tp = label_counts[label]['tp']
        fp = label_counts[label]['fp']
        fn = label_counts[label]['fn']
        p = tp / (tp + fp) if (tp + fp) > 0 else 0
        r = tp / (tp + fn) if (tp + fn) > 0 else 0
        f = 2 * p * r / (p + r) if (p + r) > 0 else 0
        per_label[label] = (p, r, f)
    macro_f1 = sum(f for p, r, f in per_label.values()) / len(per_label) if per_label else 0

    return {
        'token_level': token_report,
        'span_level': {
            'labelled': compute_span_metrics(labelled_true, labelled_pred),
            'unlabelled': compute_span_metrics(unlabelled_true, unlabelled_pred)
        },
        'per_label': per_label,
        'macro_f1': macro_f1
    }


In [24]:
results = evaluate_full_tagset(
    model={'encoder': encoder, 'clf_head': clf_head},
    tokeniser=tokeniser,
    data=test_data,
    label_to_i=label_to_i,
    i_to_label=i_to_label,
    encoder_device=encoder_device,
    clf_head_device=clf_head_device
)

print("Token-Level Metrics:")
print(results['token_level'])  # Or format as DataFrame

print("\nLabelled Span Metrics:")
print(results['span_level']['labelled'])

print("\nUnlabelled Span Metrics:")
print(results['span_level']['unlabelled'])

print("\nPer-label F1:")
for label, (p, r, f) in results['per_label'].items():
    print(f"{label}: P={p:.4f}, R={r:.4f}, F1={f:.4f}")

print(f"\nMacro F1: {results['macro_f1']:.4f}")


Evaluating:   0%|          | 0/2078 [00:00<?, ?it/s]

Token-Level Metrics:
{'LOC': {'precision': np.float64(0.8023952095808383), 'recall': np.float64(0.8454258675078864), 'f1-score': np.float64(0.8233486943164363), 'support': np.int64(317)}, 'ORG': {'precision': np.float64(0.5404699738903395), 'recall': np.float64(0.6428571428571429), 'f1-score': np.float64(0.5872340425531916), 'support': np.int64(322)}, 'PER': {'precision': np.float64(0.8353658536585366), 'recall': np.float64(0.9153674832962138), 'f1-score': np.float64(0.873538788522848), 'support': np.int64(449)}, 'micro avg': {'precision': np.float64(0.7328370554177006), 'recall': np.float64(0.8143382352941176), 'f1-score': np.float64(0.7714410100130605), 'support': np.int64(1088)}, 'macro avg': {'precision': np.float64(0.7260770123765714), 'recall': np.float64(0.8012168312204144), 'f1-score': np.float64(0.761373841797492), 'support': np.int64(1088)}, 'weighted avg': {'precision': np.float64(0.7384833468037665), 'recall': np.float64(0.8143382352941176), 'f1-score': np.float64(0.7741818

In [25]:
results = evaluate_full_tagset(
    model={'encoder': encoder, 'clf_head': clf_head},
    tokeniser=tokeniser,
    data=ood_data,
    label_to_i=label_to_i,
    i_to_label=i_to_label,
    encoder_device=encoder_device,
    clf_head_device=clf_head_device
)

print("Token-Level Metrics:")
print(results['token_level'])  # Or format as DataFrame

print("\nLabelled Span Metrics:")
print(results['span_level']['labelled'])

print("\nUnlabelled Span Metrics:")
print(results['span_level']['unlabelled'])

print("\nPer-label F1:")
for label, (p, r, f) in results['per_label'].items():
    print(f"{label}: P={p:.4f}, R={r:.4f}, F1={f:.4f}")

print(f"\nMacro F1: {results['macro_f1']:.4f}")


Evaluating:   0%|          | 0/1001 [00:00<?, ?it/s]

Token-Level Metrics:
{'LOC': {'precision': np.float64(0.7124183006535948), 'recall': np.float64(0.7694117647058824), 'f1-score': np.float64(0.7398190045248869), 'support': np.int64(425)}, 'ORG': {'precision': np.float64(0.5494505494505495), 'recall': np.float64(0.425531914893617), 'f1-score': np.float64(0.47961630695443647), 'support': np.int64(235)}, 'PER': {'precision': np.float64(0.8964705882352941), 'recall': np.float64(0.9180722891566265), 'f1-score': np.float64(0.907142857142857), 'support': np.int64(415)}, 'micro avg': {'precision': np.float64(0.7579737335834896), 'recall': np.float64(0.7516279069767442), 'f1-score': np.float64(0.7547874824848202), 'support': np.int64(1075)}, 'macro avg': {'precision': np.float64(0.7194464794464794), 'recall': np.float64(0.704338656252042), 'f1-score': np.float64(0.7088593895407268), 'support': np.int64(1075)}, 'weighted avg': {'precision': np.float64(0.7478455358291201), 'recall': np.float64(0.7516279069767442), 'f1-score': np.float64(0.7475322

In [None]:
def evaluate_full_tagset(model, tokeniser, data, label_to_i, i_to_label, encoder_device=0, clf_head_device=0):
    """Evaluates model on span-level metrics for the full 7-label tagset."""
    true_spans_all = []
    pred_spans_all = []

    for sentence in tqdm(data, desc="Evaluating"):
        # Get model predictions
        with torch.no_grad():
            logits, gold_labels = process_sentence(
                sentence, label_to_i, tokeniser,
                model['encoder'], model['clf_head'],
                encoder_device, clf_head_device
            )
        pred_indices = torch.argmax(logits, dim=-1).cpu().numpy()

        # Extract true/predicted tags
        true_tags = [i_to_label[i] for i in gold_labels.cpu().numpy()]
        pred_tags = [i_to_label[i] for i in pred_indices]

        # Extract spans
        true_spans = bio_to_spans(true_tags)
        pred_spans = bio_to_spans(pred_tags)

        true_spans_all.append(true_spans)
        pred_spans_all.append(pred_spans)

    # Compute metrics
    return compute_span_metrics(true_spans_all, pred_spans_all)

In [None]:
def bio_to_spans(tags):
    """Converts BIO tags to (start, end, label) spans."""
    spans = []
    current_start = None
    current_label = None

    for i, tag in enumerate(tags):
        if tag.startswith('B-'):
            if current_start is not None:
                spans.append((current_start, i-1, current_label))
            current_label = tag[2:]
            current_start = i
        elif tag.startswith('I-'):
            if current_label != tag[2:]:  # Mismatched I-tag
                if current_start is not None:
                    spans.append((current_start, i-1, current_label))
                current_start = None
                current_label = None
        else:
            if current_start is not None:
                spans.append((current_start, i-1, current_label))
                current_start = None
                current_label = None
    if current_start is not None:
        spans.append((current_start, len(tags)-1, current_label))
    return spans

In [None]:
def compute_span_metrics(true_spans_all, pred_spans_all):
    """Computes span-level precision/recall/F1 for exact matches (label + boundaries)."""
    # Flatten spans across sentences
    true_flat = [(i,)+span for i, spans in enumerate(true_spans_all) for span in spans]
    pred_flat = [(i,)+span for i, spans in enumerate(pred_spans_all) for span in spans]

    # Count matches
    correct = set(true_flat) & set(pred_flat)

    # Overall metrics
    precision = len(correct)/len(pred_flat) if pred_flat else 0
    recall = len(correct)/len(true_flat) if true_flat else 0
    f1 = 2*precision*recall/(precision+recall) if (precision+recall) else 0

    # Per-label metrics
    label_counts = defaultdict(lambda: {'tp':0, 'fp':0, 'fn':0})
    for span in true_flat:
        label_counts[span[-1]]['fn'] += 1
    for span in pred_flat:
        label_counts[span[-1]]['fp'] += 1
    for span in correct:
        label_counts[span[-1]]['tp'] += 1
        label_counts[span[-1]]['fn'] -= 1

    per_label = {}
    for label in ['LOC', 'PER', 'ORG']:  # Your actual labels may vary
        tp = label_counts[label]['tp']
        fp = label_counts[label]['fp']
        fn = label_counts[label]['fn']

        p = tp/(tp+fp) if (tp+fp) > 0 else 0
        r = tp/(tp+fn) if (tp+fn) > 0 else 0
        f = 2*p*r/(p+r) if (p+r) > 0 else 0
        per_label[label] = (p, r, f)

    macro_f1 = sum(f for p,r,f in per_label.values())/len(per_label)

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'per_label': per_label,
        'macro_f1': macro_f1
    }


In [None]:
# After training
results = evaluate_full_tagset(
    model={'encoder': encoder, 'clf_head': clf_head},
    tokeniser=tokeniser,
    data=test_data,  # Your test dataset
    label_to_i=label_to_i,
    i_to_label=i_to_label
)

print(f"Span-level Precision: {results['precision']:.4f}")
print(f"Span-level Recall:    {results['recall']:.4f}")
print(f"Span-level F1:        {results['f1']:.4f}")
print(f"Macro F1:             {results['macro_f1']:.4f}")

for label, (p, r, f) in results['per_label'].items():
    print(f"{label}: P={p:.4f}, R={r:.4f}, F1={f:.4f}")


Evaluating:   0%|          | 0/2078 [00:00<?, ?it/s]

Span-level Precision: 0.4944
Span-level Recall:    0.3217
Span-level F1:        0.3898
Macro F1:             0.2559
LOC: P=0.3408, R=0.8107, F1=0.4799
PER: P=0.3120, R=0.1849, F1=0.2322
ORG: P=0.2632, R=0.0311, F1=0.0556


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)
clf_head = clf_head.to(device)


In [None]:
# Test prediction on a known entity
test_sentence = [('San', 'B-LOC'), ('Francisco', 'I-LOC')]
logits, labels = process_sentence(
    test_sentence, label_to_i, tokeniser, encoder, clf_head, device, device
)
predicted_tags = [i_to_label[i] for i in torch.argmax(logits, dim=-1).cpu().numpy()]
print("Predicted:", predicted_tags)
print("Gold     :", [tag for _, tag in test_sentence])


Predicted: ['B-LOC', 'I-LOC']
Gold     : ['B-LOC', 'I-LOC']
