In [9]:
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoTokenizer
import json

#  lets laod the dataset 
with open("../data/dataset.json", "r", encoding="utf-8") as f:
    data = json.load(f)

raw_datasets = DatasetDict({
    "train": Dataset.from_list(data["train"]),
    "validation": Dataset.from_list(data["validation"])
})

# lets laod the pretrained ztokenozer , i can train tokenizer from scratch if my domain is not matching 
model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

assert tokenizer.is_fast, "Tokenizer must be fast for word_ids to work!"

#  lets write main label alignment function
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id is None:
            new_labels.append(-100)
        elif word_id != current_word:
            current_word = word_id
            new_labels.append(labels[word_id])
        else:
            label = labels[word_id]
            if label % 2 == 1:
                label += 1
            new_labels.append(label)
    return new_labels


def tokenize_and_align_labels(example):
    if len(example["tokens"]) != len(example["ner_tags"]):
        print("Skipping due to mismatched length:", example["tokens"])
        return {}

    tokenized_input = tokenizer(
        example["tokens"],
        truncation=True,
        padding=True,
        is_split_into_words=True
    )

    word_ids = tokenized_input.word_ids()
    labels = align_labels_with_tokens(example["ner_tags"], word_ids)
    tokenized_input["labels"] = labels
    return tokenized_input


def is_clean(sample):
    return "tokens" in sample and "ner_tags" in sample and len(sample["tokens"]) == len(sample["ner_tags"])

raw_datasets["train"] = raw_datasets["train"].filter(is_clean)
raw_datasets["validation"] = raw_datasets["validation"].filter(is_clean)



#  lets map it to dataset 
tokenized_datasets = raw_datasets.map(
    tokenize_and_align_labels,
    batched=False,
    remove_columns=raw_datasets["train"].column_names,
    desc="Tokenizing and aligning",
)

# Remove empty examples if any (important!)
def is_valid(example):
    return "input_ids" in example and isinstance(example["input_ids"], list) and len(example["input_ids"]) > 0

tokenized_datasets = tokenized_datasets.filter(is_valid)



# === Save for reuse ===
tokenized_datasets.save_to_disk("../data/tokenized_dataset")
print("Tokenized dataset saved to '../data/tokenized_dataset'")


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/21 [00:00<?, ? examples/s]

Tokenizing and aligning:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing and aligning:   0%|          | 0/15 [00:00<?, ? examples/s]

Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/15 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15 [00:00<?, ? examples/s]

Tokenized dataset saved to '../data/tokenized_dataset'
