In [None]:
! pip install datasets

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Part 1: Reading Comprehension System - BiDAF Implementation
#          with SQuAD-Azerbaijani Dataset
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import warnings
from datasets import load_dataset
from transformers import AutoTokenizer # Using a transformer tokenizer for convenience
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm # For progress bars

# Suppress dropout warning for LSTM
warnings.filterwarnings("ignore", message="dropout option adds dropout after all but last recurrent layer")


# --- Helper Functions (Keep from previous BiDAF code) ---
def get_mask_from_lengths(lengths, max_len=None):
    """
    Creates a boolean mask from sequence lengths.
    True indicates valid positions, False indicates padding.
    Args:
        lengths (torch.Tensor): Tensor of sequence lengths (batch_size,).
        max_len (int, optional): Maximum sequence length. If None, uses max(lengths).
    Returns:
        torch.Tensor: Boolean mask (batch_size, max_len).
    """
    if max_len is None:
        max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0) # (1, max_len)
    mask = (ids < lengths.unsqueeze(1)) # (batch_size, max_len)
    return mask

def masked_softmax(logits, mask, dim=-1):
    """
    Performs softmax on logits, masking out invalid positions.
    Args:
        logits (torch.Tensor): Input logits.
        mask (torch.Tensor): Boolean mask (same shape as logits or broadcastable).
                             True indicates valid positions.
        dim (int): Dimension along which to perform softmax.
    Returns:
        torch.Tensor: Softmax probabilities.
    """
    if mask is None:
        return F.softmax(logits, dim=dim)
    # Ensure mask is boolean
    if mask.dtype != torch.bool:
         mask = mask.bool()
    masked_logits = logits.masked_fill(~mask, -float('inf'))
    return F.softmax(masked_logits, dim=dim)


# --- BiDAF Model Components (Keep from previous BiDAF code) ---

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, dropout_prob=0.1, padding_idx=0):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.dropout = nn.Dropout(dropout_prob)
        print(f"EmbeddingLayer: Vocab size={vocab_size}, Dim={embedding_dim}, Pad Idx={padding_idx}")

    def forward(self, input_ids):
        embedded = self.word_embedding(input_ids)
        embedded = self.dropout(embedded)
        return embedded

class ContextualEmbeddingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, dropout_prob=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                          bidirectional=True, batch_first=True, dropout=dropout_prob if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout_prob)
        print(f"ContextualEmbeddingLayer: Input Dim={input_dim}, Hidden Dim={hidden_dim}, Output Dim={hidden_dim*2}")

    def forward(self, embedded_sequence, sequence_lengths):
        # Ensure lengths are on CPU for pack_padded_sequence
        lengths_cpu = sequence_lengths.cpu()
        # Handle potential zero lengths (can happen with heavy truncation)
        if torch.any(lengths_cpu <= 0):
            # Option 1: Return zeros (simple, but check downstream effects)
            batch_size, seq_len, embed_dim = embedded_sequence.shape
            device = embedded_sequence.device
            # Output should be (batch_size, seq_len, 2 * hidden_dim)
            zero_output = torch.zeros(batch_size, seq_len, 2 * self.hidden_dim, device=device)
            # Need to handle which items had zero length if mixing non-zero/zero
            # This simple version assumes if any are zero, maybe all are problematic - needs refinement
            # A better approach is to filter these examples earlier if possible.
            # Or, process non-zero length items only.
            # For now, let's just proceed but be aware of this potential issue.
            print(f"Warning: Zero length sequence found in batch. Lengths: {lengths_cpu}")
            # Return zeros for safety in this example, though ideally filter these upstream
            # return zero_output

            # Option 2: Proceed, but pack_padded_sequence might complain or give unexpected results
            # If lengths_cpu contains zeros, ensure enforce_sorted=False.
            pass # Let pack_padded_sequence handle it (might raise error if all are 0)


        # Sort lengths for packing if needed, or use enforce_sorted=False (safer)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded_sequence, lengths_cpu, batch_first=True, enforce_sorted=False
        )
        packed_outputs, _ = self.lstm(packed_embedded)
        lstm_outputs, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outputs, batch_first=True, total_length=embedded_sequence.shape[1] # Ensure original seq len
        )
        lstm_outputs = self.dropout(lstm_outputs)
        return lstm_outputs

class AttentionFlowLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.input_dim = 2 * hidden_dim
        self.similarity_weight = nn.Parameter(torch.Tensor(1, 1, 3 * self.input_dim))
        nn.init.xavier_uniform_(self.similarity_weight)
        print(f"AttentionFlowLayer: Input Dim (H/U)={self.input_dim}, Similarity Weight Dim={3 * self.input_dim}")

    def forward(self, H, U, context_mask, question_mask):
        batch_size, context_len, _ = H.shape
        _, question_len, _ = U.shape

        # 1. Similarity Matrix (S)
        H_expanded = H.unsqueeze(2).expand(-1, -1, question_len, -1)
        U_expanded = U.unsqueeze(1).expand(-1, context_len, -1, -1)
        elementwise_prod = H_expanded * U_expanded
        concat_features = torch.cat([H_expanded, U_expanded, elementwise_prod], dim=3)
        similarity_weight_expanded = self.similarity_weight.unsqueeze(1)
        S = torch.sum(concat_features * similarity_weight_expanded, dim=3)

        # Apply masks before softmax
        context_mask_expanded = context_mask.unsqueeze(2).bool()
        question_mask_expanded = question_mask.unsqueeze(1).bool()
        S_mask = context_mask_expanded & question_mask_expanded

        # 2. Context-to-Question (C2Q) Attention
        alpha = masked_softmax(S, S_mask, dim=2)
        # Handle cases where all mask entries in a row are False (results in NaN after softmax)
        # If alpha sums to 0 (due to all -inf), replace NaNs with 0
        alpha = torch.nan_to_num(alpha)
        U_tilde = torch.bmm(alpha, U) # (batch, c_len, 2*hidden)

        # 3. Question-to-Context (Q2C) Attention
        # Max over question dimension (returns values, indices)
        # Need to handle cases where S_mask makes entire rows invalid
        S_masked_inf = S.masked_fill(~S_mask, -float('inf'))
        m, _ = torch.max(S_masked_inf, dim=2) # (batch, c_len)
        # If a context position had no valid question comparisons, max will be -inf
        # Mask m based on original context mask before softmax
        beta = masked_softmax(m, context_mask.bool(), dim=1) # (batch, c_len)
        beta = torch.nan_to_num(beta) # Handle potential NaNs if context_mask is all False for an item

        # Weighted sum of context embeddings H
        H_tilde = torch.bmm(H.transpose(1, 2), beta.unsqueeze(2)).squeeze(2) # (batch, 2*hidden)
        H_tilde_expanded = H_tilde.unsqueeze(1).expand(-1, context_len, -1) # (batch, c_len, 2*hidden)

        # 4. Combine representations
        G = torch.cat([
            H,                     # (batch, c_len, 2*hidden)
            U_tilde,               # (batch, c_len, 2*hidden)
            H * U_tilde,           # (batch, c_len, 2*hidden)
            H * H_tilde_expanded   # (batch, c_len, 2*hidden)
        ], dim=2) # (batch, c_len, 8 * hidden_dim)

        # print(f"AttentionFlowLayer Output (G) Dim: {G.shape[-1]} (Expected 8*hidden={8*self.input_dim//2})")
        return G

class ModelingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                          bidirectional=True, batch_first=True, dropout=dropout_prob)
        self.dropout = nn.Dropout(dropout_prob)
        print(f"ModelingLayer: Input Dim={input_dim}, Hidden Dim={hidden_dim}, Num Layers={num_layers}, Output Dim={hidden_dim*2}")

    def forward(self, G, context_lengths):
        # Similar length handling as in ContextualEmbeddingLayer
        lengths_cpu = context_lengths.cpu()
        if torch.any(lengths_cpu <= 0):
            print(f"Warning: Zero length context found in ModelingLayer. Lengths: {lengths_cpu}")
            # Fallback: return zeros or handle appropriately
            # return torch.zeros_like(G[:, :, :2*self.hidden_dim]) # Match expected output dim

        packed_G = nn.utils.rnn.pack_padded_sequence(
            G, lengths_cpu, batch_first=True, enforce_sorted=False
        )
        packed_outputs, _ = self.lstm(packed_G)
        M, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outputs, batch_first=True, total_length=G.shape[1]
        )
        M = self.dropout(M)
        # print(f"ModelingLayer Output (M) Dim: {M.shape[-1]} (Expected 2*hidden={2*self.hidden_dim})")
        return M

class OutputLayer(nn.Module):
    def __init__(self, G_dim, M_dim):
        super().__init__()
        self.input_dim = G_dim + M_dim
        self.start_linear = nn.Linear(self.input_dim, 1)
        self.end_linear = nn.Linear(self.input_dim, 1)
        print(f"OutputLayer: Input Dim for prediction ([G;M])={self.input_dim}")

    def forward(self, G, M, context_mask):
        combined_GM = torch.cat([G, M], dim=2)
        logits_start = self.start_linear(combined_GM).squeeze(-1)
        logits_end = self.end_linear(combined_GM).squeeze(-1)

        # Ensure mask is boolean
        if context_mask.dtype != torch.bool:
             context_mask = context_mask.bool()

        # Apply mask before returning logits
        masked_logits_start = logits_start.masked_fill(~context_mask, -float('inf'))
        masked_logits_end = logits_end.masked_fill(~context_mask, -float('inf'))

        return masked_logits_start, masked_logits_end

# --- Full BiDAF Model (Keep from previous BiDAF code) ---
class BiDAF(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout_prob=0.1, padding_idx=0):
        super().__init__()
        print("\n--- Initializing BiDAF Model ---")
        self.padding_idx = padding_idx

        # 1. Embedding Layer
        self.embedding_layer = EmbeddingLayer(vocab_size, embedding_dim, dropout_prob, padding_idx)
        # 2. Contextual Embedding Layer
        contextual_input_dim = embedding_dim
        self.contextual_embedding_layer = ContextualEmbeddingLayer(contextual_input_dim, hidden_dim, num_layers=1, dropout_prob=dropout_prob)
        contextual_output_dim = 2 * hidden_dim
        # 3. Attention Flow Layer
        self.attention_flow_layer = AttentionFlowLayer(hidden_dim)
        attention_output_dim = 8 * hidden_dim
        # 4. Modeling Layer
        self.modeling_layer = ModelingLayer(attention_output_dim, hidden_dim, num_layers=2, dropout_prob=dropout_prob)
        modeling_output_dim = 2 * hidden_dim
        # 5. Output Layer
        self.output_layer = OutputLayer(G_dim=attention_output_dim, M_dim=modeling_output_dim)
        print("--- BiDAF Model Initialized ---")

    def forward(self, context_ids, question_ids, context_lengths, question_lengths):
        # Create masks (True for non-padding tokens)
        # Ensure lengths > 0 before creating mask to avoid issues with max() on empty tensor maybe?
        # However, get_mask_from_lengths should handle max_len=0 if lengths are all 0.
        max_c_len = context_ids.shape[1]
        max_q_len = question_ids.shape[1]

        context_mask = get_mask_from_lengths(context_lengths, max_len=max_c_len)
        question_mask = get_mask_from_lengths(question_lengths, max_len=max_q_len)

        # 1. Embedding Layer
        context_embedded = self.embedding_layer(context_ids) # (batch, c_len, emb_dim)
        question_embedded = self.embedding_layer(question_ids) # (batch, q_len, emb_dim)

        # 2. Contextual Embedding Layer
        # Handle cases where lengths might be 0 before passing to LSTM layer
        H = self.contextual_embedding_layer(context_embedded, torch.max(context_lengths, torch.ones_like(context_lengths))) # (batch, c_len, 2*hidden)
        U = self.contextual_embedding_layer(question_embedded, torch.max(question_lengths, torch.ones_like(question_lengths))) # (batch, q_len, 2*hidden)
        # Note: Using max with ones ensures LSTM gets length >= 1. If original length was 0, the output for that item
        # might not be meaningful, but avoids crashes. Masking later should handle it.

        # 3. Attention Flow Layer
        G = self.attention_flow_layer(H, U, context_mask, question_mask) # (batch, c_len, 8*hidden)

        # 4. Modeling Layer
        M = self.modeling_layer(G, torch.max(context_lengths, torch.ones_like(context_lengths))) # (batch, c_len, 2*hidden)

        # 5. Output Layer
        logits_start, logits_end = self.output_layer(G, M, context_mask) # (batch, c_len), (batch, c_len)

        return logits_start, logits_end


# --- Data Loading and Preprocessing ---

# Constants
MODEL_NAME = "bert-base-multilingual-cased" # Tokenizer choice
DATASET_NAME = "hajili/squad-azerbaijani-reindex-translation"
MAX_CONTEXT_LENGTH = 384 # Max tokens for context
MAX_QUESTION_LENGTH = 64  # Max tokens for question
BATCH_SIZE = 8 # Adjust based on memory
LEARNING_RATE = 1e-4 # Example learning rate
EPOCHS = 1 # Number of training epochs (adjust for real training)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Tokenizer
print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure PAD token is set if not default (mbert should have it)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    # Resize model embeddings if adding tokens, but here we only need PAD ID
PAD_TOKEN_ID = tokenizer.pad_token_id
print(f"PAD Token ID: {PAD_TOKEN_ID}")


# Load Dataset
print(f"Loading dataset: {DATASET_NAME}")
# Load a smaller portion for demonstration if needed:
# raw_datasets = load_dataset(DATASET_NAME, split='train[:1%]') # Load 1% of train
raw_datasets = load_dataset(DATASET_NAME)
print("Raw dataset loaded:")
print(raw_datasets)


# Preprocessing Function
def preprocess_squad_for_bidaf(examples):
    """
    Tokenizes context and question separately and finds token-level answer spans.
    Designed for the BiDAF model expecting separate inputs.
    """
    questions = [q.strip() for q in examples["question"]]
    contexts = examples["context"]
    answer_text = examples["answer_text"] # Accessing the structure correctly
    answer_start = examples["answer_start"]
    # Tokenize contexts and questions separately
    tokenized_contexts = tokenizer(
        contexts,
        max_length=MAX_CONTEXT_LENGTH,
        truncation=True,
        padding="max_length", # Pad here to simplify collating
        return_offsets_mapping=True # Crucial for finding token spans
    )
    tokenized_questions = tokenizer(
        questions,
        max_length=MAX_QUESTION_LENGTH,
        truncation=True,
        padding="max_length" # Pad here
    )

    offset_mapping = tokenized_contexts.pop("offset_mapping") # Get offsets and remove from dict

    start_positions = []
    end_positions = []
    context_lengths = []
    question_lengths = []

    for i, offsets in enumerate(offset_mapping):
        start_char = answer_start[i]
        end_char = start_char + len(answer_text[i])

        # Find the actual sequence length (before padding) for context and question
        # Method 1: Use attention mask (summing True values)
        context_len = sum(tokenized_contexts["attention_mask"][i])
        question_len = sum(tokenized_questions["attention_mask"][i])
        context_lengths.append(context_len)
        question_lengths.append(question_len)

        # Find context token start/end corresponding to char start/end
        token_start_index = 0
        while token_start_index < len(offsets) and offsets[token_start_index][0] == 0 and offsets[token_start_index][1] == 0:
            # Skip special tokens like [CLS] at the beginning if tokenizer adds them implicitly
            # Note: Check if your tokenizer adds [CLS] by default to single sequences.
            # BertTokenizer usually does for pair-input but not single. Let's assume no [CLS] here for separate tokenization.
            token_start_index += 1
        # Adjust if offsets list is shorter than context_len (due to truncation?)
        valid_offsets = offsets[:context_len] # Consider only non-padded tokens

        context_token_start = -1
        context_token_end = -1

        # Find start token index
        for idx, (start, end) in enumerate(valid_offsets):
             # Check if the character start falls within this token's span
             if start <= start_char < end:
                 context_token_start = idx
                 break

        # Find end token index
        # Look for the token whose *end* character aligns with the answer end character - 1
        for idx, (start, end) in enumerate(reversed(valid_offsets)):
            actual_idx = len(valid_offsets) - 1 - idx
            # Check if the character end falls within this token's span
            # The condition is end_char > start because end_char is exclusive
            if start < end_char <= end:
                 context_token_end = actual_idx
                 break # Found the last token containing part of the answer


        # Handle cases where the answer was not found or truncated
        # If start or end weren't found, or if end is before start (shouldn't happen with valid SQuAD)
        # or if the answer span is empty after tokenization.
        if (context_token_start == -1 or
            context_token_end == -1 or
            context_token_end < context_token_start or
            context_len == 0): # Added check for zero length context
            # Treat as impossible answer (often mapped to CLS index 0)
            start_positions.append(0)
            end_positions.append(0)
        # Check if answer text re-constructed from tokens roughly matches original
        # (This is complex due to subwords, skipping for this example)
        else:
             # Ensure positions are within the actual length, not max_length
             # Although BiDAF operates on padded sequence, targets should be valid indices
             start_positions.append(context_token_start)
             end_positions.append(context_token_end)


    # Prepare the final dictionary structure expected by the model/training loop
    processed = {
        "context_ids": tokenized_contexts["input_ids"],
        "question_ids": tokenized_questions["input_ids"],
        "context_attention_mask": tokenized_contexts["attention_mask"], # Keep masks if needed later
        "question_attention_mask": tokenized_questions["attention_mask"],
        "start_positions": start_positions,
        "end_positions": end_positions,
        "context_lengths": context_lengths, # Pass actual lengths
        "question_lengths": question_lengths # Pass actual lengths
    }
    return processed


# Apply preprocessing
# Note: SQuAD-Azerbaijani might have a different structure (e.g., 'answers' field)
# Let's adjust the preprocessing function call based on the dataset printout
# The printout showed columns: id, title, context, question, answers
# 'answers' is likely a dictionary {'text': [...], 'answer_start': [...]}

print("\nPreprocessing dataset...")
# Using a smaller subset for faster preprocessing during demo
# train_dataset = raw_datasets["train"].select(range(1000))
# validation_dataset = raw_datasets["validation"].select(range(100))
split_dataset = raw_datasets["train"].train_test_split()
train_dataset = split_dataset["train"]
validation_dataset = split_dataset["test"]

tokenized_train_dataset = train_dataset.map(
    preprocess_squad_for_bidaf,
    batched=True,
    remove_columns=train_dataset.column_names # Remove original columns
)
tokenized_validation_dataset = validation_dataset.map(
    preprocess_squad_for_bidaf,
    batched=True,
    remove_columns=validation_dataset.column_names
)

print("\nProcessed dataset example:")
print(tokenized_train_dataset[0])

# Set format for PyTorch
tokenized_train_dataset.set_format("torch")
tokenized_validation_dataset.set_format("torch")


# Create DataLoader
# No custom collate_fn needed *if* padding was done in .map()
train_dataloader = DataLoader(tokenized_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(tokenized_validation_dataset, batch_size=BATCH_SIZE) # No shuffle for validation


# --- Model Instantiation ---
VOCAB_SIZE = tokenizer.vocab_size
EMBEDDING_DIM = 100 # Keep BiDAF original embedding dim, could be 768 if using BERT embeds
HIDDEN_DIM = 100    # Hidden dimension for LSTMs
DROPOUT_PROB = 0.2 # Slightly higher dropout can help

model = BiDAF(
    vocab_size=VOCAB_SIZE,
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    dropout_prob=DROPOUT_PROB,
    padding_idx=PAD_TOKEN_ID # Use tokenizer's pad id
).to(DEVICE)

# --- Training Setup ---
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # AdamW often preferred
# Loss function (CrossEntropy combines LogSoftmax and NLLLoss)
# We output logits, so CrossEntropyLoss is appropriate.
loss_fn = nn.CrossEntropyLoss() # Reduction='mean' by default

Loading tokenizer: bert-base-multilingual-cased
PAD Token ID: 0
Loading dataset: hajili/squad-azerbaijani-reindex-translation


train-00000-of-00001.parquet:   0%|          | 0.00/12.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/54092 [00:00<?, ? examples/s]

Raw dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answer_text', 'answer_start'],
        num_rows: 54092
    })
})

Preprocessing dataset...


Map:   0%|          | 0/40569 [00:00<?, ? examples/s]

Map:   0%|          | 0/13523 [00:00<?, ? examples/s]


Processed dataset example:
{'context_ids': [101, 10829, 86844, 13098, 32467, 62838, 10679, 140, 108467, 50109, 10711, 59591, 80457, 37514, 11170, 32467, 62838, 171, 108467, 58829, 13084, 25093, 188, 17871, 50208, 70437, 19648, 10711, 25898, 77865, 13941, 30588, 10132, 15756, 37203, 10330, 185, 13998, 97471, 10115, 187, 11562, 51180, 17393, 20267, 13941, 38981, 66952, 19277, 12843, 11957, 11562, 13897, 91208, 119, 11916, 171, 108467, 58829, 13084, 46879, 14573, 187, 11562, 51180, 17393, 20267, 19697, 11499, 10261, 10371, 25591, 13941, 16868, 97339, 15397, 10561, 181, 28708, 10206, 16880, 28041, 183, 39017, 22934, 53216, 119, 34842, 99437, 10115, 33884, 11249, 30204, 29839, 19432, 10711, 68447, 84931, 12146, 33100, 17392, 13744, 30859, 16497, 17007, 44855, 11170, 32467, 62838, 10679, 91783, 10143, 28613, 47538, 10116, 14573, 10562, 28708, 19277, 38708, 91208, 119, 29839, 19432, 37325, 72112, 29531, 11802, 49889, 10245, 65892, 74347, 10713, 15397, 194, 16540, 13091, 17629, 10219, 22944, 

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --- Training Loop ---
print(f"\n--- Starting Training on {DEVICE} ---")
model.train()
total_loss = 0
progress_bar = tqdm(range(EPOCHS * len(train_dataloader)), desc="Training")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    for batch_num, batch in enumerate(train_dataloader):
        # Move batch to device
        context_ids = batch['context_ids'].to(DEVICE)
        question_ids = batch['question_ids'].to(DEVICE)
        context_lengths = batch['context_lengths'].to(DEVICE)
        question_lengths = batch['question_lengths'].to(DEVICE)
        start_positions = batch['start_positions'].to(DEVICE)
        end_positions = batch['end_positions'].to(DEVICE)

        # Check for invalid target indices (<= 0 often means unanswerable/filtered)
        # This is important because CrossEntropyLoss expects valid class indices >= 0
        # If we mapped impossible answers to 0, and 0 is a valid token index, this is okay.
        # However, if loss calculation fails, check target values.
        # Optional: Filter batches with problematic targets if needed here.

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        try:
            logits_start, logits_end = model(
                context_ids,
                question_ids,
                context_lengths,
                question_lengths
            )
        except Exception as e:
            print(f"Error during forward pass in batch {batch_num}: {e}")
            print(f"Context lengths: {context_lengths}")
            print(f"Question lengths: {question_lengths}")
            # Option: skip this batch or investigate further
            continue # Skip batch if forward pass fails


        # Calculate loss
        # Ensure target indices are within the valid range [0, context_len - 1]
        # Logits have shape (batch, context_len)
        # Targets have shape (batch,)
        # Clamp targets to be safe, although preprocessing should handle this.
        max_c_len_batch = logits_start.shape[1]
        start_positions = torch.clamp(start_positions, min=0, max=max_c_len_batch - 1)
        end_positions = torch.clamp(end_positions, min=0, max=max_c_len_batch - 1)

        # Ignore loss calculation for examples where the answer was impossible (e.g., target is 0)
        # CrossEntropyLoss has `ignore_index` but we assumed 0 could be a valid target.
        # A more robust way is to filter invalid examples during preprocessing,
        # or use a mask if CrossEntropyLoss doesn't suit this setup.
        # For now, calculate loss on all items. Check for NaNs.
        loss_start = loss_fn(logits_start, start_positions)
        loss_end = loss_fn(logits_end, end_positions)
        batch_loss = loss_start + loss_end

        # Backward pass and optimization
        if torch.isnan(batch_loss):
             print(f"Warning: NaN loss detected in batch {batch_num}. Skipping backward.")
             print(f"Logits Start sample: {logits_start[0,:10]}")
             print(f"Logits End sample: {logits_end[0,:10]}")
             print(f"Targets Start: {start_positions}")
             print(f"Targets End: {end_positions}")
        else:
            batch_loss.backward()
            # Optional: Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += batch_loss.item()

        progress_bar.update(1)
        progress_bar.set_postfix({"loss": batch_loss.item()})

        # Optional: Print loss occasionally
        # if batch_num % 100 == 0:
        #     print(f"Epoch {epoch}, Batch {batch_num}/{len(train_dataloader)}, Loss: {batch_loss.item():.4f}")

    avg_epoch_loss = epoch_loss / len(train_dataloader)
    print(f"\nEpoch {epoch+1}/{EPOCHS} - Average Training Loss: {avg_epoch_loss:.4f}")

    # --- Evaluation (Example) ---
    model.eval()
    total_eval_loss = 0
    # Add metrics like EM and F1 for proper QA evaluation later
    print("Running evaluation...")
    eval_progress_bar = tqdm(validation_dataloader, desc="Evaluating")
    with torch.no_grad():
        for batch in eval_progress_bar:
            context_ids = batch['context_ids'].to(DEVICE)
            question_ids = batch['question_ids'].to(DEVICE)
            context_lengths = batch['context_lengths'].to(DEVICE)
            question_lengths = batch['question_lengths'].to(DEVICE)
            start_positions = batch['start_positions'].to(DEVICE)
            end_positions = batch['end_positions'].to(DEVICE)

            logits_start, logits_end = model(
                context_ids,
                question_ids,
                context_lengths,
                question_lengths
            )

            max_c_len_batch = logits_start.shape[1]
            start_positions = torch.clamp(start_positions, min=0, max=max_c_len_batch - 1)
            end_positions = torch.clamp(end_positions, min=0, max=max_c_len_batch - 1)

            loss_start = loss_fn(logits_start, start_positions)
            loss_end = loss_fn(logits_end, end_positions)
            batch_loss = loss_start + loss_end

            if not torch.isnan(batch_loss):
                total_eval_loss += batch_loss.item()

    avg_eval_loss = total_eval_loss / len(validation_dataloader)
    print(f"Epoch {epoch+1}/{EPOCHS} - Validation Loss: {avg_eval_loss:.4f}")


print("\n--- Training Finished ---")

# --- Inference Example (on one batch) ---
print("\n--- Inference Example ---")
model.eval()
example_batch = next(iter(validation_dataloader)) # Get one batch

context_ids = example_batch['context_ids'].to(DEVICE)
question_ids = example_batch['question_ids'].to(DEVICE)
context_lengths = example_batch['context_lengths'].to(DEVICE)
question_lengths = example_batch['question_lengths'].to(DEVICE)
true_start = example_batch['start_positions']
true_end = example_batch['end_positions']

with torch.no_grad():
    logits_start, logits_end = model(context_ids, question_ids, context_lengths, question_lengths)

    # Get probabilities
    probs_start = F.softmax(logits_start, dim=1)
    probs_end = F.softmax(logits_end, dim=1)

    # Get predicted indices (simple argmax)
    pred_start = torch.argmax(probs_start, dim=1)
    pred_end = torch.argmax(probs_end, dim=1)

    print(f"Context IDs shape: {context_ids.shape}")
    print(f"Predicted Start Indices (Batch): {pred_start.cpu().numpy()}")
    print(f"True Start Indices (Batch):      {true_start.numpy()}")
    print(f"Predicted End Indices (Batch):   {pred_end.cpu().numpy()}")
    print(f"True End Indices (Batch):        {true_end.numpy()}")

    # TODO: Implement proper span decoding (find best start/end pair)
    #       and map token indices back to text for evaluation (EM/F1).


--- Starting Training on cuda ---


Training:   0%|          | 0/5072 [00:00<?, ?it/s]


Epoch 1/1 - Average Training Loss: 9.4601
Running evaluation...


Evaluating:   0%|          | 0/1691 [00:00<?, ?it/s]

Epoch 1/1 - Validation Loss: 8.3124

--- Training Finished ---

--- Inference Example ---
Context IDs shape: torch.Size([8, 384])
Predicted Start Indices (Batch): [ 44  60 116  12 174  11 167   0]
True Start Indices (Batch):      [ 28  78 149  88 147   2 162 144]
Predicted End Indices (Batch):   [ 48   0 116  18 174  23 163  18]
True End Indices (Batch):        [ 33  81 151  89 154   3 173 145]


In [None]:
def get_bert_embeddings(text_batch):
    inputs = tokenizer(text_batch, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = bert(**inputs)
    return outputs.last_hidden_state  # [B, seq_len, hidden_dim]

In [None]:
! pip install datasets

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-nvrtc-cu12 12.5.82 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cuda-runtime-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-runtime-cu12 12.5

In [None]:
from datasets import load_dataset
dataset = load_dataset("hajili/squad-azerbaijani-reindex-translation")

README.md:   0%|          | 0.00/483 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/12.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/54092 [00:00<?, ? examples/s]

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

def train_bidaf(model, dataloader, epochs=2, lr=2e-5, device="cuda" if torch.cuda.is_available() else "cpu"):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")
        total_loss = 0
        for batch in tqdm(dataloader):
            input_ids = batch['input_ids_c'].to(device)
            attention_mask = batch['attention_mask_c'].to(device)
            start_pos = batch['start_positions'].to(device)
            end_pos = batch['end_positions'].to(device)

            # Forward pass
            start_logits, end_logits = model(input_ids, attention_mask, input_ids, attention_mask)
            loss = (loss_fn(start_logits, start_pos) + loss_fn(end_logits, end_pos)) / 2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Loss: {total_loss / len(dataloader):.4f}")
