In [25]:
import os
from datasets import load_dataset, ClassLabel, Sequence
from consts import DATA_PATH
import spacy

In [26]:
# config
MODEL_NAME = "en_core_web_trf"
SAVE_PATH = os.path.join(DATA_PATH, "wikiann+spacy_pos")

In [27]:
# load
nlp = spacy.load(MODEL_NAME, disable=["parser", "ner", "lemmatizer"])
dataset = load_dataset('wikiann', 'en', cache_dir=os.path.join(DATA_PATH, "raw_cache"))

In [28]:
original_features = dataset["train"].features["ner_tags"].feature
original_names = original_features.names
new_names = original_names + ["B-NOUN", "I-NOUN", "B-PRON", "I-PRON"]
label2id = {label: i for i, label in enumerate(new_names)}

In [29]:
def doc_generator(tokens_list):
    for tokens in tokens_list:
        yield Doc(nlp.vocab, words=tokens)

In [30]:
# define the Logic
def augment_batch(batch):
    new_batch_tags = []

    # Use the generator here instead of a list comprehension
    doc_gen = doc_generator(batch["tokens"])

    # Process with nlp.pipe
    # batch_size=32 is safer for Transformers to avoid index mismatch errors
    processed_docs = list(nlp.pipe(doc_gen, batch_size=32))

    for i, doc in enumerate(processed_docs):
        original_ids = batch["ner_tags"][i]
        row_tags = []

        # doc is now fully processed with POS tags
        for token, original_id in zip(doc, original_ids):
            original_label = original_names[original_id]

            # 1. Keep Entities
            if original_label != "O":
                row_tags.append(label2id[original_label])
            # 2. Augment Nouns/Pronouns
            else:
                if token.pos_ == "NOUN":
                    row_tags.append(label2id["B-NOUN"])
                elif token.pos_ == "PRON":
                    row_tags.append(label2id["B-PRON"])
                else:
                    row_tags.append(label2id["O"])

        new_batch_tags.append(row_tags)

    # Return NEW column
    return {"augmented_tags": new_batch_tags}

In [31]:
spacy.prefer_gpu()

True

In [32]:
# run
from spacy.tokens import Doc
augmented_dataset = dataset.map(augment_batch, batched=True, batch_size=50)

augmented_dataset = augmented_dataset.remove_columns("ner_tags")
augmented_dataset = augmented_dataset.rename_column("augmented_tags", "ner_tags")

  dlpack_tensor = xp_tensor.toDlpack()  # type: ignore
Map: 100%|██████████| 10000/10000 [00:20<00:00, 497.99 examples/s]
Map: 100%|██████████| 10000/10000 [00:18<00:00, 529.45 examples/s]
Map: 100%|██████████| 20000/20000 [00:37<00:00, 531.79 examples/s]


In [33]:
# save
print(f"Saving to {SAVE_PATH}...")
# Update the feature definition so the dataset knows about the new tags
new_features = augmented_dataset["train"].features.copy()
new_features["ner_tags"] = Sequence(ClassLabel(names=new_names))
augmented_dataset = augmented_dataset.cast(new_features)

augmented_dataset.save_to_disk(SAVE_PATH)
print("Success! Dataset created.")

Saving to /home/dan/Work/utcn/an4/sem1/pso/proj/knowledge-graph-extraction/train/data/conll2003_augmented_lg...


Casting the dataset: 100%|██████████| 10000/10000 [00:00<00:00, 1328362.31 examples/s]
Casting the dataset: 100%|██████████| 10000/10000 [00:00<00:00, 1483606.52 examples/s]
Casting the dataset: 100%|██████████| 20000/20000 [00:00<00:00, 2116464.74 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10000/10000 [00:00<00:00, 1875387.44 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10000/10000 [00:00<00:00, 1582159.19 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 20000/20000 [00:00<00:00, 2175638.15 examples/s]

Success! Dataset created.



