In [7]:
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch
import math

dataset = load_dataset("coastalcph/tydi_xor_rc")

languages = ['ar', 'ko', 'te']
train_dataset = dataset["train"].filter(lambda example: example['lang'] in languages)
val_dataset = dataset["validation"].filter(lambda example: example['lang'] in languages)

print("Sample from train dataset:")
sample = train_dataset[0]
print(f"Keys: {sample.keys()}")
print(f"Answer structure: {sample['answer']}")
print(f"Answer type: {type(sample['answer'])}")

model_checkpoint = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 384
doc_stride = 128

def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    contexts = [c.strip() for c in examples["context"]]
    
    tokenized_examples = tokenizer(
        questions,
        contexts,
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = tokenized_examples.pop("offset_mapping")
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        sequence_ids = tokenized_examples.sequence_ids(i)

        sample_index = sample_mapping[i]
        
        answer_data = examples["answer"][sample_index]
        
        if isinstance(answer_data, dict):
            if "text" in answer_data:
                answer_text = answer_data["text"]
                answer_starts = answer_data.get("answer_start", [])
            else:
                answer_text = answer_data.get("answer_text", answer_data.get("answers", ""))
                answer_starts = answer_data.get("answer_start", answer_data.get("answer_starts", []))
        elif isinstance(answer_data, str):
            answer_text = answer_data
            answer_starts = []
        else:
            answer_text = ""
            answer_starts = []
        
        if isinstance(answer_text, str):
            answer_texts = [answer_text] if answer_text else []
        else:
            answer_texts = answer_text if answer_text else []
        
        if not isinstance(answer_starts, list):
            answer_starts = [answer_starts] if answer_starts else []
        
        if not answer_texts or not answer_texts[0].strip():
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            answer_text = answer_texts[0]
            start_char = answer_starts[0] if answer_starts else 0
            end_char = start_char + len(answer_text)

            token_start_index = 0
            while token_start_index < len(sequence_ids) and sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while token_end_index >= 0 and sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if (token_start_index >= len(offsets) or token_end_index >= len(offsets) or 
                token_start_index > token_end_index or
                not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char)):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                
                while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Sample from train dataset:
Keys: dict_keys(['question', 'context', 'lang', 'answerable', 'answer_start', 'answer', 'answer_inlang'])
Answer structure: France
Answer type: <class 'str'>


Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
wandb: Currently logged in as: aarushsinha60 (chungimungi) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss
1,No log,1.056749


Perplexity: 2.88
