## Fine-Tuning BERT

In [1]:
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch
import math

dataset = load_dataset("coastalcph/tydi_xor_rc")

languages = ['ar', 'ko', 'te']
train_dataset = dataset["train"].filter(lambda example: example['lang'] in languages)
val_dataset = dataset["validation"].filter(lambda example: example['lang'] in languages)

print("Sample from train dataset:")
sample = train_dataset[0]
print(f"Keys: {sample.keys()}")
print(f"Answer structure: {sample['answer']}")
print(f"Answer type: {type(sample['answer'])}")

model_checkpoint = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 384
doc_stride = 128

def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    contexts = [c.strip() for c in examples["context"]]
    
    tokenized_examples = tokenizer(
        questions,
        contexts,
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = tokenized_examples.pop("offset_mapping")
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        sequence_ids = tokenized_examples.sequence_ids(i)

        sample_index = sample_mapping[i]
        
        answer_data = examples["answer"][sample_index]
        
        if isinstance(answer_data, dict):
            if "text" in answer_data:
                answer_text = answer_data["text"]
                answer_starts = answer_data.get("answer_start", [])
            else:
                answer_text = answer_data.get("answer_text", answer_data.get("answers", ""))
                answer_starts = answer_data.get("answer_start", answer_data.get("answer_starts", []))
        elif isinstance(answer_data, str):
            answer_text = answer_data
            answer_starts = []
        else:
            answer_text = ""
            answer_starts = []
        
        if isinstance(answer_text, str):
            answer_texts = [answer_text] if answer_text else []
        else:
            answer_texts = answer_text if answer_text else []
        
        if not isinstance(answer_starts, list):
            answer_starts = [answer_starts] if answer_starts else []
        
        if not answer_texts or not answer_texts[0].strip():
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            answer_text = answer_texts[0]
            start_char = answer_starts[0] if answer_starts else 0
            end_char = start_char + len(answer_text)

            token_start_index = 0
            while token_start_index < len(sequence_ids) and sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while token_end_index >= 0 and sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if (token_start_index >= len(offsets) or token_end_index >= len(offsets) or 
                token_start_index > token_end_index or
                not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char)):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                
                while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=25,
    logging_strategy="steps",
    logging_steps=25,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

print("Overall Evaluation")
eval_results = trainer.evaluate()
print(f"Overall Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
print(f"Overall Loss: {eval_results['eval_loss']:.4f}")

languages = ['ar', 'ko', 'te']
print("\nLanguage-specific Evaluations")

for lang in languages:
    print(f"\nEvaluating {lang.upper()}")
    
    lang_val_dataset = val_dataset.filter(lambda example: example['lang'] == lang)
    print(f"Number of {lang} validation examples: {len(lang_val_dataset)}")
    
    if len(lang_val_dataset) == 0:
        print(f"No validation examples found for language: {lang}")
        continue
    
    tokenized_lang_val = lang_val_dataset.map(
        preprocess_function, 
        batched=True, 
        remove_columns=lang_val_dataset.column_names
    )
    
    lang_trainer = Trainer(
        model=model,
        args=training_args,
        eval_dataset=tokenized_lang_val,
        tokenizer=tokenizer,
    )
    
    lang_eval_results = lang_trainer.evaluate()
    print(f"{lang.upper()} Perplexity: {math.exp(lang_eval_results['eval_loss']):.2f}")
    print(f"{lang.upper()} Loss: {lang_eval_results['eval_loss']:.4f}")
    
    for key, value in lang_eval_results.items():
        if key not in ['eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch']:
            print(f"{lang.upper()} {key}: {value:.4f}")

print("Evaluation completed for all languages")

Sample from train dataset:
Keys: dict_keys(['question', 'context', 'lang', 'answerable', 'answer_start', 'answer', 'answer_inlang'])
Answer structure: France
Answer type: <class 'str'>


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
wandb: Currently logged in as: aarushsinha60 (chungimungi) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss
25,4.5146,3.173701
50,2.1181,1.284127
75,1.4246,1.214832
100,1.2826,1.151535
125,1.2195,1.121558
150,1.1836,1.111004
175,1.1539,1.115361
200,1.2134,1.081321
225,1.074,1.087571
250,1.1417,1.079575


Overall Evaluation


Overall Perplexity: 2.88
Overall Loss: 1.0582

Language-specific Evaluations

Evaluating AR
Number of ar validation examples: 415


  lang_trainer = Trainer(


AR Perplexity: 2.94
AR Loss: 1.0793
AR eval_model_preparation_time: 0.0020

Evaluating KO
Number of ko validation examples: 356


Map:   0%|          | 0/356 [00:00<?, ? examples/s]

KO Perplexity: 2.86
KO Loss: 1.0503
KO eval_model_preparation_time: 0.0020

Evaluating TE
Number of te validation examples: 384


TE Perplexity: 2.83
TE Loss: 1.0420
TE eval_model_preparation_time: 0.0030
Evaluation completed for all languages


## RNN (LSTM)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
import math
from tqdm import tqdm
from collections import Counter
import re

class WordTokenizer:
    def __init__(self, vocab, unk_token='<unk>', pad_token='<pad>'):
        self.word2idx = {word: i for i, word in enumerate(vocab)}
        self.idx2word = {i: word for i, word in enumerate(vocab)}
        self.unk_token = unk_token
        self.pad_token = pad_token
        self.unk_token_id = self.word2idx.get(self.unk_token)
        self.pad_token_id = self.word2idx.get(self.pad_token)
        self.vocab_size = len(vocab)

    def encode(self, text, max_length=None, padding=False, truncation=False):
        words = re.findall(r"\w+|[^\s\w]", text.lower())
        token_ids = [self.word2idx.get(word, self.unk_token_id) for word in words]

        if truncation and max_length:
            token_ids = token_ids[:max_length]

        if padding and max_length:
            padded_length = max_length - len(token_ids)
            token_ids.extend([self.pad_token_id] * padded_length)
        
        return token_ids

def build_tokenizer(dataset, languages, vocab_size=30000):
    print("Building vocabulary from the training dataset...")
    word_counts = Counter()
    
    train_dataset = dataset["train"].filter(lambda example: example['lang'] in languages)

    for example in tqdm(train_dataset, desc="Counting words"):
        question_words = re.findall(r"\w+|[^\s\w]", example['question'].lower())
        context_words = re.findall(r"\w+|[^\s\w]", example['context'].lower())
        word_counts.update(question_words)
        word_counts.update(context_words)

    special_tokens = ['<pad>', '<unk>']
    vocab = special_tokens + [word for word, count in word_counts.most_common(vocab_size - len(special_tokens))]
    
    print(f"Vocabulary built. Total unique words found: {len(word_counts)}. Vocab size: {len(vocab)}")
    
    return WordTokenizer(vocab)

class QADataset(Dataset):
    def __init__(self, dataset, tokenizer, max_q_len=128, max_c_len=512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_q_len = max_q_len
        self.max_c_len = max_c_len

    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        example = self.dataset[idx]
        question = example['question'].strip()
        context = example['context'].strip()
        
        question_ids = self.tokenizer.encode(question, max_length=self.max_q_len, padding=True, truncation=True)
        context_ids = self.tokenizer.encode(context, max_length=self.max_c_len, padding=True, truncation=True)
        
        answer_text = example.get('answer', '')
        answer_start = example.get('answer_start')

        answer_start_char = 0
        if isinstance(answer_start, list) and answer_start:
            answer_start_char = answer_start[0]
        elif isinstance(answer_start, int):
            answer_start_char = answer_start
        
        start_token_idx, end_token_idx = 0, 0
        
        if answer_text:
            context_lower = context.lower()
            answer_lower = answer_text.lower()
            answer_end_char = answer_start_char + len(answer_text)
            
            char_to_token = {}
            tokens_with_spans = [(m.group(0), m.start(), m.end()) for m in re.finditer(r"\w+|[^\s\w]", context_lower)]

            for i, (token, start, end) in enumerate(tokens_with_spans):
                if i >= self.max_c_len:
                    break
                for char_idx in range(start, end):
                    char_to_token[char_idx] = i

            start_token_idx = char_to_token.get(answer_start_char, 0)
            end_token_idx = char_to_token.get(answer_end_char - 1, 0)

            if start_token_idx > end_token_idx:
                start_token_idx = end_token_idx

            if start_token_idx >= self.max_c_len or end_token_idx >= self.max_c_len:
                start_token_idx, end_token_idx = 0, 0

        return {
            'question_ids': torch.tensor(question_ids, dtype=torch.long),
            'context_ids': torch.tensor(context_ids, dtype=torch.long),
            'start_position': torch.tensor(start_token_idx, dtype=torch.long),
            'end_position': torch.tensor(end_token_idx, dtype=torch.long),
        }

class LSTMQuestionAnswering(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=4, dropout=0.3, padding_idx=0):
        super(LSTMQuestionAnswering, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.question_lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.context_lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        
        self.bi_hidden_dim = hidden_dim * 2
        
        self.self_attention = nn.MultiheadAttention(self.bi_hidden_dim, num_heads=8, dropout=dropout, batch_first=True)
        self.cross_attention = nn.MultiheadAttention(self.bi_hidden_dim, num_heads=8, dropout=dropout, batch_first=True)
        
        self.layer_norm1 = nn.LayerNorm(self.bi_hidden_dim)
        self.layer_norm2 = nn.LayerNorm(self.bi_hidden_dim)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(self.bi_hidden_dim, self.bi_hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.bi_hidden_dim * 4, self.bi_hidden_dim)
        )
        self.start_classifier = nn.Linear(self.bi_hidden_dim, 1)
        self.end_classifier = nn.Linear(self.bi_hidden_dim, 1)

    def forward(self, question_ids, context_ids):
        question_embedded = self.embedding(question_ids)
        context_embedded = self.embedding(context_ids)
        
        question_output, _ = self.question_lstm(question_embedded)
        context_output, _ = self.context_lstm(context_embedded)
        
        self_attn_output, _ = self.self_attention(context_output, context_output, context_output)
        context_output = self.layer_norm1(context_output + self_attn_output)
        
        cross_attn_output, _ = self.cross_attention(context_output, question_output, question_output)
        final_output = self.layer_norm2(context_output + cross_attn_output)
        
        ff_output = self.feed_forward(final_output)
        final_output = final_output + ff_output
        
        start_logits = self.start_classifier(final_output).squeeze(-1)
        end_logits = self.end_classifier(final_output).squeeze(-1)
        
        return start_logits, end_logits

def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc="Training")
    for batch in pbar:
        optimizer.zero_grad()
        start_logits, end_logits = model(
            question_ids=batch['question_ids'].to(device),
            context_ids=batch['context_ids'].to(device)
        )
        start_positions = batch['start_position'].to(device)
        end_positions = batch['end_position'].to(device)
        
        start_loss = F.cross_entropy(start_logits, start_positions)
        end_loss = F.cross_entropy(end_logits, end_positions)
        loss = start_loss + end_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({'avg_loss': f'{total_loss / (pbar.n + 1):.4f}'})
    return total_loss / len(train_loader)

def evaluate_model(model, val_loader, device):
    model.eval()
    total_loss = 0
    total_start_acc, total_end_acc = 0, 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            start_logits, end_logits = model(
                question_ids=batch['question_ids'].to(device),
                context_ids=batch['context_ids'].to(device)
            )
            start_positions = batch['start_position'].to(device)
            end_positions = batch['end_position'].to(device)

            start_loss = F.cross_entropy(start_logits, start_positions)
            end_loss = F.cross_entropy(end_logits, end_positions)
            loss = start_loss + end_loss
            total_loss += loss.item()
            
            total_start_acc += (start_logits.argmax(dim=1) == start_positions).sum().item()
            total_end_acc += (end_logits.argmax(dim=1) == end_positions).sum().item()

    num_samples = len(val_loader.dataset)
    return total_loss / len(val_loader), total_start_acc / num_samples, total_end_acc / num_samples

def evaluate_by_language(model, dataset_raw, tokenizer, languages, device, batch_size=16):
    results = {}
    for lang in languages:
        print(f"\nEvaluating language: {lang}")
        lang_dataset_raw = dataset_raw.filter(lambda example: example['lang'] == lang)
        
        if len(lang_dataset_raw) == 0:
            print(f"No examples found for language: {lang}")
            continue

        lang_dataset = QADataset(lang_dataset_raw, tokenizer)
        lang_loader = DataLoader(lang_dataset, batch_size=batch_size, shuffle=False)
        
        avg_loss, start_acc, end_acc = evaluate_model(model, lang_loader, device)
        
        try:
            perplexity = math.exp(avg_loss)
        except OverflowError:
            perplexity = float('inf')
            
        results[lang] = {
            'loss': avg_loss,
            'perplexity': perplexity,
            'start_accuracy': start_acc,
            'end_accuracy': end_acc
        }
        
        print(f"{lang.upper()} | Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f} | Start Acc: {start_acc:.4f} | End Acc: {end_acc:.4f}")
    return results

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print("Loading dataset...")
    dataset = load_dataset("coastalcph/tydi_xor_rc")
    languages = ['ar', 'ko', 'te']
    
    tokenizer = build_tokenizer(dataset, languages, vocab_size=30000)
    
    train_dataset_raw = dataset["train"].filter(lambda example: example['lang'] in languages)
    val_dataset_raw = dataset["validation"].filter(lambda example: example['lang'] in languages)

    train_dataset = QADataset(train_dataset_raw, tokenizer)
    val_dataset = QADataset(val_dataset_raw, tokenizer)

    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = LSTMQuestionAnswering(
        vocab_size=tokenizer.vocab_size,
        padding_idx=tokenizer.pad_token_id
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=1)

    best_val_loss = float('inf')
    patience_counter = 0
    num_epochs = 2

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss, start_acc, end_acc = evaluate_model(model, val_loader, device)
        scheduler.step(val_loss)
        
        try:
            perplexity = math.exp(val_loss)
        except OverflowError:
            perplexity = float('inf')
        
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Perplexity: {perplexity:.2f}")
        print(f"Start Accuracy: {start_acc:.4f} | End Accuracy: {end_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= 3:
                print("Early stopping triggered.")
                break
    
    print("\n" + "="*20)
    print("Final Overall Evaluation")
    print("="*20)
    
    overall_loss, overall_start_acc, overall_end_acc = evaluate_model(model, val_loader, device)
    
    try:
        overall_perplexity = math.exp(overall_loss)
    except OverflowError:
        overall_perplexity = float('inf')
        
    print(f"Overall Loss: {overall_loss:.4f}")
    print(f"Overall Perplexity: {overall_perplexity:.2f}")
    print(f"Overall Start Position Accuracy: {overall_start_acc:.4f}")
    print(f"Overall End Position Accuracy: {overall_end_acc:.4f}")
    
    print("\n" + "="*20)
    print("Final Language-Specific Evaluation")
    print("="*20)
    
    lang_results = evaluate_by_language(model, val_dataset_raw, tokenizer, languages, device)

if __name__ == "__main__":
    main()

Using device: cuda
Loading dataset...
Building vocabulary from the training dataset...


Counting words: 100%|██████████| 6335/6335 [00:00<00:00, 11001.49it/s]


Vocabulary built. Total unique words found: 51783. Vocab size: 30000
Model parameters: 19,086,850

Epoch 1/2


Training: 100%|██████████| 396/396 [02:19<00:00,  2.84it/s, avg_loss=7.3613]
Evaluating: 100%|██████████| 73/73 [00:10<00:00,  7.03it/s]


Train Loss: 7.3613 | Val Loss: 6.3509 | Val Perplexity: 573.01
Start Accuracy: 0.2658 | End Accuracy: 0.2104

Epoch 2/2


Training: 100%|██████████| 396/396 [02:25<00:00,  2.72it/s, avg_loss=5.9130]
Evaluating: 100%|██████████| 73/73 [00:12<00:00,  6.02it/s]


Train Loss: 5.9130 | Val Loss: 5.9307 | Val Perplexity: 376.42
Start Accuracy: 0.2831 | End Accuracy: 0.2312

Final Overall Evaluation


Evaluating: 100%|██████████| 73/73 [00:12<00:00,  5.98it/s]


Overall Loss: 5.9307
Overall Perplexity: 376.42
Overall Start Position Accuracy: 0.2831
Overall End Position Accuracy: 0.2312

Final Language-Specific Evaluation

Evaluating language: ar


Evaluating: 100%|██████████| 26/26 [00:03<00:00,  7.62it/s]


AR | Loss: 5.9678 | Perplexity: 390.63 | Start Acc: 0.2458 | End Acc: 0.2386

Evaluating language: ko


Evaluating: 100%|██████████| 23/23 [00:04<00:00,  5.27it/s]


KO | Loss: 5.9851 | Perplexity: 397.47 | Start Acc: 0.2528 | End Acc: 0.2275

Evaluating language: te


Evaluating: 100%|██████████| 24/24 [00:04<00:00,  5.97it/s]

TE | Loss: 5.9362 | Perplexity: 378.50 | Start Acc: 0.3516 | End Acc: 0.2266



