In [None]:
import spacy
from datasets import load_dataset

import consts

# 1. Setup spaCy for POS tagging
# We disable 'parser' and 'ner' to make it faster; we only need the tagger
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# 2. Load standard NER dataset (CoNLL-2003)
dataset = load_dataset("conll2003")

# 3. Define the Label Mappings
# CoNLL original: {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', ...}
# We need to map integers to strings to manipulate them easily
original_feature = dataset['train'].features['ner_tags'].feature
label_names = original_feature.names

# Add our new custom labels
new_label_names = label_names + ["B-NOUN", "I-NOUN", "B-PRON", "I-PRON"]
label2id = {label: i for i, label in enumerate(new_label_names)}
id2label = {i: label for i, label in enumerate(new_label_names)}

def augment_data(example):
    tokens = example['tokens']
    ner_ids = example['ner_tags']
    
    # Create a spaCy Doc from the pre-tokenized list (preserves alignment)
    doc = nacy_doc = spacy.tokens.Doc(nlp.vocab, words=tokens)
    # Run the tagger pipeline manually on this doc
    nlp.tagger(doc)
    
    new_tags = []
    
    for token, ner_id in zip(doc, ner_ids):
        original_tag = label_names[ner_id]
        
        # LOGIC:
        # 1. If it already has an NER tag (e.g., B-PER), KEEP IT.
        # 2. If it is 'O', check POS tags.
        if original_tag != "O":
            new_tags.append(label2id[original_tag])
        
        else:
            # Check spaCy POS tags
            # standard spaCy POS: NOUN, PROPN, PRON
            if token.pos_ in ["NOUN"]:
                new_tags.append(label2id["B-NOUN"]) # Simplified: treating all as Beginning
            elif token.pos_ == "PRON":
                new_tags.append(label2id["B-PRON"])
            else:
                new_tags.append(label2id["O"])
                
    return {'ner_tags': new_tags}

# 4. Apply augmentation to the dataset
print("Augmenting dataset with POS tags...")
augmented_dataset = dataset.map(augment_data)

# Verify one example
print(f"Tokens: {augmented_dataset['train'][0]['tokens']}")
print(f"New IDs: {augmented_dataset['train'][0]['ner_tags']}")
print(f"Readable: {[id2label[i] for i in augmented_dataset['train'][0]['ner_tags']]}")