In [None]:
import os; os.environ["WANDB_DISABLED"] = "true" # turn off because otherwise it would ask for a password?
import pandas as pd, numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

In [None]:
# Two QA-Head (Start/End) and one IO-Token-Labeling
# evaluation per language (EM/F1), Unanswerable => empty/all O

# Load data and only keep the relevant languages for comutational reasons
langs = ["ar","ko","te"]
splits = {'train': 'train.parquet', 'validation': 'validation.parquet'}
df_train = pd.read_parquet("hf://datasets/coastalcph/tydi_xor_rc/" + splits["train"])
df_val = pd.read_parquet("hf://datasets/coastalcph/tydi_xor_rc/" + splits["validation"])
df_train = df_train[df_train.lang.isin(langs)].reset_index(drop=True)
df_val = df_val[df_val.lang.isin(langs)].reset_index(drop=True)

# Convert to Huggingface Datasets and only keep relevant columns for computational reasons
train_ds = Dataset.from_pandas(df_train[["lang","question","context","answerable","answer_start","answer"]], preserve_index=False)
val_ds = Dataset.from_pandas(df_val[["lang","question","context","answerable","answer_start","answer"]], preserve_index=False)

In [None]:
#### Prepocessing QA-Head (Start/End-Targets)
def build_preprocess(tokenizer, max_length=384):
    no_answer_index = 0
    def preprocess_qa(examples):
        questions = examples["question"]
        contexts = examples["context"]
        answers = examples["answer"] if examples["answerable"] else ""
        answer_starts = examples["answer_start"] if examples["answerable"] else -1

        tokenized_examples = tokenizer(
            questions,
            contexts,
            truncation="only_second", # trunc only the context, not question, risk is that important part of context is trunced
            max_length=max_length,
            return_offsets_mapping=True,
            return_token_type_ids=True
        )

        # unanswerable
        start = end = no_answer_index

        if examples["answerable"]:
            seq_ids = tokenized_examples.sequence_ids() # marks which tokens come from questions (0), context (1)
            offset_mapping = tokenized_examples["offset_mapping"]
            conext_token_indices = [i for i, s in enumerate(seq_ids) if s == 1]

            if len(conext_token_indices) > 0:
                context_start = conext_token_indices[0]
                context_end = conext_token_indices[-1]
                answer0 = answer_starts
                answer1 = answer_starts + len(answers)
                i = context_start

                # move i forward while the token ends before (or exactly at) the answer start
                while i <= context_end and (offset_mapping[i][0] <= answer0 and offset_mapping[i][1] <= answer0):
                    i += 1
                # step back if we overshot so that offsets[i-1] covers the answer start
                while i > context_start and offset_mapping[i-1][0] <= answer0 < offset_mapping[i-1][1]:
                    i -= 1

                # if i is within the context expand to cover the full answer span
                if context_start <= i <= context_end:
                    j = i
                    while j <= context_end and offset_mapping[j][0] < answer1:
                        j += 1
                    start = i
                    end = min(j-1, context_end) # last token that still overlaps with the answer

        tokenized_examples["start_positions"] = start
        tokenized_examples["end_positions"] = end
        tokenized_examples.pop("offset_mapping", None)

        return tokenized_examples
    return preprocess_qa


In [None]:
### Metrics for QA-Head - span as IO-Token-Labeling (EM/F1 per language)
def compute_metrics_qa(eval_pred):
    # pick the most likely start and end position per example
    start_logits = np.argmax(eval_pred.predictions[0], -1)
    end_logits = np.argmax(eval_pred.predictions[1], -1)
    labels = eval_pred.label_ids

    if isinstance(labels, dict):
        gold_start = labels["start_positions"]
        gold_end = labels["end_positions"]
    elif isinstance(labels, (list, tuple)) and len(labels) == 2:
        gold_start, gold_end = labels
    else:
        gold_start = labels
        gold_end = labels

    # create set of token indices
    def span_to_set(start, end):
        # cast both to scalar ints to avoid "truth value of an array is ambiguous" errors
        s = int(np.asarray(start).reshape(-1)[0])
        e = int(np.asarray(end).reshape(-1)[0])
        if s == 0 and e == 0: # unanswerable
            return set()
        return set(range(s, e + 1))

    em_list = []
    f1_list = []

    for i,j,k,l in zip(start_logits, end_logits, gold_start, gold_end):
        predicted_tokens = span_to_set(i, j)
        gold_tokens = span_to_set(k, l)
        em_list.append(int(predicted_tokens == gold_tokens))

        if not predicted_tokens and not gold_tokens:
            f1_list.append(1)
        elif not predicted_tokens or not gold_tokens:
            f1_list.append(0)
        else:
            intersection = len(predicted_tokens.intersection(gold_tokens))
            precision = intersection / len(predicted_tokens)
            recall = intersection / len(gold_tokens)
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            f1_list.append(f1)

    return {"exact_match": np.mean(em_list), "f1_token": np.mean(f1_list)}


In [None]:
## Preprocessing Token-Labeling-Head (IO-Targets)
def build_preprocess_token(tokenizer, max_length=384, stride=128):
    def preprocess_token(examples):
        questions = examples["question"]
        contexts = examples["context"]
        answers = examples["answer"] if examples["answerable"] else ""
        answer_starts = examples["answer_start"] if examples["answerable"] else -1

        # tokenize questions and context
        tokenized_examples = tokenizer(
            questions,
            contexts,
            truncation="only_second", # trunc only the context, not question, risk is that important part of context is trunced
            max_length=max_length,
            return_offsets_mapping=True,
            return_token_type_ids=True
        )

        sequence_ids = tokenized_examples.sequence_ids() # also here mark which tokens come from questions, context and special tokens
        offset_mapping = tokenized_examples["offset_mapping"]
        labels = np.full(len(tokenized_examples["input_ids"]), -100) # nitialize all positions with -100 so the loss ignores them
        context_token_indices = [i for i, s in enumerate(sequence_ids) if s == 1]

        if context_token_indices:
            context_start = context_token_indices[0]
            context_end = context_token_indices[-1]
            labels[context_start:context_end+1] = 0  # default O-label

            if examples["answerable"]:
                answer0 = answer_starts
                answer1 = answer_starts + len(answers)

                for i in range(context_start, context_end + 1):
                    start, end = offset_mapping[i]

                    if not (end <= answer0 or start >= answer1):
                        labels[i] = 1  # mark as 1 if answer span is overlapped

        del tokenized_examples["offset_mapping"] # offsets are no longer needed by the Trainer
        tokenized_examples["labels"] = labels.tolist()

        return tokenized_examples

    return preprocess_token

In [None]:
## Metrics for Token-Labeling-Head (F1 per language)

def compute_metrics_token(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=-1)
    labels = eval_pred.label_ids
    mask = labels != -100 # bool mask which is true where only the IO tokens are
    true = labels[mask]
    predictions = predictions[mask]
    tp = np.sum((predictions == 1) & (true == 1))
    fp = np.sum((predictions == 1) & (true == 0))
    fn = np.sum((predictions == 0) & (true == 1))
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return {"f1_token": f1}

In [None]:
### Train and evaluate models for each language

from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

MODELS = [
    "google-bert/bert-base-multilingual-cased",
    "distilbert/distilbert-base-multilingual-cased",
    "xlm-roberta-base", # IO-Token-Classifier
]

for i, model_name in enumerate(MODELS):
    print(f"Training model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    data_collator = None
    if i < 2:
        # QA-Head (Start/End)
        prep = build_preprocess(tokenizer)
        train_prep = train_ds.map(prep, remove_columns=train_ds.column_names)
        val_prep = val_ds.map(prep, remove_columns=val_ds.column_names)
        model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        compute_metrics = compute_metrics_qa
    else:
        # IO-Token-Labeling
        prep = build_preprocess_token(tokenizer)
        train_prep = train_ds.map(prep, remove_columns=train_ds.column_names)
        val_prep = val_ds.map(prep, remove_columns=val_ds.column_names)
        model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=2)
        compute_metrics = compute_metrics_token
        data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=f"wk40_{model_name.split('/')[-1]}",
        learning_rate=2e-5,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        weight_decay=0.01,
        logging_steps=100,
        do_eval=True,
        report_to="none",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_prep,
        eval_dataset=val_prep,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
    )

    trainer.train()

    # Evaluation per language
    for lang in langs:
        val_lang = val_ds.filter(lambda ex, L=lang: ex["lang"] == L)
        val_lang_prep = val_lang.map(prep, remove_columns=val_lang.column_names)
        metrics = trainer.evaluate(eval_dataset=val_lang_prep)
        print(f"VAL [{lang}]: {metrics}")


Training model: google-bert/bert-base-multilingual-cased


Map:   0%|          | 0/6335 [00:00<?, ? examples/s]

Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-multilingual-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
100,3.2865
200,2.2815
300,2.0455
400,1.6993
500,1.7561
600,1.6813
700,1.5587
800,1.5105
900,1.2843
1000,1.1895


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/415 [00:00<?, ? examples/s]

VAL [ar]: {'eval_loss': 1.568082332611084, 'eval_exact_match': 0.5831325301204819, 'eval_f1_token': 0.673734396556835, 'eval_runtime': 6.1008, 'eval_samples_per_second': 68.024, 'eval_steps_per_second': 8.523, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/356 [00:00<?, ? examples/s]

VAL [ko]: {'eval_loss': 1.299232840538025, 'eval_exact_match': 0.5646067415730337, 'eval_f1_token': 0.6593058611377222, 'eval_runtime': 5.0213, 'eval_samples_per_second': 70.898, 'eval_steps_per_second': 8.962, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/384 [00:00<?, ? examples/s]

VAL [te]: {'eval_loss': 1.9579099416732788, 'eval_exact_match': 0.46875, 'eval_f1_token': 0.5288466885358427, 'eval_runtime': 5.4796, 'eval_samples_per_second': 70.078, 'eval_steps_per_second': 8.76, 'epoch': 3.0}
Training model: distilbert/distilbert-base-multilingual-cased


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/6335 [00:00<?, ? examples/s]

Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
100,3.665
200,2.6402
300,2.6567
400,2.397
500,2.4046
600,2.3485
700,2.176
800,2.041
900,1.7535
1000,1.686


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/415 [00:00<?, ? examples/s]

VAL [ar]: {'eval_loss': 1.6698248386383057, 'eval_exact_match': 0.4819277108433735, 'eval_f1_token': 0.5521112892688683, 'eval_runtime': 3.3172, 'eval_samples_per_second': 125.105, 'eval_steps_per_second': 15.676, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/356 [00:00<?, ? examples/s]

VAL [ko]: {'eval_loss': 1.8707506656646729, 'eval_exact_match': 0.4157303370786517, 'eval_f1_token': 0.5132133176311409, 'eval_runtime': 2.753, 'eval_samples_per_second': 129.314, 'eval_steps_per_second': 16.346, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/384 [00:00<?, ? examples/s]

VAL [te]: {'eval_loss': 2.0360074043273926, 'eval_exact_match': 0.4270833333333333, 'eval_f1_token': 0.48840393934986964, 'eval_runtime': 2.9825, 'eval_samples_per_second': 128.752, 'eval_steps_per_second': 16.094, 'epoch': 3.0}
Training model: xlm-roberta-base


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Map:   0%|          | 0/6335 [00:00<?, ? examples/s]

Map:   0%|          | 0/1155 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
100,0.1652
200,0.1215
300,0.135
400,0.1183
500,0.1129
600,0.1134
700,0.115
800,0.0946
900,0.0876
1000,0.0862


Step,Training Loss
100,0.1652
200,0.1215
300,0.135
400,0.1183
500,0.1129
600,0.1134
700,0.115
800,0.0946
900,0.0876
1000,0.0862


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/415 [00:00<?, ? examples/s]

VAL [ar]: {'eval_loss': 0.08154372125864029, 'eval_f1_token': 0.5489874110563766, 'eval_runtime': 6.8112, 'eval_samples_per_second': 60.929, 'eval_steps_per_second': 7.635, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/356 [00:00<?, ? examples/s]

VAL [ko]: {'eval_loss': 0.0917949229478836, 'eval_f1_token': 0.5728643216080402, 'eval_runtime': 5.6355, 'eval_samples_per_second': 63.171, 'eval_steps_per_second': 7.985, 'epoch': 3.0}


Filter:   0%|          | 0/1155 [00:00<?, ? examples/s]

Map:   0%|          | 0/384 [00:00<?, ? examples/s]

VAL [te]: {'eval_loss': 0.0752517506480217, 'eval_f1_token': 0.45024763619991, 'eval_runtime': 6.0795, 'eval_samples_per_second': 63.163, 'eval_steps_per_second': 7.895, 'epoch': 3.0}
