## Imports

In [1]:
import os
import zipfile
import nltk
import json
import optuna
import random
import joblib
import datetime
import numpy as np
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset, load_from_disk
from pprint import pprint
from PIL import Image
from pathlib import Path
from matplotlib import pyplot as plt

import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from huggingface_hub import hf_hub_download

nltk.download('punkt_tab', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

  from .autonotebook import tqdm as notebook_tqdm


True

## Constants

In [2]:
IMG_SIZE = (224, 224)
VOCAB_SIZE = 5000
BATCH_SIZE = 32
MAX_NODES_PER_QUESTION = 10

# Directory Information
DATA_DIR = "data/"
DATASET_PATH = os.path.join(DATA_DIR, 'dataset/')
IMAGE_PATH = os.path.join(DATA_DIR, 'imgs/')
VOCABS_PATH = os.path.join(DATA_DIR, 'vocabs/')
HYPERPARAMETERS_RESULT_PATH = os.path.join(DATA_DIR, 'tuning/')
FINAL_MODEL_PATH = os.path.join(DATA_DIR, 'final_model/')

# Huggingface Repository Information
repo_id = "BoKelvin/SLAKE"
repo_type = "dataset"
img_file = "imgs.zip"

# Seeding
GLOBAL_SEED = 42

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
def set_global_seed():
    random.seed(GLOBAL_SEED)
    np.random.seed(GLOBAL_SEED)
    torch.manual_seed(GLOBAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(GLOBAL_SEED)
        torch.cuda.manual_seed_all(GLOBAL_SEED)
        # For deterministic CuDNN operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_global_seed()

## Dataset Setup

### Dataset Download

In [4]:
# Utility function for downloading and extracting ZIP file
def download_and_store_ZIP(filename, save_dir):
    print(f"Fetching file {filename} from {repo_id} repo")

    try:
        # Caches the file locally and returns the path to the cached file
        cached_zip_path = hf_hub_download(
          repo_id=repo_id,
          filename=filename,
          repo_type=repo_type
        )
        print(f"{filename} download complete. Cached at: {cached_zip_path}")

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Extract the contents
        print(f"Extracting to {save_dir}...")
        with zipfile.ZipFile(cached_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print("Extraction complete.")
        print(f"{filename} files are located in: {os.path.abspath(save_dir)}")
    except Exception as e:
        print(f"Failed to download or extract {filename}: {e}")

# Scoping to English only
def filter_language(original):
    return original.filter(lambda data: data['q_lang'] == 'en')

# Download and store the dataset
def download_and_store_english_dataset():
    print(f"Downloading dataset from {repo_id} repo")

    # Load from Hugging Face
    original = load_dataset(repo_id)

    # Scope to English Only
    original = filter_language(original)

    # Show the dataset formatting
    pprint(original)

    # Save the original dataset
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    if not os.path.exists(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    original.save_to_disk(DATASET_PATH)
    return original

# Download and store the image files
def download_and_store_image():
    download_and_store_ZIP(img_file, DATA_DIR)

# Download necessary files
def download_and_store_slake():
    dataset = download_and_store_english_dataset()
    download_and_store_image()

    return dataset

### Vocabulary Builder

In [5]:
class VocabularyBuilder:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def tokenize(self, text):
        return nltk.word_tokenize(text.lower())
    
    def __len__(self):
        return len(self.stoi)
    
    def build_word_vocabs(self, sentences):
        counter = Counter()
        start_index = len(self.stoi)

        # 1. Count frequencies of all tokens in the tokenized sentences
        for sentence in sentences:
            tokens = self.tokenize(sentence)
            counter.update(tokens)

        # 2. Add words that meet the frequency threshold
        for word, count in counter.items():
            if count >= self.min_freq and word not in self.stoi:
                self.stoi[word] = start_index
                self.itos[start_index] = word
                start_index += 1

        print(f"Vocabulary Built. Vocabulary Size: {len(self.stoi)}")

    def numericalize(self, text):
        tokens = self.tokenize(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokens
        ]

In [6]:
# Build vocabularies for questions and answers
def build_vocabs(dataset):
    questions = [item['question'] for item in dataset]
    answers = [item['answer'] for item in dataset]

    # Question Vocabulary
    questvocab_builder = VocabularyBuilder(min_freq=1)
    questvocab_builder.build_word_vocabs(questions)
    
    # Answer Vocabulary
    ansvocab_builder = VocabularyBuilder(min_freq=1)

    # Use a dummy tokenizer that just returns the whole lowercased string as one token
    identity_tokenizer = lambda x: [x.lower().strip()]
    ansvocab_builder.tokenize = identity_tokenizer

    ansvocab_builder.build_word_vocabs(answers)

    return questvocab_builder, ansvocab_builder

# Save vocabularies to JSON files
def save_vocabs(quest_vocab, ans_vocab):
    if not os.path.exists(VOCABS_PATH):
        os.makedirs(VOCABS_PATH)

    # Save Question Vocabulary
    with open(os.path.join(VOCABS_PATH, 'question_vocab.json'), 'w') as f:
        json.dump({'stoi': quest_vocab.stoi, 'itos': quest_vocab.itos}, f)

    # Save Answer Vocabulary
    with open(os.path.join(VOCABS_PATH, 'answer_vocab.json'), 'w') as f:
        json.dump({'stoi': ans_vocab.stoi, 'itos': ans_vocab.itos}, f)

    print("Vocabularies saved successfully.")

### Dataset Class

In [7]:
class SlakeDataset(Dataset):
    def __init__(self, dataset, question_vocab, answer_vocab, transform=None, cache_images=True):
        self.data = dataset
        self.question_vocab = question_vocab
        self.answer_vocab = answer_vocab
        self.transform = transform
        self.cache_images = cache_images

        # Caching
        self.image_cache = {}
        if self.cache_images:
            print(f"Caching images for into RAM...")
            # Get unique image names to avoid duplicate loading
            unique_imgs = set(item['img_name'] for item in self.data)
            
            for img_name in unique_imgs:
                path = os.path.join(IMAGE_PATH, img_name)
                # Load and convert to RGB
                img = Image.open(path).convert('RGB')
                
                # Resize immediately to save RAM and CPU later
                img = img.resize((224, 224)) 
                
                self.image_cache[img_name] = img
            print(f"Cached {len(self.image_cache)} images.")

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Image Processing
        image_path = item['img_name']

        if self.cache_images:
            # Get from RAM
            image = self.image_cache[image_path]
        else:
            # Load from Disk and Resize
            img_path = os.path.join(IMAGE_PATH, image_path)
            image = Image.open(img_path).convert('RGB')
            image = image.resize((224, 224))

        if self.transform:
            image = self.transform(image)

        # 2. Question Processing
        question = item['question']
        question_indices = self.question_vocab.numericalize(question)

        # 3. Answer Processing
        answer = str(item.get('answer', '')) # Answer may be missing in test set
        answer_index = self.answer_vocab.numericalize(answer)

        return {
            'image': image,
            'question' : torch.tensor(question_indices),
            'answer' : torch.tensor(answer_index, dtype=torch.long),
            # Add original items for reference
            'original_question': question,
            'original_answer': answer,
            # Add ID for tracking
            'id': item['qid']
        }

### Collate Function

In [8]:
def slake_collate_fn(batch, pad_index=0):
    # Separate different components
    images = []
    questions = []
    answers = []
    original_questions = []
    original_answers = []
    ids = []
    
    for item in batch:
        images.append(item['image'])
        questions.append(item['question'])
        answers.append(item['answer'])
        original_questions.append(item['original_question'])
        original_answers.append(item['original_answer'])
        ids.append(item['id'])
    
    # Stack images
    images = torch.stack(images)  # [batch_size, 3, H, W]
    
    # Get question lengths BEFORE padding
    question_lengths = torch.tensor([len(q) for q in questions])
    
    # Pad questions to the longest sequence in THIS batch
    # pad_sequence expects list of tensors, pads with 0 by default
    questions_padded = pad_sequence(questions, batch_first=True, padding_value=pad_index)
    # questions_padded: [batch_size, max_len_in_batch]
    
    # Handling answers
    # Handling each answer as a single class
    # answers = torch.stack(answers)
    answers = torch.tensor([item['answer'] for item in batch])
    
    return {
        'image': images,
        'question': questions_padded,
        'question_lengths': question_lengths,
        'answer': answers,
        'original_question': original_questions,
        'original_answer': original_answers,
        'id': ids
    }

## Preparation

In [None]:
# Comment out if dataset is already downloaded
# dataset = download_and_store_slake()

# Uncomment if dataset is already downloaded
dataset = load_from_disk(DATASET_PATH)

# Build vocabularies for training
train_data = dataset['train']
validation_data = dataset['validation']
test_data = dataset['test']
question_vocab, answer_vocab = build_vocabs(train_data)

# Define image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create train dataset and dataloader
train_dataset = SlakeDataset(train_data, question_vocab, answer_vocab, transform=transform)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=slake_collate_fn
)

validation_dataset = SlakeDataset(validation_data, question_vocab, answer_vocab, transform=transform)
validation_loader = DataLoader(
    validation_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=slake_collate_fn
)

test_dataset = SlakeDataset(test_data, question_vocab, answer_vocab, transform=transform)

Vocabulary Built. Vocabulary Size: 281
Vocabulary Built. Vocabulary Size: 225
Caching images for into RAM...


## Modeling Baseline

CNN with Bidirectional LSTM with Self-Attention

In [None]:
# Bidirectional LSTM with Self-Attention for question encoding
class BiLSTMWithSelfAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=512, num_layers=1, 
                 dropout=0.5, pooling_strategy='mean', attention_heads=8):
        super(BiLSTMWithSelfAttention, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.pooling_strategy = pooling_strategy
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Bidirectional LSTM
        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Self-attention mechanism
        # BiLSTM outputs hidden_dim * 2 (forward + backward)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=attention_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        
    def forward(self, questions, question_lengths=None):
        # Embed questions
        embeds = self.embedding(questions)  # [B, seq_len, embed_dim]
        embeds = self.dropout(embeds)
        
        # Pack sequence if lengths provided (for efficiency)
        if question_lengths is not None:
            embeds = nn.utils.rnn.pack_padded_sequence(
                embeds, question_lengths.cpu(), 
                batch_first=True, enforce_sorted=False
            )
        
        # BiLSTM encoding
        lstm_out, (hidden, cell) = self.bilstm(embeds)
        
        # Unpack if needed
        if question_lengths is not None:
            lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
                lstm_out, batch_first=True
            )
        
        # lstm_out: [B, seq_len, hidden_dim * 2]
        
        # Self-attention: query = key = value = lstm_out
        attn_out, attn_weights = self.attention(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            need_weights=True
        )
        
        # Residual connection + Layer Norm
        attn_out = self.layer_norm(lstm_out + attn_out)
        attn_out = self.dropout(attn_out)
        
        # Pooling strategy - you can experiment with these:
        if self.pooling_strategy == 'mean':
            question_feature = attn_out.mean(dim=1)  # [B, hidden_dim * 2]
        elif self.pooling_strategy == 'max':
            question_feature = attn_out.max(dim=1)[0]
        else:
            # Last hidden state (concatenate forward and backward)
            question_feature = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        return question_feature, attn_weights

In [None]:
# Complete VQA model: ResNet34 + BiLSTM with Self-Attention
class VQA_ResNet_BiLSTM_Attention(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=300, 
                 lstm_hidden=512, fusion_dim=1024, lstm_dropout=0.5, 
                 lstm_num_layers=1, attention_heads=8, fusion_dropout=0.5,
                 pooling_strategy='mean'):
        super(VQA_ResNet_BiLSTM_Attention, self).__init__()
        
        # Image encoder: ResNet34
        resnet = models.resnet34(pretrained=True)
        # Remove the final FC layer
        self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.image_feature_dim = 512  # ResNet34 final layer
        
        # Question encoder: BiLSTM + Self-Attention
        self.question_encoder = BiLSTMWithSelfAttention(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            hidden_dim=lstm_hidden,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout,
            attention_heads=attention_heads,
            pooling_strategy=pooling_strategy
        )
        self.question_feature_dim = lstm_hidden * 2  # Bidirectional
        
        # Multimodal fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.image_feature_dim + self.question_feature_dim, fusion_dim),
            nn.BatchNorm1d(fusion_dim),
            nn.ReLU(),
            nn.Dropout(fusion_dropout),
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.BatchNorm1d(fusion_dim // 2),
            nn.ReLU(),
            nn.Dropout(fusion_dropout)
        )
        
        # Classifier
        self.classifier = nn.Linear(fusion_dim // 2, num_classes)
        
    def forward(self, images, questions, question_lengths=None):
        # Extract image features
        img_features = self.image_encoder(images)  # [B, 512, 1, 1]
        img_features = img_features.squeeze(-1).squeeze(-1)  # [B, 512]
        
        # Extract question features with attention
        q_features, attn_weights = self.question_encoder(questions, question_lengths) # [B, lstm_hidden * 2]
        
        # Concatenate image and question features
        combined = torch.cat([img_features, q_features], dim=1)
        # combined: [B, 512 + lstm_hidden*2]
        
        # Fusion
        fused = self.fusion(combined)  # [B, fusion_dim // 2]
        
        # Classification
        logits = self.classifier(fused)  # [B, num_classes]
        
        return logits

## Hyperparameter Tuning

In [None]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        questions = batch['question'].to(device)
        question_lengths = batch['question_lengths'].to(device)
        answers = batch['answer'].to(device)
        
        # Forward
        logits = model(images, questions, question_lengths)
        loss = criterion(logits, answers)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == answers).sum().item()
        total += answers.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    return total_loss / len(dataloader), 100 * correct / total

def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validating'):
            images = batch['image'].to(device)
            questions = batch['question'].to(device)
            question_lengths = batch['question_lengths'].to(device)
            answers = batch['answer'].to(device)
            
            logits = model(images, questions, question_lengths)
            loss = criterion(logits, answers)
            
            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            correct += (predictions == answers).sum().item()
            total += answers.size(0)
    
    return total_loss / len(dataloader), 100 * correct / total

In [None]:
class HyperparameterTuner:
    def __init__(self, train_dataset, validation_dataset, vocab_size, num_classes, 
                 n_trials=50):
        self.train_dataset = train_dataset
        self.validation_dataset = validation_dataset
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.n_trials = n_trials
        self.results_dir = Path(HYPERPARAMETERS_RESULT_PATH)
        self.results_dir.mkdir(exist_ok=True)
        
        # Track all trial results
        self.trial_results = []

    def config_BLSTM(self, trial):
        return {
            # Embedding parameters
            'embed_dim': trial.suggest_categorical('embed_dim', [200, 300, 512]),

            # LSTM parameters
            'lstm_hidden': trial.suggest_categorical('lstm_hidden', [256, 512, 768, 1024]),
            'lstm_num_layers': trial.suggest_int('lstm_num_layers', 1, 3),
            'lstm_dropout': trial.suggest_float('lstm_dropout', 0.1, 0.6),
            'pooling_strategy': trial.suggest_categorical('pooling_strategy', ['mean', 'max', 'last']),

            # Attention parameters
            'attention_heads': trial.suggest_categorical('attention_heads', [4, 8, 16]),

            # Fusion parameters
            'fusion_dim': trial.suggest_categorical('fusion_dim', [512, 1024, 2048]),
            'fusion_dropout': trial.suggest_float('fusion_dropout', 0.2, 0.6),

            # Training parameters
            'batch_size': trial.suggest_categorical('batch_size', [16, 32, 64]),
            'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1e-3),
            'weight_decay': trial.suggest_loguniform('weight_decay', 1e-6, 1e-3),
            'scheduler_step_size': trial.suggest_int('scheduler_step_size', 5, 15),
            'scheduler_gamma': trial.suggest_float('scheduler_gamma', 0.3, 0.7),
        }

    def objective(self, trial):
        print(f"Trial {trial.number + 1}/{self.n_trials}")

        config = self.config_BLSTM(trial)
        model = VQA_ResNet_BiLSTM_Attention(
            vocab_size=self.vocab_size,
            num_classes=self.num_classes,
            embed_dim=config['embed_dim'],
            lstm_hidden=config['lstm_hidden'],
            lstm_num_layers=config['lstm_num_layers'],
            attention_heads=config['attention_heads'],
            fusion_dim=config['fusion_dim'],
            lstm_dropout=config['lstm_dropout'],
            fusion_dropout=config['fusion_dropout'],
            pooling_strategy=config['pooling_strategy']
        ).to(device)
        
        for param in model.image_encoder.parameters():
            param.requires_grad = False
                
        print(f"Config: {json.dumps(config, indent=2)}")

        train_loader = DataLoader(
            self.train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            self.validation_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['scheduler_step_size'],
            gamma=config['scheduler_gamma']
        )

        best_val_acc = 0.0
        threshold = 5
        threshold_count = 0
        max_epochs = 30

        for epoch in range(max_epochs):
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer
            )

            val_loss, val_acc = validate(
                model, val_loader, criterion
            )

            scheduler.step()
            print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                threshold_count = 0
            else:
                threshold_count += 1
            
            if threshold_count >= threshold:
                print(f"Early stopping at epoch {epoch+1}")
                break

            trial.report(val_acc, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()

        trial_result = {
            'trial_number': trial.number,
            'config': config,
            'best_val_acc': best_val_acc,
            'final_epoch': epoch + 1
        }
        self.trial_results.append(trial_result)
        
        return best_val_acc
    
    def save_results(self, study):
        # Save best parameters
        best_params_path = self.results_dir / f'best_params_BLSTM.json'
        with open(best_params_path, 'w') as f:
            json.dump({
                'best_params': study.best_params,
                'best_value': study.best_value,
                'best_trial': study.best_trial.number
            }, f, indent=2)
        
        # Save all trial results
        all_results_path = self.results_dir / f'all_trials_BLSTM.json'
        with open(all_results_path, 'w') as f:
            json.dump(self.trial_results, f, indent=2)
        
        # Save study
        study_path = self.results_dir / f'study_BLSTM.pkl'
        joblib.dump(study, study_path)
        
        print(f"\nResults saved to: {self.results_dir}")

    def run(self):
        print("STARTING HYPERPARAMETER TUNING FOR BLSTM MODEL\n")
        
        study = optuna.create_study(
            direction='maximize',
            pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
            sampler=optuna.samplers.TPESampler(seed=GLOBAL_SEED)
        )

        study.optimize(self.objective, n_trials=self.n_trials)
        self.save_results(study)

        # Print best results
        print("HYPERPARAMETER TUNING COMPLETE")
        print(f"Best Trial: {study.best_trial.number}")
        print(f"Best Validation Accuracy: {study.best_value:.2f}%\n")
        print(f"Best Hyperparameters:")
        for key, value in study.best_params.items():
            print(f"  {key}: {value}")
        
        return study

In [None]:
# Tune Hyperparameters for a BLSTM model
tuner = HyperparameterTuner(
    vocab_size=len(question_vocab),
    num_classes=len(answer_vocab),
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
    n_trials=50
)

# Run tuning
study = tuner.run()

In [None]:
# Load the best hyperparameter for the model
best_params_path = os.path.join(HYPERPARAMETERS_RESULT_PATH, 'best_params_BLSTM.json')
with open(best_params_path, 'r') as f:
    best_params = json.load(f)

best_params

## Train the final model

In [None]:
class FinalModelTrainer:
    def __init__(self, train_dataset, val_dataset, test_dataset, 
                 best_params, vocab_size, num_classes):
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.best_params = best_params
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        
        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'learning_rates': []
        }
        
        # Create results directory
        self.results_dir = Path(FINAL_MODEL_PATH)
        self.results_dir.mkdir(exist_ok=True)
    
    def final_evaluation(self, model):
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=32,
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        model.eval()
        
        # Store all predictions and results
        all_predictions = []
        all_targets = []
        all_ids = []
        
        # Type-specific tracking
        type_stats = {
            'CLOSED': {'correct': 0, 'total': 0, 'predictions': [], 'targets': []},
            'OPEN': {'correct': 0, 'total': 0, 'predictions': [], 'targets': []}
        }
        
        # Get predictions
        with torch.no_grad():
            for batch in tqdm(test_loader, desc='Testing'):
                images = batch['image'].to(device)
                questions = batch['question'].to(device)
                question_lengths = batch['question_lengths'].to(device)
                answers = batch['answer'].to(device)
                
                logits = model(images, questions, question_lengths)
                predictions = torch.argmax(logits, dim=1)
                
                all_predictions.extend(predictions.cpu().tolist())
                all_targets.extend(answers.cpu().tolist())
                all_ids.extend(batch['id'])
        
        # Categorize by answer type
        for pred, target, qid in zip(all_predictions, all_targets, all_ids):
            # Find the question in the dataset
            item = next((x for x in self.test_dataset.data if x['qid'] == qid), None)
            
            if item is not None:
                answer_type = item.get('answer_type', 'OPEN').upper()
                
                # Ensure answer_type is in our tracking dict
                if answer_type not in type_stats:
                    type_stats[answer_type] = {
                        'correct': 0, 'total': 0, 
                        'predictions': [], 'targets': []
                    }
                
                type_stats[answer_type]['total'] += 1
                type_stats[answer_type]['predictions'].append(pred)
                type_stats[answer_type]['targets'].append(target)
                
                if pred == target:
                    type_stats[answer_type]['correct'] += 1
        
        # Calculate accuracies
        overall_correct = sum(p == t for p, t in zip(all_predictions, all_targets))
        overall_total = len(all_predictions)
        overall_acc = 100 * overall_correct / overall_total if overall_total > 0 else 0
        
        type_accuracies = {}
        for answer_type, stats in type_stats.items():
            if stats['total'] > 0:
                acc = 100 * stats['correct'] / stats['total']
                type_accuracies[answer_type] = acc
        
        # Print detailed results
        print("DETAILED EVALUATION RESULTS")
        print(f"Overall Accuracy: {overall_acc:.2f}% ({overall_correct}/{overall_total})")
        print(f"\nPerrformance on Answer Types:")
        
        for answer_type in sorted(type_stats.keys()):
            stats = type_stats[answer_type]
            if stats['total'] > 0:
                acc = type_accuracies[answer_type]
                print(f"  {answer_type:12s}: {acc:6.2f}% ({stats['correct']:4d}/{stats['total']:4d})")
        
        # Prepare results dictionary
        results = {
            'overall_accuracy': overall_acc,
            'overall_correct': overall_correct,
            'overall_total': overall_total,
            'type_accuracies': type_accuracies,
            'type_stats': {
                answer_type: {
                    'accuracy': type_accuracies.get(answer_type, 0),
                    'correct': stats['correct'],
                    'total': stats['total']
                }
                for answer_type, stats in type_stats.items()
            },
            'predictions': all_predictions,
            'targets': all_targets,
            'ids': all_ids
        }
        
        return results
    
    def train(self, num_epochs=100, threshold=15, save_every=10):
        print("TRAINING FINAL MODEL WITH BEST HYPERPARAMETERS")
        print(f"Training for up to {num_epochs} epochs")
        print(f"Early stopping threshold: {threshold} epochs")

        print(f"\nBest hyperparameters:")
        print(json.dumps(self.best_params, indent=2))
        
        # 1. Create model with best hyperparameters
        model = VQA_ResNet_BiLSTM_Attention(
            vocab_size=self.vocab_size,
            num_classes=self.num_classes,
            embed_dim=self.best_params['best_params']['embed_dim'],
            lstm_hidden=self.best_params['best_params']['lstm_hidden'],
            lstm_num_layers=self.best_params['best_params']['lstm_num_layers'],
            lstm_dropout=self.best_params['best_params']['lstm_dropout'],
            pooling_strategy=self.best_params['best_params']['pooling_strategy'],
            attention_heads=self.best_params['best_params']['attention_heads'],
            fusion_dim=self.best_params['best_params']['fusion_dim'],
            fusion_dropout=self.best_params['best_params']['fusion_dropout'],
        ).to(device)
        
        # 2. Create dataloaders with best batch size
        batch_size = self.best_params['best_params']['batch_size']
        
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        # 3. Setup optimizer and scheduler with best parameters
        criterion = nn.CrossEntropyLoss()
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.best_params['best_params']['learning_rate'],
            weight_decay=self.best_params['best_params']['weight_decay']
        )
        
        # Scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.best_params['best_params']['scheduler_step_size'],
            gamma=self.best_params['best_params']['scheduler_gamma']
        )
        
        # 4. Training loop
        best_val_acc = 0.0
        best_epoch = 0
        threshold_counter = 0
        
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            
            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer
            )
            
            # Validate
            val_loss, val_acc = validate(
                model, val_loader, criterion
            )
            
            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(current_lr)
            
            # Print metrics
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
            print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
            print(f"Learning Rate: {current_lr:.6f}")
            
            # Learning rate scheduling
            scheduler.step()
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_epoch = epoch + 1
                threshold_counter = 0
                
                # Save best model
                self.save_checkpoint(
                    model, optimizer, epoch, val_acc, 
                    filename='best_model.pth'
                )
                print(f"New best model found with Val Acc: {val_acc:.2f}%")
            else:
                threshold_counter += 1
                print(f"No improvement ({threshold_counter}/{threshold})")
            
            # Save periodic checkpoint
            if (epoch + 1) % save_every == 0:
                self.save_checkpoint(
                    model, optimizer, epoch, val_acc,
                    filename=f'checkpoint_epoch_{epoch+1}.pth'
                )
            
            # Early stopping check
            if threshold_counter >= threshold:
                print(f"Early stopping triggered at epoch {epoch+1}")
                print(f"Best validation accuracy: {best_val_acc:.2f}% at epoch {best_epoch}")
                break
        
        # 5. Load best model and evaluate on test set
        print("FINAL EVALUATION ON TEST SET")
        
        self.load_checkpoint(model, 'best_model.pth')
        test_results = self.final_evaluation(model)
        
        # 6. Save training history and results
        self.save_results(test_results, best_epoch, best_val_acc)
        
        # 7. Plot training curves
        self.plot_training_curves()
        
        print("TRAINING COMPLETE!")
        print(f"Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
        print(f"\nTest Set Results:")
        print(f"  Overall Accuracy: {test_results['overall_accuracy']:.2f}%")
        if 'type_accuracies' in test_results:
            print(f"\n  By Answer Type:")
            for answer_type in sorted(test_results['type_accuracies'].keys()):
                acc = test_results['type_accuracies'][answer_type]
                total = test_results['type_stats'][answer_type]['total']
                print(f"    {answer_type:12s}: {acc:6.2f}% ({total:4d} samples)")
        print(f"\nResults saved to: {self.results_dir}")
        print(f"{'='*70}")
        
        return model, test_results
    
    def save_checkpoint(self, model, optimizer, epoch, val_acc, filename):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'best_params': self.best_params,
            'history': self.history
        }
        torch.save(checkpoint, self.results_dir / filename)
    
    def load_checkpoint(self, model, filename):
        checkpoint = torch.load(self.results_dir / filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from {filename} (Epoch {checkpoint['epoch']}, Val Acc: {checkpoint['val_acc']:.2f}%)")
    
    def save_results(self, test_results, best_epoch, best_val_acc):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        results = {
            'timestamp': timestamp,
            'best_hyperparameters': self.best_params,
            'best_epoch': best_epoch,
            'best_val_acc': best_val_acc,
            'test_results': {
                'overall_accuracy': test_results['overall_accuracy'],
                'overall_correct': test_results['overall_correct'],
                'overall_total': test_results['overall_total'],
                'type_accuracies': test_results['type_accuracies'],
                'type_stats': test_results['type_stats']
            },
            'training_history': self.history
        }
        
        with open(self.results_dir / f'final_results_{timestamp}.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        # Create type-specific accuracy plot
        self.plot_type_accuracies(test_results['type_accuracies'], test_results['type_stats'])
    
    def plot_training_curves(self):
        _, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        epochs = range(1, len(self.history['train_loss']) + 1)
        
        # Loss curves
        axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0, 0].plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Accuracy curves
        axes[0, 1].plot(epochs, self.history['train_acc'], 'b-', label='Train Acc', linewidth=2)
        axes[0, 1].plot(epochs, self.history['val_acc'], 'r-', label='Val Acc', linewidth=2)
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        axes[0, 1].set_title('Training and Validation Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Learning rate
        axes[1, 0].plot(epochs, self.history['learning_rates'], 'g-', linewidth=2)
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_title('Learning Rate Schedule')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Validation accuracy with best marker
        axes[1, 1].plot(epochs, self.history['val_acc'], 'r-', linewidth=2)
        best_epoch = np.argmax(self.history['val_acc']) + 1
        best_acc = max(self.history['val_acc'])
        axes[1, 1].scatter([best_epoch], [best_acc], color='gold', s=200, 
                          marker='*', edgecolors='black', linewidths=2, 
                          label=f'Best: {best_acc:.2f}% (Epoch {best_epoch})', zorder=5)
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Validation Accuracy (%)')
        axes[1, 1].set_title('Validation Accuracy Progress')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.results_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
        plt.close()

    def plot_type_accuracies(self, type_accuracies, type_stats):
        if not type_accuracies:
            return
        
        _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Sort by answer type
        answer_types = sorted(type_accuracies.keys())
        accuracies = [type_accuracies[t] for t in answer_types]
        totals = [type_stats[t]['total'] for t in answer_types]
        
        # Bar plot of accuracies
        colors = ['#2ecc71' if acc >= 80 else '#f39c12' if acc >= 70 else '#e74c3c' 
                  for acc in accuracies]
        bars = ax1.bar(answer_types, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
        ax1.set_xlabel('Answer Type', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
        ax1.set_title('Accuracy by Answer Type', fontsize=14, fontweight='bold')
        ax1.set_ylim([0, 100])
        ax1.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bar, acc in zip(bars, accuracies):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{acc:.1f}%',
                    ha='center', va='bottom', fontweight='bold', fontsize=10)
        
        # Pie chart of sample distribution
        colors_pie = ['#3498db', '#e67e22', '#9b59b6', '#1abc9c'][:len(answer_types)]
        wedges, texts, autotexts = ax2.pie(totals, labels=answer_types, autopct='%1.1f%%',
                                            colors=colors_pie, startangle=90,
                                            textprops={'fontsize': 11, 'fontweight': 'bold'})
        ax2.set_title('Sample Distribution by Answer Type', fontsize=14, fontweight='bold')
        
        # Add legend with counts
        legend_labels = [f'{t}: {type_stats[t]["total"]} samples' for t in answer_types]
        ax2.legend(legend_labels, loc='best', fontsize=10)
        
        plt.tight_layout()
        plt.savefig(self.results_dir / 'type_specific_accuracy.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Type-specific accuracy plot saved")

In [None]:
final_model_trainer = FinalModelTrainer(
    train_dataset,
    validation_dataset,
    test_dataset,
    best_params,
    len(question_vocab),
    len(answer_vocab)
)

In [None]:
final_model, test_results = final_model_trainer.train(
    num_epochs=100,
    threshold=15,
    save_every=10
)