In [11]:
import pandas as pd
import spacy

try:
    nlp = spacy.load("en_core_web_trf")
except:
    nlp = spacy.load("en_core_web_sm")  


test_df = pd.read_csv("NER-test.tsv", sep="\t", encoding="utf-8")

label_map = {
    "PERSON": "PER",
    "LOCATION": "LOC",
    "ORGANIZATION": "ORG",
    "ORG": "ORG",
    "WORK_OF_ART": "MISC",
    "EVENT": "MISC",
    "PRODUCT": "MISC",
    "DATE": "MISC",
    "TIME": "MISC",
    "MONEY": "MISC",
    "PERCENT": "MISC",
}

def convert_bio_tag(tag):
    if tag == "O":
        return tag
    else:
        prefix, typ = tag.split("-", 1)
        typ_mapped = label_map.get(typ, "MISC")
        return f"{prefix}-{typ_mapped}"

test_df["BIO_NER_tag"] = test_df["BIO_NER_tag"].apply(convert_bio_tag)

test_sentences = []
for sid, group in test_df.groupby("sentence_id"):
    sentence = " ".join(group["token"])
    test_sentences.append((sid, sentence))

predictions = []
for sid, sent in test_sentences:
    doc = nlp(sent)
    for ent in doc.ents:
        if ent.label_ in ["PERSON", "PER"]:
            mapped = "PER"
        elif ent.label_ in ["ORG", "ORGANIZATION"]:
            mapped = "ORG"
        elif ent.label_ in ["GPE", "LOC", "LOCATION"]:
            mapped = "LOC"
        else:
            mapped = "MISC"
        predictions.append((sid, ent.text, ent.start_char, ent.end_char, mapped))
        print(f"Sentence {sid} | Entity: '{ent.text}' | Label: {mapped} | Span: ({ent.start_char}, {ent.end_char})")

pred_df = pd.DataFrame(predictions, columns=["sentence_id", "entity", "start_char", "end_char", "label"])
pred_df.to_csv("ner_predictions_pretrained.csv", index=False)

# Evaluation: Precision, Recall, F1 

def bio_to_spans(df):
    spans = []
    for sent_id, group in df.groupby("sentence_id"):
        tags = list(group["BIO_NER_tag"])
        tokens = list(group["token"])
        start = None
        label = None
        for i, tag in enumerate(tags):
            if tag.startswith("B-"):
                if start is not None:
                    spans.append((sent_id, start, i-1, label))
                start = i
                label = tag[2:]
            elif tag.startswith("I-") and start is not None:
                continue
            else:
                if start is not None:
                    spans.append((sent_id, start, i-1, label))
                    start = None
                    label = None
        if start is not None:
            spans.append((sent_id, start, len(tags)-1, label))
    return set(spans)

gold_spans = bio_to_spans(test_df)

# Map predictions to token indices
sentence_tokens = {sid: list(group["token"]) for sid, group in test_df.groupby("sentence_id")}
pred_spans = []
for sid, tokens in sentence_tokens.items():
    sent_text = " ".join(tokens)
    doc = nlp(sent_text)
    char_to_token = []
    pointer = 0
    for idx, tok in enumerate(tokens):
        pointer = sent_text.find(tok, pointer)
        char_to_token.append((pointer, pointer + len(tok)))
        pointer += len(tok)
    for ent in doc.ents:
        if ent.label_ in ["PERSON", "PER"]:
            mapped = "PER"
        elif ent.label_ in ["ORG", "ORGANIZATION"]:
            mapped = "ORG"
        elif ent.label_ in ["GPE", "LOC", "LOCATION"]:
            mapped = "LOC"
        else:
            mapped = "MISC"
        ent_start = ent_end = None
        for idx, (start_char, end_char) in enumerate(char_to_token):
            if ent_start is None and start_char >= ent.start:
                ent_start = idx
            if end_char > ent.end:
                ent_end = idx
                break
        if ent_start is not None:
            if ent_end is None:
                ent_end = len(tokens) - 1
            pred_spans.append((sid, ent_start, ent_end, mapped))

pred_spans = set(pred_spans)

tp = len(gold_spans & pred_spans)
fp = len(pred_spans - gold_spans)
fn = len(gold_spans - pred_spans)

precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

print("\n--- Evaluation ---")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")


Sentence 0 | Entity: 'Paris' | Label: LOC | Span: (19, 24)
Sentence 0 | Entity: 'Louvre' | Label: PER | Span: (48, 54)
Sentence 0 | Entity: 'the Mona Lisa' | Label: MISC | Span: (73, 86)
Sentence 1 | Entity: 'Amazon' | Label: ORG | Span: (0, 6)
Sentence 1 | Entity: 'Google' | Label: ORG | Span: (9, 15)
Sentence 1 | Entity: 'Meta' | Label: ORG | Span: (20, 24)
Sentence 2 | Entity: 'Pharoah Sanders' | Label: PER | Span: (13, 28)
Sentence 2 | Entity: 'Floating Points' | Label: MISC | Span: (52, 67)
Sentence 4 | Entity: 'Kevin' | Label: PER | Span: (10, 15)
Sentence 4 | Entity: 'Succession' | Label: MISC | Span: (39, 49)
Sentence 4 | Entity: 'Kieran Culkin 's' | Label: PER | Span: (81, 97)
Sentence 5 | Entity: 'Venus Williams' | Label: PER | Span: (0, 14)
Sentence 6 | Entity: 'Elizabeth' | Label: PER | Span: (12, 21)
Sentence 6 | Entity: 'King Charles' | Label: PER | Span: (29, 41)
Sentence 6 | Entity: 'the British Royal Family' | Label: ORG | Span: (63, 87)
Sentence 7 | Entity: 'Dark Soul