In [None]:
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch
import math

dataset = load_dataset("coastalcph/tydi_xor_rc")

languages = ['ar', 'ko', 'te']
train_dataset = dataset["train"].filter(lambda example: example['lang'] in languages)
val_dataset = dataset["validation"].filter(lambda example: example['lang'] in languages)

print("Sample from train dataset:")
sample = train_dataset[0]
print(f"Keys: {sample.keys()}")
print(f"Answer structure: {sample['answer']}")
print(f"Answer type: {type(sample['answer'])}")

#model loading
model_checkpoint = "google-bert/bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 384
doc_stride = 128

def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    contexts = [c.strip() for c in examples["context"]]

    tokenized_examples = tokenizer(
        questions,
        contexts,
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = tokenized_examples.pop("offset_mapping")
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_examples.sequence_ids(i)
        sample_index = sample_mapping[i]

        start_char = examples["answer_start"][sample_index]
        answer_text = examples["answer"][sample_index]
        end_char = start_char + len(answer_text)

        token_start_index = 0
        while token_start_index < len(sequence_ids) and sequence_ids[token_start_index] != 1:
            token_start_index += 1

        token_end_index = len(input_ids) - 1
        while token_end_index >= 0 and sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)
            while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

print("\nTokenizing datasets...")
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=50,
    logging_strategy="steps",
    logging_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

print("\nOverall Evaluation")
eval_results = trainer.evaluate()
print(f"Overall Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
print(f"Overall Loss: {eval_results['eval_loss']:.4f}")

print("\nLanguage-specific Evaluations")

for lang in languages:
    print(f"\nEvaluating {lang.upper()}")
    lang_val_dataset = val_dataset.filter(lambda example: example['lang'] == lang)
    print(f"Number of {lang.upper()} validation examples: {len(lang_val_dataset)}")

    if len(lang_val_dataset) == 0:
        print(f"No validation examples found for language: {lang}")
        continue

    tokenized_lang_val = lang_val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=lang_val_dataset.column_names
    )

    lang_trainer = Trainer(
        model=model,
        args=training_args,
        eval_dataset=tokenized_lang_val,
        tokenizer=tokenizer,
    )

    lang_eval_results = lang_trainer.evaluate()
    print(f"{lang.upper()} Perplexity: {math.exp(lang_eval_results['eval_loss']):.2f}")
    print(f"{lang.upper()} Loss: {lang_eval_results['eval_loss']:.4f}")

# English Context Only Evaluation
print("\nEvaluating English Contexts Only")

en_contexts = list(val_dataset["context"])
en_context_only_dataset = Dataset.from_dict({
    "question": [""] * len(en_contexts),  # empty question (no lang)
    "context": en_contexts,
    "answer": [""] * len(en_contexts),
    "answer_start": [0] * len(en_contexts),
})

tokenized_en_context_val = en_context_only_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=en_context_only_dataset.column_names
)

en_trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=tokenized_en_context_val,
    tokenizer=tokenizer,
)

en_eval_results = en_trainer.evaluate()
print(f"English Context Perplexity: {math.exp(en_eval_results['eval_loss']):.2f}")
print(f"English Context Loss: {en_eval_results['eval_loss']:.4f}")

for key, value in en_eval_results.items():
    if key not in ['eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch']:
        print(f"English Context {key}: {value:.4f}")


Sample from train dataset:
Keys: dict_keys(['question', 'context', 'lang', 'answerable', 'answer_start', 'answer', 'answer_inlang'])
Answer structure: France
Answer type: <class 'str'>


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


Tokenizing datasets...


Map:   0%|          | 0/6335 [00:00<?, ? examples/s]

Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-multilingual-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
wandb: Currently logged in as: aarushsinha60 (chungimungi) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss
50,3.9293,2.812973
100,2.6137,2.686838
150,2.4786,2.355334
200,2.2368,2.183851
250,2.041,1.975945
300,1.9599,1.698209
350,1.8166,1.666332
400,1.6583,1.659788
450,1.4064,1.585595
500,1.3607,1.591094



Overall Evaluation


Overall Perplexity: 4.25
Overall Loss: 1.4472

Language-specific Evaluations

Evaluating AR


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Number of AR validation examples: 415


Map:   0%|          | 0/415 [00:00<?, ? examples/s]

  lang_trainer = Trainer(


AR Perplexity: 3.92
AR Loss: 1.3671

Evaluating KO


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Number of KO validation examples: 356


Map:   0%|          | 0/356 [00:00<?, ? examples/s]

KO Perplexity: 3.62
KO Loss: 1.2862

Evaluating TE


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Number of TE validation examples: 384


Map:   0%|          | 0/384 [00:00<?, ? examples/s]

TE Perplexity: 5.41
TE Loss: 1.6876

Evaluating English Contexts Only


Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

  en_trainer = Trainer(


English Context Perplexity: 9.67
English Context Loss: 2.2691
English Context eval_model_preparation_time: 0.0018
