In [1]:
from transformers import BertTokenizerFast
from datasets import load_dataset, DatasetDict, Dataset
import json
from collections import Counter
import numpy as np
from tqdm import tqdm

# --- Configuration ---
BERT_MODEL_NAME = 'bert-base-cased'

# Max lengths for BERT token sequences (these will be the sequence lengths for BiDAF)
MAX_BERT_QUESTION_LEN = 64  # Max BERT tokens for the question
MAX_BERT_CONTEXT_LEN = 512 # Max BERT tokens for the context (where answer span is predicted)

# Max characters per BERT token (for the character CNN)
MAX_BERT_TOKEN_CHAR_LEN = 25 # Most BERT tokens (subwords) are short. Adjust if needed.

# Paths for saving processed data and char vocab
CHAR_VOCAB_BERT_PATH = "./squad_char_vocab_for_bert_tokens.json"
PROCESSED_TRAIN_BERT_CHAR_PATH = "./squad_train_processed_bert_char.hf"
PROCESSED_VAL_BERT_CHAR_PATH = "./squad_val_processed_bert_char.hf"

# Special character tokens
CHAR_PAD_TOKEN = "<C_PAD>"
CHAR_UNK_TOKEN = "<C_UNK>" # For characters not in our char vocab (rare for BERT tokens)

# Load BERT Tokenizer
print(f"Loading BERT tokenizer: {BERT_MODEL_NAME}...")
tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

Loading BERT tokenizer: bert-base-cased...


In [2]:
import os

def build_char_vocab_from_bert_tokens(squad_dataset_split, bert_tokenizer, special_char_tokens):
    print("Building character vocabulary from BERT tokens...")
    char_counter = Counter()
    # Process a subset for speed, or all of train if needed
    # For a robust char vocab, process all unique BERT tokens from tokenizer.vocab
    # or tokenize a large representative text sample.

    # Option 1: Characters from BERT's own vocabulary tokens
    # for bert_token_str in bert_tokenizer.vocab.keys():
    #     char_counter.update(bert_token_str)

    # Option 2: Characters from tokenizing actual SQuAD data (more targeted)
    # Let's use a sample from the provided SQuAD split for this example
    sample_size = min(10000, len(squad_dataset_split)) # Adjust sample size
    print(f"Using a sample of {sample_size} examples to build character vocabulary...")
    for i in tqdm(range(sample_size)):
        example = squad_dataset_split[i]
        context_tokens = bert_tokenizer.tokenize(example['context'])
        question_tokens = bert_tokenizer.tokenize(example['question'])
        for bert_token_str in context_tokens + question_tokens:
            char_counter.update(bert_token_str)

    char_to_idx = {token: idx for idx, token in enumerate(special_char_tokens)}
    idx_to_char = {idx: token for token, idx in char_to_idx.items()}

    for char, count in char_counter.most_common():
        if char not in char_to_idx: # Avoid re-adding special tokens if they are actual chars
            idx = len(char_to_idx)
            char_to_idx[char] = idx
            idx_to_char[idx] = char

    print(f"Character vocabulary size (for BERT tokens): {len(char_to_idx)}")
    return char_to_idx, idx_to_char

# Load raw SQuAD data
squad_raw: DatasetDict = load_dataset("squad")
squad_train_raw: Dataset = squad_raw['train']
# squad_validation_raw: Dataset = squad_raw['validation'] # if needed for char vocab

# Build and save character vocabulary
# (Ideally run this once and save/load)
if not os.path.exists(CHAR_VOCAB_BERT_PATH):
    char_to_idx, idx_to_char = build_char_vocab_from_bert_tokens(squad_train_raw, tokenizer, [CHAR_PAD_TOKEN, CHAR_UNK_TOKEN])
    print(f"Saving character vocabulary to {CHAR_VOCAB_BERT_PATH}...")
    with open(CHAR_VOCAB_BERT_PATH, 'w', encoding='utf-8') as f:
        json.dump({'char_to_idx': char_to_idx, 'idx_to_char': idx_to_char}, f, ensure_ascii=False)
    print("Character vocabulary saved.")
else:
    print(f"Loading existing character vocabulary from {CHAR_VOCAB_BERT_PATH}...")
    with open(CHAR_VOCAB_BERT_PATH, 'r', encoding='utf-8') as f:
        char_vocab_data = json.load(f)
        char_to_idx = char_vocab_data['char_to_idx']
        # idx_to_char = char_vocab_data['idx_to_char'] # If needed
print(f"Character vocab size: {len(char_to_idx)}")
CHAR_PAD_ID = char_to_idx[CHAR_PAD_TOKEN]
CHAR_UNK_ID = char_to_idx.get(CHAR_UNK_TOKEN, CHAR_PAD_ID) # Fallback for UNK if not explicitly made

Loading existing character vocabulary from ./squad_char_vocab_for_bert_tokens.json...
Character vocab size: 266


In [3]:
def preprocess_for_bert_bidaf_with_chars(examples):
    questions = [q.strip() for q in examples['question']]
    contexts = examples['context']
    answers = examples['answers']
    example_ids = examples['id']

    # Tokenize questions with BERT
    tokenized_questions = tokenizer(
        questions,
        max_length=MAX_BERT_QUESTION_LEN,
        truncation=True,
        padding="max_length", # Pad to MAX_BERT_QUESTION_LEN
        return_attention_mask=True,
        return_offsets_mapping=False # Not needed for question if processed alone
    )

    # Tokenize contexts with BERT
    tokenized_contexts = tokenizer(
        contexts,
        max_length=MAX_BERT_CONTEXT_LEN,
        truncation=True,
        padding="max_length", # Pad to MAX_BERT_CONTEXT_LEN
        return_attention_mask=True,
        return_offsets_mapping=True # Essential for answer mapping
    )

    # Prepare lists for results
    start_positions_bert_final = []
    end_positions_bert_final = []
    context_bert_token_char_ids_final = []
    question_bert_token_char_ids_final = []

    for i in range(len(contexts)):
        # --- Answer Span Mapping ---
        answer = answers[i]
        context_char_to_token_offset = tokenized_contexts.offset_mapping[i]
        # context_bert_ids_for_current_example = tokenized_contexts.input_ids[i] # For debugging

        answer_char_start = answer['answer_start'][0]
        answer_text = answer['text'][0]
        answer_char_end = answer_char_start + len(answer_text)

        # Find start BERT token index for the answer
        start_token_idx_bert = -1
        for token_idx, (start_char, end_char) in enumerate(context_char_to_token_offset):
            if start_char <= answer_char_start < end_char:
                start_token_idx_bert = token_idx
                break

        # Find end BERT token index for the answer
        end_token_idx_bert = -1
        if start_token_idx_bert != -1: # Only search for end if start was found
            for token_idx in range(start_token_idx_bert, len(context_char_to_token_offset)):
                start_char, end_char = context_char_to_token_offset[token_idx]
                if start_char < answer_char_end <= end_char : # Answer ends within or at the end of this token
                    end_token_idx_bert = token_idx
                    break
                if start_char >= answer_char_end : # Answer ended before this token (should have been caught)
                    # This case means previous token was the end, or something is off.
                    # For safety, if it passed the start, but now start_char is already beyond answer_char_end,
                    # it implies the answer might be fully contained in start_token_idx_bert or previous tokens.
                    # The above loop usually handles this better by checking answer_char_end <= end_char.
                    # If answer text is empty, end_token_idx_bert might remain -1.
                    if end_token_idx_bert == -1 and token_idx > start_token_idx_bert: # If not set and current token is past ans
                         end_token_idx_bert = token_idx -1
                    break
            if end_token_idx_bert == -1 and start_token_idx_bert != -1: # If answer fully in start token or at end of context
                 if answer_char_end <= context_char_to_token_offset[start_token_idx_bert][1]:
                    end_token_idx_bert = start_token_idx_bert
                 else: # Reaches end of context tokens, and answer also ends there or got truncated
                    for token_idx in range(len(context_char_to_token_offset)-1, -1, -1):
                        if context_char_to_token_offset[token_idx] != (0,0): # find last actual token
                            end_token_idx_bert = token_idx
                            break


        # Validate and handle unmappable answers (SQuAD 1.1 should always have an answer)
        # For BiDAF, logits are over MAX_BERT_CONTEXT_LEN
        if start_token_idx_bert == -1 or end_token_idx_bert == -1 or \
           start_token_idx_bert >= MAX_BERT_CONTEXT_LEN or \
           end_token_idx_bert >= MAX_BERT_CONTEXT_LEN or \
           end_token_idx_bert < start_token_idx_bert:
            # print(f"WARN: Unmappable/Out-of-bounds answer for ID {example_ids[i]}. Char span ({answer_char_start}-{answer_char_end}), Text: '{answer_text}'. Defaulting to CLS/0.")
            # Map to CLS token (index 0) for now. Your loss function should ignore this if possible.
            # Or use -1 if your CrossEntropyLoss ignore_index is -1.
            start_positions_bert_final.append(0) # Assuming CLS is at index 0 and ignored or handled
            end_positions_bert_final.append(0)
        else:
            start_positions_bert_final.append(start_token_idx_bert)
            end_positions_bert_final.append(end_token_idx_bert)

        # --- Character ID Generation for BERT Tokens ---
        # For Context
        current_context_bert_ids = tokenized_contexts.input_ids[i]
        context_word_char_ids_list = []
        for bert_token_id in current_context_bert_ids: # These are already padded to MAX_BERT_CONTEXT_LEN
            if bert_token_id == tokenizer.pad_token_id: # If it's a padding BERT token
                char_ids_for_token = [CHAR_PAD_ID] * MAX_BERT_TOKEN_CHAR_LEN
            else:
                bert_token_str = tokenizer.convert_ids_to_tokens(bert_token_id)
                char_ids_for_token = [char_to_idx.get(char, CHAR_UNK_ID) for char in bert_token_str]
                # Pad/truncate characters in this BERT token
                if len(char_ids_for_token) > MAX_BERT_TOKEN_CHAR_LEN:
                    char_ids_for_token = char_ids_for_token[:MAX_BERT_TOKEN_CHAR_LEN]
                else:
                    char_ids_for_token.extend([CHAR_PAD_ID] * (MAX_BERT_TOKEN_CHAR_LEN - len(char_ids_for_token)))
            context_word_char_ids_list.append(char_ids_for_token)
        context_bert_token_char_ids_final.append(context_word_char_ids_list)

        # For Question
        current_question_bert_ids = tokenized_questions.input_ids[i]
        question_word_char_ids_list = []
        for bert_token_id in current_question_bert_ids: # Padded to MAX_BERT_QUESTION_LEN
            if bert_token_id == tokenizer.pad_token_id:
                char_ids_for_token = [CHAR_PAD_ID] * MAX_BERT_TOKEN_CHAR_LEN
            else:
                bert_token_str = tokenizer.convert_ids_to_tokens(bert_token_id)
                char_ids_for_token = [char_to_idx.get(char, CHAR_UNK_ID) for char in bert_token_str]
                if len(char_ids_for_token) > MAX_BERT_TOKEN_CHAR_LEN:
                    char_ids_for_token = char_ids_for_token[:MAX_BERT_TOKEN_CHAR_LEN]
                else:
                    char_ids_for_token.extend([CHAR_PAD_ID] * (MAX_BERT_TOKEN_CHAR_LEN - len(char_ids_for_token)))
            question_word_char_ids_list.append(char_ids_for_token)
        question_bert_token_char_ids_final.append(question_word_char_ids_list)

    # Prepare final dictionary for the dataset's .map() function
    processed_output = {
        'context_input_ids': tokenized_contexts.input_ids,
        'context_attention_mask': tokenized_contexts.attention_mask,
        'context_token_char_ids': context_bert_token_char_ids_final, # (batch, MAX_BERT_CONTEXT_LEN, MAX_BERT_TOKEN_CHAR_LEN)

        'question_input_ids': tokenized_questions.input_ids,
        'question_attention_mask': tokenized_questions.attention_mask,
        'question_token_char_ids': question_bert_token_char_ids_final, # (batch, MAX_BERT_QUESTION_LEN, MAX_BERT_TOKEN_CHAR_LEN)

        'start_token_bert': start_positions_bert_final,
        'end_token_bert': end_positions_bert_final,
        'id': example_ids
    }
    return processed_output


# --- Apply Preprocessing to SQuAD dataset ---
print("\nStarting SQuAD preprocessing for BERT with Character Embeddings...")
# squad_train_raw and squad_validation_raw should be your loaded Hugging Face squad datasets
# Example:
# squad_raw: DatasetDict = load_dataset("squad")
# squad_train_raw: Dataset = squad_raw['train']
# squad_validation_raw: Dataset = squad_raw['validation']

# Important: Remove original columns to avoid issues when saving or using the dataset later
train_cols_to_remove = squad_train_raw.column_names
val_cols_to_remove = squad_raw['validation'].column_names # Assuming 'validation' split exists

processed_squad_train = squad_train_raw.map(
    preprocess_for_bert_bidaf_with_chars,
    batched=True, # Process in batches for efficiency
    remove_columns=train_cols_to_remove # Remove old columns
)
processed_squad_val = squad_raw['validation'].map(
    preprocess_for_bert_bidaf_with_chars,
    batched=True,
    remove_columns=val_cols_to_remove
)

print("\nPreprocessing complete.")
print("Sample processed training example:")
print(processed_squad_train[0])

# Save processed datasets
print(f"Saving processed training data to {PROCESSED_TRAIN_BERT_CHAR_PATH}...")
processed_squad_train.save_to_disk(PROCESSED_TRAIN_BERT_CHAR_PATH)
print(f"Saving processed validation data to {PROCESSED_VAL_BERT_CHAR_PATH}...")
processed_squad_val.save_to_disk(PROCESSED_VAL_BERT_CHAR_PATH)
print("Processed datasets saved.")


Starting SQuAD preprocessing for BERT with Character Embeddings...


Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CharEmbedding(nn.Module):
    def __init__(self, char_vocab_size, char_embedding_dim, char_cnn_out_channels,
                 char_cnn_kernel_size, char_padding_idx, dropout_rate):
        super(CharEmbedding, self).__init__()

        self.char_embedding = nn.Embedding(
            num_embeddings=char_vocab_size,
            embedding_dim=char_embedding_dim,
            padding_idx=char_padding_idx
        )

        self.conv1d = nn.Conv1d(
            in_channels=char_embedding_dim,
            out_channels=char_cnn_out_channels,
            kernel_size=char_cnn_kernel_size,
            # Padding to maintain length for easier max-pooling across the sequence dimension later,
            # or use specific padding based on kernel size.
            # For kernel_size=5, padding=2 would keep length same if stride=1.
        )
        # The paper implies max-pooling over the resulting length of the convolution for each word.
        # If conv output length is L_out = L_in - kernel_size + 1 + 2*padding.
        # Let's use padding such that the output length is reasonable.
        # A common approach: padding = (kernel_size - 1) // 2 for 'same' style padding.

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x_char_ids):
        # x_char_ids: (batch_size, seq_len, max_word_len)
        batch_size, seq_len, max_word_len = x_char_ids.shape

        # Reshape for embedding: (batch_size * seq_len, max_word_len)
        x_char_ids_reshaped = x_char_ids.view(-1, max_word_len)

        # Embed characters: (batch_size * seq_len, max_word_len, char_embedding_dim)
        char_emb = self.char_embedding(x_char_ids_reshaped)
        char_emb = self.dropout(char_emb) # Apply dropout to character embeddings

        # Permute for Conv1d: (batch_size * seq_len, char_embedding_dim, max_word_len)
        char_emb_permuted = char_emb.permute(0, 2, 1)

        # Convolution: (batch_size * seq_len, char_cnn_out_channels, convolved_len)
        # The convolved_len depends on max_word_len, kernel_size, and padding.
        # Example: if max_word_len=16, kernel=5, padding=0 -> convolved_len = 16-5+1 = 12
        # Example: if max_word_len=16, kernel=5, padding=2 -> convolved_len = 16-5+1+2*2 = 16
        char_conv_out = self.conv1d(char_emb_permuted)
        char_conv_out = F.relu(char_conv_out) # Apply ReLU

        # Max-pool over the convolved length dimension: (batch_size * seq_len, char_cnn_out_channels)
        # The kernel_size for max_pool1d should be the full length of the convolved dimension.
        char_pooled = F.max_pool1d(char_conv_out, kernel_size=char_conv_out.shape[2]).squeeze(2)

        # Reshape back to (batch_size, seq_len, char_cnn_out_channels)
        final_char_emb = char_pooled.view(batch_size, seq_len, -1)

        return final_char_emb


# HighwayNetwork class remains the same as provided in the previous full model response
class HighwayNetwork(nn.Module):
    def __init__(self, input_dim, num_layers):
        super(HighwayNetwork, self).__init__()
        self.num_layers = num_layers
        self.transform_gates = nn.ModuleList([nn.Linear(input_dim, input_dim) for _ in range(num_layers)])
        self.normal_layers = nn.ModuleList([nn.Linear(input_dim, input_dim) for _ in range(num_layers)])

    def forward(self, x):
        for i in range(self.num_layers):
            transform_gate_output = torch.sigmoid(self.transform_gates[i](x))
            normal_layer_output = F.relu(self.normal_layers[i](x))
            x = transform_gate_output * normal_layer_output + (1 - transform_gate_output) * x
        return x

# BiDAFAttention class remains the same as provided in the previous full model response
class BiDAFAttention(nn.Module):
    def __init__(self, hidden_size_times_2d): # effectively 2 * d, e.g., 2 * 300 = 600
        super(BiDAFAttention, self).__init__()
        self.hidden_size_times_2d = hidden_size_times_2d
        self.similarity_weight = nn.Linear(self.hidden_size_times_2d * 3, 1, bias=False)

    def forward(self, C_contextual, Q_contextual, C_mask, Q_mask):
        # C_contextual: (batch, C_len, 2d)
        # Q_contextual: (batch, Q_len, 2d)
        batch_size, C_len, _ = C_contextual.shape
        _, Q_len, _ = Q_contextual.shape

        C_expanded = C_contextual.unsqueeze(2).expand(-1, -1, Q_len, -1)
        Q_expanded = Q_contextual.unsqueeze(1).expand(-1, C_len, -1, -1)
        elementwise_prod = C_expanded * Q_expanded
        concat_features = torch.cat((C_expanded, Q_expanded, elementwise_prod), dim=3)
        S = self.similarity_weight(concat_features).squeeze(3)

        S_masked_q = S.masked_fill(Q_mask.unsqueeze(1) == 0, -float('inf'))
        S_masked_c = S.masked_fill(C_mask.unsqueeze(2) == 0, -float('inf'))

        alpha = F.softmax(S_masked_q, dim=2)
        A = torch.bmm(alpha, Q_contextual)

        max_S_c = torch.max(S, dim=2)[0]
        max_S_c_masked = max_S_c.masked_fill(C_mask == 0, -float('inf'))
        b_weights = F.softmax(max_S_c_masked, dim=1)
        C_prime = torch.bmm(b_weights.unsqueeze(1), C_contextual).squeeze(1)
        B = C_prime.unsqueeze(1).expand(-1, C_len, -1)

        g_c_a = C_contextual * A
        g_c_b = C_contextual * B
        G = torch.cat((C_contextual, A, g_c_a, g_c_b), dim=2)
        return G

# Assume CharEmbedding, HighwayNetwork, BiDAFAttention classes are defined as before.
# CharEmbedding parameters will be:
CHAR_EMB_DIM_FOR_BERT_TOKENS = 8 # As per original BiDAF paper's char embedding dim
CHAR_CNN_OUT_CHANNELS_FOR_BERT_TOKENS = 50 # Output dim of char CNN
CHAR_CNN_KERNEL_SIZE_FOR_BERT_TOKENS = 5

class BiDAF_BERT_Char(nn.Module):
    def __init__(self,
                 bert_model_name,
                 # Char embedding params (for chars of BERT tokens)
                 char_vocab_size, # Size of your new char_to_idx for BERT token chars
                 char_embedding_dim, # e.g., 8
                 char_cnn_out_channels, # e.g., 100
                 char_cnn_kernel_size, # e.g., 5
                 char_padding_idx,
                 # BiDAF specific params
                 hidden_size, # 'd' for BiDAF's LSTMs (e.g., 128 or 300)
                 num_highway_layers=2,
                 dropout_rate=0.2,
                 num_modeling_lstm_layers=2
                ):
        super(BiDAF_BERT_Char, self).__init__()

        self.bert = BertModel.from_pretrained(bert_model_name)
        for param in self.bert.parameters():
            param.requires_grad = False # Freeze BERT parameters
        self.bert_hidden_dim = self.bert.config.hidden_size # 768 for bert-base

        self.char_embedding_layer = CharEmbedding(
            char_vocab_size=char_vocab_size,
            char_embedding_dim=char_embedding_dim,
            char_cnn_out_channels=char_cnn_out_channels,
            char_cnn_kernel_size=char_cnn_kernel_size,
            char_padding_idx=char_padding_idx,
            dropout_rate=dropout_rate
        )

        self.combined_input_dim = self.bert_hidden_dim + char_cnn_out_channels # e.g., 768 + 100 = 868

        self.hidden_size = hidden_size # 'd' for BiDAF LSTMs
        self.dropout_rate = dropout_rate

        self.highway_network = HighwayNetwork(input_dim=self.combined_input_dim, num_layers=num_highway_layers)

        self.contextual_lstm = nn.LSTM(
            input_size=self.combined_input_dim,
            hidden_size=self.hidden_size, # d
            num_layers=1,
            bidirectional=True,
            batch_first=True,
            dropout=0
        )

        self.attention = BiDAFAttention(hidden_size_times_2d=2 * self.hidden_size)

        self.modeling_lstm = nn.LSTM(
            input_size=8 * self.hidden_size,
            hidden_size=self.hidden_size, # d
            num_layers=num_modeling_lstm_layers,
            bidirectional=True,
            batch_first=True,
            dropout=self.dropout_rate if num_modeling_lstm_layers > 1 else 0
        )

        self.start_output_linear = nn.Linear(10 * self.hidden_size, 1)
        self.end_output_linear = nn.Linear(10 * self.hidden_size, 1)

        self.dropout_layer = nn.Dropout(self.dropout_rate)

    def forward(self,
                context_bert_ids, context_bert_mask, context_bert_char_ids,
                question_bert_ids, question_bert_mask, question_bert_char_ids):

        # 1. Get BERT Embeddings
        C_bert_emb = self.bert(input_ids=context_bert_ids, attention_mask=context_bert_mask).last_hidden_state
        Q_bert_emb = self.bert(input_ids=question_bert_ids, attention_mask=question_bert_mask).last_hidden_state

        # 2. Get Character Embeddings for BERT tokens
        C_char_level_emb = self.char_embedding_layer(context_bert_char_ids)
        Q_char_level_emb = self.char_embedding_layer(question_bert_char_ids)

        # 3. Concatenate BERT embeddings and Character-level embeddings for BERT tokens
        C_combined_emb = torch.cat((C_bert_emb, C_char_level_emb), dim=2)
        Q_combined_emb = torch.cat((Q_bert_emb, Q_char_level_emb), dim=2)

        C_combined_emb = self.dropout_layer(C_combined_emb)
        Q_combined_emb = self.dropout_layer(Q_combined_emb)

        # 4. Highway Network
        C_highway = self.highway_network(C_combined_emb)
        Q_highway = self.highway_network(Q_combined_emb)

        # 5. Contextual Embedding Layer (BiDAF's BiLSTM)
        C_contextual, _ = self.contextual_lstm(C_highway)
        Q_contextual, _ = self.contextual_lstm(Q_highway)

        C_contextual = self.dropout_layer(C_contextual)
        Q_contextual = self.dropout_layer(Q_contextual)

        # BiDAF masks (from BERT attention masks)
        C_bidaf_mask = context_bert_mask.float()
        Q_bidaf_mask = question_bert_mask.float()

        # 6. Attention Flow Layer
        G = self.attention(C_contextual, Q_contextual, C_bidaf_mask, Q_bidaf_mask)
        G = self.dropout_layer(G)

        # 7. Modeling Layer
        M, _ = self.modeling_lstm(G)
        M = self.dropout_layer(M)

        # 8. Output Layer
        output_features = torch.cat((G, M), dim=2)
        start_logits = self.start_output_linear(output_features).squeeze(2)
        end_logits = self.end_output_linear(output_features).squeeze(2)

        start_logits_masked = start_logits.masked_fill(C_bidaf_mask == 0, -float('inf'))
        end_logits_masked = end_logits.masked_fill(C_bidaf_mask == 0, -float('inf'))

        return start_logits_masked, end_logits_masked

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_from_disk # To load Hugging Face datasets
import json
import numpy as np
import os
from tqdm import tqdm # For progress bars
from transformers import BertModel, BertTokenizerFast # Needed for model definition and tokenizer info

# --- Configuration & Hyperparameters ---
# Paths
PROCESSED_TRAIN_PATH = "./squad_train_processed_bert_char.hf" # From previous preprocessing
PROCESSED_VAL_PATH = "./squad_val_processed_bert_char.hf"   # From previous preprocessing
CHAR_VOCAB_PATH = "./squad_char_vocab_for_bert_tokens.json" # From previous preprocessing
MODEL_SAVE_DIR = "./saved_models"
BEST_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "bidaf_bert_char_best.pt")

# Ensure save directory exists
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# Model & Embedding Hyperparameters
BERT_MODEL_NAME = 'bert-base-cased' # Should match preprocessing
# Character Embeddings (for chars of BERT tokens)
CHAR_EMBEDDING_DIM = 25   # Dimension of individual char embedding (e.g. 8, 16, 20)
CHAR_CNN_OUT_CHANNELS = 50 # Output dim of char CNN
CHAR_CNN_KERNEL_SIZE = 5
# BiDAF Structure
# HIDDEN_SIZE = 300 # 'd' in BiDAF paper, for BiDAF's LSTMs
HIDDEN_SIZE = 128 # User indicated a smaller model worked well
NUM_HIGHWAY_LAYERS = 2
NUM_MODELING_LSTM_LAYERS = 2 # User indicated a smaller model worked well
DROPOUT_RATE = 0.2

# Training Hyperparameters
LEARNING_RATE = 1e-3 # Often smaller for BERT-based models
BATCH_SIZE = 16      # Adjust based on GPU memory (BERT is memory intensive)
NUM_EPOCHS = 10       # Start with a few for BERT fine-tuning
CLIP_GRAD_NORM = 5.0 # Common for BERT fine-tuning

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory Allocated: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB")
    print(f"CUDA Memory Cached: {torch.cuda.memory_reserved(0)/1024**2:.2f} MB")


# --- Load Character Vocabulary ---
try:
    with open(CHAR_VOCAB_PATH, 'r', encoding='utf-8') as f:
        char_vocab_data = json.load(f)
        char_to_idx = char_vocab_data['char_to_idx']
    CHAR_VOCAB_SIZE = len(char_to_idx)
    CHAR_PADDING_IDX = char_to_idx.get("<C_PAD>", 0) # Ensure your CHAR_PAD token name
    print(f"Character vocabulary loaded. Size: {CHAR_VOCAB_SIZE}, CHAR_PADDING_IDX: {CHAR_PADDING_IDX}")
except FileNotFoundError as e:
    print(f"Error: Character vocabulary file not found at {CHAR_VOCAB_PATH}. {e}")
    exit()
except KeyError as e:
    print(f"Error: Special character token {e} not found in loaded char vocabulary.")
    exit()

# --- Load Processed Datasets ---
try:
    print(f"Loading processed training data from {PROCESSED_TRAIN_PATH}...")
    train_dataset_processed = load_from_disk(PROCESSED_TRAIN_PATH)
    print(f"Loading processed validation data from {PROCESSED_VAL_PATH}...")
    val_dataset_processed = load_from_disk(PROCESSED_VAL_PATH)
    print("Processed datasets loaded.")
except Exception as e:
    print(f"Error loading processed datasets: {e}")
    print("Please ensure your processed data paths are correct and data exists from the previous step.")
    exit()

# Set format for PyTorch
# These columns were created by `preprocess_for_bert_bidaf_with_chars`
columns_to_torch = [
    'context_input_ids', 'context_attention_mask', 'context_token_char_ids',
    'question_input_ids', 'question_attention_mask', 'question_token_char_ids',
    'start_token_bert', 'end_token_bert'
]
train_dataset_processed.set_format(type='torch', columns=columns_to_torch)
val_dataset_processed.set_format(type='torch', columns=columns_to_torch)

# --- DataLoaders ---
train_dataloader = DataLoader(train_dataset_processed, batch_size=BATCH_SIZE, shuffle=True, num_workers=2 if device.type == 'cuda' else 0)
val_dataloader = DataLoader(val_dataset_processed, batch_size=BATCH_SIZE, shuffle=False, num_workers=2 if device.type == 'cuda' else 0)
print(f"DataLoaders created. Train batches: {len(train_dataloader)}, Val batches: {len(val_dataloader)}")

# --- Model Instantiation ---
model = BiDAF_BERT_Char(
    bert_model_name=BERT_MODEL_NAME,
    char_vocab_size=CHAR_VOCAB_SIZE,
    char_embedding_dim=CHAR_EMBEDDING_DIM,
    char_cnn_out_channels=CHAR_CNN_OUT_CHANNELS,
    char_cnn_kernel_size=CHAR_CNN_KERNEL_SIZE,
    char_padding_idx=CHAR_PADDING_IDX,
    hidden_size=HIDDEN_SIZE,
    num_highway_layers=NUM_HIGHWAY_LAYERS,
    dropout_rate=DROPOUT_RATE,
    num_modeling_lstm_layers=NUM_MODELING_LSTM_LAYERS
).to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"BiDAF_BERT_Char Model instantiated. Trainable Parameters: {total_params:,}")

# --- Loss Function and Optimizer ---
# IMPORTANT: Regarding ignore_index for CrossEntropyLoss:
# The preprocessing step `preprocess_for_bert_bidaf_with_chars` mapped unmappable/out-of-bounds
# answers to index 0 (CLS token index). If 0 can also be a valid start/end token for an answer,
# then ignore_index=0 will incorrectly ignore those valid answers.
# A safer approach is to ensure preprocessing maps such cases to a dedicated value like -100
# and set ignore_index=-100. For now, assuming 0 is used for unmappable and it's okay.
# If many answers are at token 0, this might need revisiting.
# If your preprocessing was updated to use -1 for unmappable spans, use ignore_index=-1.
criterion = nn.CrossEntropyLoss(ignore_index=0) # Or -1 if your preprocessing maps unmappable to -1.
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE) # AdamW is often preferred for BERT

# --- Metrics Calculation (Simplified Token-Level) ---
def compute_metrics(pred_starts, pred_ends, true_starts, true_ends):
    em_sum = 0
    f1_sum = 0
    num_examples = len(pred_starts)
    if num_examples == 0: return 0.0, 0.0

    for i in range(num_examples):
        ps, pe = pred_starts[i], pred_ends[i]
        ts, te = true_starts[i], true_ends[i]

        # Handle potential ignore_index values in true labels if they weren't filtered
        if ts == criterion.ignore_index or te == criterion.ignore_index:
            num_examples -=1 # Don't count this for EM/F1 if it was an ignored sample
            continue

        # Ensure predicted end is not before predicted start
        if pe < ps: pe = ps

        # Exact Match (token level)
        if ps == ts and pe == te:
            em_sum += 1

        # F1 Score (token level)
        pred_tokens = set(range(ps, pe + 1))
        true_tokens = set(range(ts, te + 1))

        common_tokens = len(pred_tokens.intersection(true_tokens))
        if common_tokens == 0:
            f1 = 0.0
        else:
            precision = common_tokens / len(pred_tokens) if len(pred_tokens) > 0 else 0
            recall = common_tokens / len(true_tokens) if len(true_tokens) > 0 else 0
            f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        f1_sum += f1

    return (em_sum / num_examples) * 100 if num_examples > 0 else 0.0, \
           (f1_sum / num_examples) * 100 if num_examples > 0 else 0.0

# --- Training Loop ---
def train_epoch(model, dataloader, optimizer, criterion, device, clip_norm):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False, dynamic_ncols=True)

    for batch in progress_bar:
        try:
            context_ids = batch['context_input_ids'].to(device)
            context_mask = batch['context_attention_mask'].to(device)
            context_char_ids = batch['context_token_char_ids'].to(device)
            question_ids = batch['question_input_ids'].to(device)
            question_mask = batch['question_attention_mask'].to(device)
            question_char_ids = batch['question_token_char_ids'].to(device)
            true_start_indices = batch['start_token_bert'].to(device)
            true_end_indices = batch['end_token_bert'].to(device)
        except KeyError as e:
            print(f"KeyError in batch: {e}. Available keys: {batch.keys()}")
            raise e

        optimizer.zero_grad()

        start_logits, end_logits = model(
            context_bert_ids=context_ids, context_bert_mask=context_mask, context_bert_char_ids=context_char_ids,
            question_bert_ids=question_ids, question_bert_mask=question_mask, question_bert_char_ids=question_char_ids
        )

        # Filter out ignored indices before loss calculation if they are not handled by ignore_index
        # This is important if ignore_index is, e.g., 0, and 0 can be a valid target.
        # However, CrossEntropyLoss with ignore_index should handle this.
        # Let's assume ignore_index in criterion handles it.

        loss_start = criterion(start_logits, true_start_indices)
        loss_end = criterion(end_logits, true_end_indices)
        total_loss = loss_start + loss_end

        # Handle cases where loss might be NaN (e.g., if all targets in a batch are ignored)
        if torch.isnan(total_loss):
            print("Warning: NaN loss detected. Skipping batch.")
            # Optionally, print details of the batch that caused NaN
            # print("Problematic batch start_logits:", start_logits)
            # print("Problematic batch true_start_indices:", true_start_indices)
            continue # Skip optimizer step and loss accumulation for this batch

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()

        epoch_loss += total_loss.item()
        progress_bar.set_postfix({'loss': f'{total_loss.item():.4f}'})

    return epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0

# --- Evaluation Loop ---
def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    all_pred_starts, all_pred_ends = [], []
    all_true_starts, all_true_ends = [], []

    progress_bar = tqdm(dataloader, desc="Evaluating", leave=False, dynamic_ncols=True)

    with torch.no_grad():
        for batch in progress_bar:
            context_ids = batch['context_input_ids'].to(device)
            context_mask = batch['context_attention_mask'].to(device)
            context_char_ids = batch['context_token_char_ids'].to(device)
            question_ids = batch['question_input_ids'].to(device)
            question_mask = batch['question_attention_mask'].to(device)
            question_char_ids = batch['question_token_char_ids'].to(device)
            true_start_indices = batch['start_token_bert'].to(device)
            true_end_indices = batch['end_token_bert'].to(device)

            start_logits, end_logits = model(
                context_bert_ids=context_ids, context_bert_mask=context_mask, context_bert_char_ids=context_char_ids,
                question_bert_ids=question_ids, question_bert_mask=question_mask, question_bert_char_ids=question_char_ids
            )

            loss_start = criterion(start_logits, true_start_indices)
            loss_end = criterion(end_logits, true_end_indices)
            total_loss = loss_start + loss_end

            if not torch.isnan(total_loss): # Only accumulate if loss is valid
                 epoch_loss += total_loss.item()

            pred_start_batch = torch.argmax(start_logits, dim=1)
            pred_end_batch = torch.argmax(end_logits, dim=1)

            all_pred_starts.extend(pred_start_batch.cpu().tolist())
            all_pred_ends.extend(pred_end_batch.cpu().tolist())
            all_true_starts.extend(true_start_indices.cpu().tolist())
            all_true_ends.extend(true_end_indices.cpu().tolist())

            progress_bar.set_postfix({'loss': f'{total_loss.item() if not torch.isnan(total_loss) else "NaN":.4f}'})

    avg_loss = epoch_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    em, f1 = compute_metrics(all_pred_starts, all_pred_ends, all_true_starts, all_true_ends)

    return avg_loss, em, f1

# --- Main Training Orchestration ---
best_val_f1 = -1.0

print("\nStarting training with BiDAF_BERT_Char model...")
for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device, CLIP_GRAD_NORM)
    val_loss, val_em, val_f1 = evaluate(model, val_dataloader, criterion, device)

    print(f"Epoch {epoch+1} Summary:")
    print(f"\tTrain Loss: {train_loss:.4f}")
    print(f"\tVal Loss  : {val_loss:.4f} | Val EM: {val_em:.2f}% | Val F1: {val_f1:.2f}%")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f"\tNew best model saved to {BEST_MODEL_PATH} (F1: {best_val_f1:.2f}%)")

print("\nTraining complete.")
if os.path.exists(BEST_MODEL_PATH):
    print(f"Best Validation F1: {best_val_f1:.2f}% (Model saved at {BEST_MODEL_PATH})")
else:
    print(f"No model was saved. Best Validation F1: {best_val_f1:.2f}%")


Using device: cuda
CUDA Device Name: NVIDIA GeForce GTX 1080
CUDA Memory Allocated: 0.00 MB
CUDA Memory Cached: 0.00 MB
Character vocabulary loaded. Size: 266, CHAR_PADDING_IDX: 0
Loading processed training data from ./squad_train_processed_bert_char.hf...


Loading dataset from disk:   0%|          | 0/23 [00:00<?, ?it/s]

Loading processed validation data from ./squad_val_processed_bert_char.hf...
Processed datasets loaded.
DataLoaders created. Train batches: 5475, Val batches: 661
BiDAF_BERT_Char Model instantiated. Trainable Parameters: 5,243,760

Starting training with BiDAF_BERT_Char model...

--- Epoch 1/10 ---


                                                                            

Epoch 1 Summary:
	Train Loss: 5.1313
	Val Loss  : 3.2444 | Val EM: 42.05% | Val F1: 59.12%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 59.12%)

--- Epoch 2/10 ---


                                                                            

Epoch 2 Summary:
	Train Loss: 3.3016
	Val Loss  : 2.6746 | Val EM: 49.77% | Val F1: 67.66%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 67.66%)

--- Epoch 3/10 ---


                                                                            

Epoch 3 Summary:
	Train Loss: 2.9623
	Val Loss  : 2.5492 | Val EM: 51.78% | Val F1: 69.41%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 69.41%)

--- Epoch 4/10 ---


                                                                            

Epoch 4 Summary:
	Train Loss: 2.8030
	Val Loss  : 2.4688 | Val EM: 52.21% | Val F1: 70.37%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 70.37%)

--- Epoch 5/10 ---


                                                                            

Epoch 5 Summary:
	Train Loss: 2.6992
	Val Loss  : 2.4195 | Val EM: 52.84% | Val F1: 70.92%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 70.92%)

--- Epoch 6/10 ---


                                                                            

Epoch 6 Summary:
	Train Loss: 2.6292
	Val Loss  : 2.4440 | Val EM: 52.90% | Val F1: 71.32%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 71.32%)

--- Epoch 7/10 ---


                                                                            

Epoch 7 Summary:
	Train Loss: 2.5791
	Val Loss  : 2.3963 | Val EM: 53.55% | Val F1: 71.47%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 71.47%)

--- Epoch 8/10 ---


                                                                            

Epoch 8 Summary:
	Train Loss: 2.5427
	Val Loss  : 2.3255 | Val EM: 54.05% | Val F1: 72.32%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 72.32%)

--- Epoch 9/10 ---


                                                                            

Epoch 9 Summary:
	Train Loss: 2.5027
	Val Loss  : 2.3352 | Val EM: 54.15% | Val F1: 72.68%
	New best model saved to ./saved_models\bidaf_bert_char_best.pt (F1: 72.68%)

--- Epoch 10/10 ---


                                                                            

Epoch 10 Summary:
	Train Loss: 2.4699
	Val Loss  : 2.3566 | Val EM: 53.84% | Val F1: 72.23%

Training complete.
Best Validation F1: 72.68% (Model saved at ./saved_models\bidaf_bert_char_best.pt)




In [8]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()