## Imports

In [None]:
import os
import zipfile
import nltk
import json
import optuna
import random
import joblib
import re
from datetime import datetime
import numpy as np
from tqdm import tqdm
from collections import Counter, defaultdict
from datasets import load_dataset, load_from_disk
from pprint import pprint
from PIL import Image
from pathlib import Path

import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from huggingface_hub import hf_hub_download

# Metrics libraries
from sklearn.metrics import accuracy_score, f1_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import score as bert_score

from utils import plot_training_curves, plot_type_specific_comparison, plot_ngram_analysis, print_all_metrics

nltk.download('punkt_tab', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

## Constants

In [None]:
IMG_SIZE = (224, 224)
VOCAB_SIZE = 5000
BATCH_SIZE = 32
MAX_NODES_PER_QUESTION = 10

# Directory Information
DATA_DIR = "data/"
DATASET_PATH = os.path.join(DATA_DIR, 'dataset/')
IMAGE_PATH = os.path.join(DATA_DIR, 'imgs/')
VOCABS_PATH = os.path.join(DATA_DIR, 'vocabs/')
HYPERPARAMETERS_RESULT_PATH = os.path.join(DATA_DIR, 'tuning/')
FINAL_MODEL_PATH = os.path.join(DATA_DIR, 'final_model/')

# Huggingface Repository Information
repo_id = "BoKelvin/SLAKE"
repo_type = "dataset"
img_file = "imgs.zip"

# Seeding
GLOBAL_SEED = 42

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
def set_global_seed():
    random.seed(GLOBAL_SEED)
    np.random.seed(GLOBAL_SEED)
    torch.manual_seed(GLOBAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(GLOBAL_SEED)
        torch.cuda.manual_seed_all(GLOBAL_SEED)
        # For deterministic CuDNN operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_global_seed()

## Dataset Setup

### Dataset Download

In [None]:
# Utility function for downloading and extracting ZIP file
def download_and_store_ZIP(filename, save_dir):
    print(f"Fetching file {filename} from {repo_id} repo")

    try:
        # Caches the file locally and returns the path to the cached file
        cached_zip_path = hf_hub_download(
          repo_id=repo_id,
          filename=filename,
          repo_type=repo_type
        )
        print(f"{filename} download complete. Cached at: {cached_zip_path}")

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Extract the contents
        print(f"Extracting to {save_dir}...")
        with zipfile.ZipFile(cached_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print("Extraction complete.")
        print(f"{filename} files are located in: {os.path.abspath(save_dir)}")
    except Exception as e:
        print(f"Failed to download or extract {filename}: {e}")

# Scoping to English only
def filter_language(original):
    return original.filter(lambda data: data['q_lang'] == 'en')

# Download and store the dataset
def download_and_store_english_dataset():
    print(f"Downloading dataset from {repo_id} repo")

    # Load from Hugging Face
    original = load_dataset(repo_id)

    # Scope to English Only
    original = filter_language(original)

    # Show the dataset formatting
    pprint(original)

    # Save the original dataset
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    if not os.path.exists(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    original.save_to_disk(DATASET_PATH)
    return original

# Download and store the image files
def download_and_store_image():
    download_and_store_ZIP(img_file, DATA_DIR)

# Download necessary files
def download_and_store_slake():
    dataset = download_and_store_english_dataset()
    download_and_store_image()

    return dataset

### Vocabulary Builder

In [None]:
class VocabularyBuilder:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def tokenize(self, text):
        return nltk.word_tokenize(text.lower())
    
    def __len__(self):
        return len(self.stoi)
    
    def build_word_vocabs(self, sentences):
        counter = Counter()
        start_index = len(self.stoi)

        # 1. Count frequencies of all tokens in the tokenized sentences
        for sentence in sentences:
            tokens = self.tokenize(sentence)
            counter.update(tokens)

        # 2. Add words that meet the frequency threshold
        for word, count in counter.items():
            if count >= self.min_freq and word not in self.stoi:
                self.stoi[word] = start_index
                self.itos[start_index] = word
                start_index += 1

        print(f"Vocabulary Built. Vocabulary Size: {len(self.stoi)}")

    def numericalize(self, text):
        tokens = self.tokenize(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokens
        ]

In [None]:
# Build vocabularies for questions and answers
def build_vocabs(dataset):
    questions = [item['question'] for item in dataset]
    answers = [item['answer'] for item in dataset]

    # Question Vocabulary
    questvocab_builder = VocabularyBuilder(min_freq=1)
    questvocab_builder.build_word_vocabs(questions)
    
    # Answer Vocabulary
    ansvocab_builder = VocabularyBuilder(min_freq=1)

    # Use a dummy tokenizer that just returns the whole lowercased string as one token
    identity_tokenizer = lambda x: [x.lower().strip()]
    ansvocab_builder.tokenize = identity_tokenizer

    ansvocab_builder.build_word_vocabs(answers)

    return questvocab_builder, ansvocab_builder

# Save vocabularies to JSON files
def save_vocabs(quest_vocab, ans_vocab):
    if not os.path.exists(VOCABS_PATH):
        os.makedirs(VOCABS_PATH)

    # Save Question Vocabulary
    with open(os.path.join(VOCABS_PATH, 'question_vocab.json'), 'w') as f:
        json.dump({'stoi': quest_vocab.stoi, 'itos': quest_vocab.itos}, f)

    # Save Answer Vocabulary
    with open(os.path.join(VOCABS_PATH, 'answer_vocab.json'), 'w') as f:
        json.dump({'stoi': ans_vocab.stoi, 'itos': ans_vocab.itos}, f)

    print("Vocabularies saved successfully.")

### Dataset Class

In [None]:
class SlakeDataset(Dataset):
    def __init__(self, dataset, question_vocab, answer_vocab, transform=None, cache_images=True):
        self.data = dataset
        self.question_vocab = question_vocab
        self.answer_vocab = answer_vocab
        self.transform = transform
        self.cache_images = cache_images

        # Caching
        self.image_cache = {}
        if self.cache_images:
            print(f"Caching images for into RAM...")
            # Get unique image names to avoid duplicate loading
            unique_imgs = set(item['img_name'] for item in self.data)
            
            for img_name in unique_imgs:
                path = os.path.join(IMAGE_PATH, img_name)
                # Load and convert to RGB
                img = Image.open(path).convert('RGB')
                
                # Resize immediately to save RAM and CPU later
                img = img.resize((224, 224)) 
                
                self.image_cache[img_name] = img
            print(f"Cached {len(self.image_cache)} images.")

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Image Processing
        image_path = item['img_name']

        if self.cache_images:
            # Get from RAM
            image = self.image_cache[image_path]
        else:
            # Load from Disk and Resize
            img_path = os.path.join(IMAGE_PATH, image_path)
            image = Image.open(img_path).convert('RGB')
            image = image.resize((224, 224))

        if self.transform:
            image = self.transform(image)

        # 2. Question Processing
        question = item['question']
        question_indices = self.question_vocab.numericalize(question)

        # 3. Answer Processing
        answer = str(item.get('answer', '')) # Answer may be missing in test set
        answer_index = self.answer_vocab.numericalize(answer)

        return {
            'image': image,
            'question' : torch.tensor(question_indices),
            'answer' : torch.tensor(answer_index, dtype=torch.long),
            # Add original items for reference
            'original_question': question,
            'original_answer': answer,
            # Add ID for tracking
            'id': item['qid']
        }

### Collate Function

In [None]:
def slake_collate_fn(batch, pad_index=0):
    # Separate different components
    images = []
    questions = []
    answers = []
    original_questions = []
    original_answers = []
    ids = []
    
    for item in batch:
        images.append(item['image'])
        questions.append(item['question'])
        answers.append(item['answer'])
        original_questions.append(item['original_question'])
        original_answers.append(item['original_answer'])
        ids.append(item['id'])
    
    # Stack images
    images = torch.stack(images)  # [batch_size, 3, H, W]
    
    # Get question lengths BEFORE padding
    question_lengths = torch.tensor([len(q) for q in questions])
    
    # Pad questions to the longest sequence in THIS batch
    # pad_sequence expects list of tensors, pads with 0 by default
    questions_padded = pad_sequence(questions, batch_first=True, padding_value=pad_index)
    # questions_padded: [batch_size, max_len_in_batch]
    
    # Handling answers
    # Handling each answer as a single class
    # answers = torch.stack(answers)
    answers = torch.tensor([item['answer'] for item in batch])
    
    return {
        'image': images,
        'question': questions_padded,
        'question_lengths': question_lengths,
        'answer': answers,
        'original_question': original_questions,
        'original_answer': original_answers,
        'id': ids
    }

## Preparation

In [None]:
# Comment out if dataset is already downloaded
# dataset = download_and_store_slake()

# Uncomment if dataset is already downloaded
dataset = load_from_disk(DATASET_PATH)

# Build vocabularies for training
train_data = dataset['train']
validation_data = dataset['validation']
test_data = dataset['test']
question_vocab, answer_vocab = build_vocabs(train_data)

# Define image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create train dataset and dataloader
train_dataset = SlakeDataset(train_data, question_vocab, answer_vocab, transform=transform)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=slake_collate_fn
)

validation_dataset = SlakeDataset(validation_data, question_vocab, answer_vocab, transform=transform)
validation_loader = DataLoader(
    validation_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=slake_collate_fn
)

test_dataset = SlakeDataset(test_data, question_vocab, answer_vocab, transform=transform)

## Modeling Baseline

CNN with Bidirectional LSTM with Self-Attention

In [None]:
# Bidirectional LSTM with Self-Attention for question encoding
class BiLSTMWithSelfAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=512, num_layers=1, 
                 dropout=0.5, pooling_strategy='mean', attention_heads=8):
        super(BiLSTMWithSelfAttention, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.pooling_strategy = pooling_strategy
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Bidirectional LSTM
        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Self-attention mechanism
        # BiLSTM outputs hidden_dim * 2 (forward + backward)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=attention_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        
    def forward(self, questions, question_lengths=None):
        # Embed questions
        embeds = self.embedding(questions)  # [B, seq_len, embed_dim]
        embeds = self.dropout(embeds)
        
        # Pack sequence if lengths provided (for efficiency)
        if question_lengths is not None:
            embeds = nn.utils.rnn.pack_padded_sequence(
                embeds, question_lengths.cpu(), 
                batch_first=True, enforce_sorted=False
            )
        
        # BiLSTM encoding
        lstm_out, (hidden, cell) = self.bilstm(embeds)
        
        # Unpack if needed
        if question_lengths is not None:
            lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
                lstm_out, batch_first=True
            )
        
        # lstm_out: [B, seq_len, hidden_dim * 2]
        
        # Self-attention: query = key = value = lstm_out
        attn_out, attn_weights = self.attention(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            need_weights=True
        )
        
        # Residual connection + Layer Norm
        attn_out = self.layer_norm(lstm_out + attn_out)
        attn_out = self.dropout(attn_out)
        
        # Pooling strategy - experiment with these:
        if self.pooling_strategy == 'mean':
            question_feature = attn_out.mean(dim=1)  # [B, hidden_dim * 2]
        elif self.pooling_strategy == 'max':
            question_feature = attn_out.max(dim=1)[0]
        else:
            # Last hidden state (concatenate forward and backward)
            question_feature = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        return question_feature, attn_weights

In [None]:
# Complete VQA model: ResNet34 + BiLSTM with Self-Attention
class VQA_ResNet_BiLSTM_Attention(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=300, 
                 lstm_hidden=512, fusion_dim=1024, lstm_dropout=0.5, 
                 lstm_num_layers=1, attention_heads=8, fusion_dropout=0.5,
                 pooling_strategy='mean'):
        super(VQA_ResNet_BiLSTM_Attention, self).__init__()
        
        # Image encoder: ResNet34
        resnet = models.resnet34(pretrained=True)
        # Remove the final FC layer
        self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.image_feature_dim = 512  # ResNet34 final layer
        
        # Question encoder: BiLSTM + Self-Attention
        self.question_encoder = BiLSTMWithSelfAttention(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            hidden_dim=lstm_hidden,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout,
            attention_heads=attention_heads,
            pooling_strategy=pooling_strategy
        )
        self.question_feature_dim = lstm_hidden * 2  # Bidirectional
        
        # Multimodal fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.image_feature_dim + self.question_feature_dim, fusion_dim),
            nn.BatchNorm1d(fusion_dim),
            nn.ReLU(),
            nn.Dropout(fusion_dropout),
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.BatchNorm1d(fusion_dim // 2),
            nn.ReLU(),
            nn.Dropout(fusion_dropout)
        )
        
        # Classifier
        self.classifier = nn.Linear(fusion_dim // 2, num_classes)
        
    def forward(self, images, questions, question_lengths=None):
        # Extract image features
        img_features = self.image_encoder(images)  # [B, 512, 1, 1]
        img_features = img_features.squeeze(-1).squeeze(-1)  # [B, 512]
        
        # Extract question features with attention
        q_features, attn_weights = self.question_encoder(questions, question_lengths) # [B, lstm_hidden * 2]
        
        # Concatenate image and question features
        combined = torch.cat([img_features, q_features], dim=1)
        # combined: [B, 512 + lstm_hidden*2]
        
        # Fusion
        fused = self.fusion(combined)  # [B, fusion_dim // 2]
        
        # Classification
        logits = self.classifier(fused)  # [B, num_classes]
        
        return logits

## Hyperparameter Tuning

In [None]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        questions = batch['question'].to(device)
        question_lengths = batch['question_lengths'].to(device)
        answers = batch['answer'].to(device)
        
        # Forward
        logits = model(images, questions, question_lengths)
        loss = criterion(logits, answers)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == answers).sum().item()
        total += answers.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    return total_loss / len(dataloader), 100 * correct / total

def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validating'):
            images = batch['image'].to(device)
            questions = batch['question'].to(device)
            question_lengths = batch['question_lengths'].to(device)
            answers = batch['answer'].to(device)
            
            logits = model(images, questions, question_lengths)
            loss = criterion(logits, answers)
            
            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            correct += (predictions == answers).sum().item()
            total += answers.size(0)
    
    return total_loss / len(dataloader), 100 * correct / total

In [None]:
class HyperparameterTuner:
    def __init__(self, train_dataset, validation_dataset, vocab_size, num_classes, 
                 n_trials=50):
        self.train_dataset = train_dataset
        self.validation_dataset = validation_dataset
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.n_trials = n_trials
        self.results_dir = Path(HYPERPARAMETERS_RESULT_PATH)
        self.results_dir.mkdir(exist_ok=True)
        
        # Track all trial results
        self.trial_results = []

    def config_BLSTM(self, trial):
        return {
            # Embedding parameters
            'embed_dim': trial.suggest_categorical('embed_dim', [200, 300, 512]),

            # LSTM parameters
            'lstm_hidden': trial.suggest_categorical('lstm_hidden', [256, 512, 768, 1024]),
            'lstm_num_layers': trial.suggest_int('lstm_num_layers', 1, 3),
            'lstm_dropout': trial.suggest_float('lstm_dropout', 0.1, 0.6),
            'pooling_strategy': trial.suggest_categorical('pooling_strategy', ['mean', 'max', 'last']),

            # Attention parameters
            'attention_heads': trial.suggest_categorical('attention_heads', [4, 8, 16]),

            # Fusion parameters
            'fusion_dim': trial.suggest_categorical('fusion_dim', [512, 1024, 2048]),
            'fusion_dropout': trial.suggest_float('fusion_dropout', 0.2, 0.6),

            # Training parameters
            'batch_size': trial.suggest_categorical('batch_size', [16, 32, 64]),
            'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1e-3),
            'weight_decay': trial.suggest_loguniform('weight_decay', 1e-6, 1e-3),
            'scheduler_step_size': trial.suggest_int('scheduler_step_size', 5, 15),
            'scheduler_gamma': trial.suggest_float('scheduler_gamma', 0.3, 0.7),
        }

    def objective(self, trial):
        print(f"Trial {trial.number + 1}/{self.n_trials}")

        config = self.config_BLSTM(trial)
        model = VQA_ResNet_BiLSTM_Attention(
            vocab_size=self.vocab_size,
            num_classes=self.num_classes,
            embed_dim=config['embed_dim'],
            lstm_hidden=config['lstm_hidden'],
            lstm_num_layers=config['lstm_num_layers'],
            attention_heads=config['attention_heads'],
            fusion_dim=config['fusion_dim'],
            lstm_dropout=config['lstm_dropout'],
            fusion_dropout=config['fusion_dropout'],
            pooling_strategy=config['pooling_strategy']
        ).to(device)
        
        for param in model.image_encoder.parameters():
            param.requires_grad = False
                
        print(f"Config: {json.dumps(config, indent=2)}")

        train_loader = DataLoader(
            self.train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            self.validation_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['scheduler_step_size'],
            gamma=config['scheduler_gamma']
        )

        best_val_acc = 0.0
        threshold = 5
        threshold_count = 0
        max_epochs = 30

        for epoch in range(max_epochs):
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer
            )

            val_loss, val_acc = validate(
                model, val_loader, criterion
            )

            scheduler.step()
            print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                threshold_count = 0
            else:
                threshold_count += 1
            
            if threshold_count >= threshold:
                print(f"Early stopping at epoch {epoch+1}")
                break

            trial.report(val_acc, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()

        trial_result = {
            'trial_number': trial.number,
            'config': config,
            'best_val_acc': best_val_acc,
            'final_epoch': epoch + 1
        }
        self.trial_results.append(trial_result)
        
        return best_val_acc
    
    def save_results(self, study):
        # Save best parameters
        best_params_path = self.results_dir / f'best_params_BLSTM.json'
        with open(best_params_path, 'w') as f:
            json.dump({
                'best_params': study.best_params,
                'best_value': study.best_value,
                'best_trial': study.best_trial.number
            }, f, indent=2)
        
        # Save all trial results
        all_results_path = self.results_dir / f'all_trials_BLSTM.json'
        with open(all_results_path, 'w') as f:
            json.dump(self.trial_results, f, indent=2)
        
        # Save study
        study_path = self.results_dir / f'study_BLSTM.pkl'
        joblib.dump(study, study_path)
        
        print(f"\nResults saved to: {self.results_dir}")

    def run(self):
        print("STARTING HYPERPARAMETER TUNING FOR BLSTM MODEL\n")
        
        study = optuna.create_study(
            direction='maximize',
            pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
            sampler=optuna.samplers.TPESampler(seed=GLOBAL_SEED)
        )

        study.optimize(self.objective, n_trials=self.n_trials)
        self.save_results(study)

        # Print best results
        print("HYPERPARAMETER TUNING COMPLETE")
        print(f"Best Trial: {study.best_trial.number}")
        print(f"Best Validation Accuracy: {study.best_value:.2f}%\n")
        print(f"Best Hyperparameters:")
        for key, value in study.best_params.items():
            print(f"  {key}: {value}")
        
        return study

In [None]:
# Tune Hyperparameters for a BLSTM model
tuner = HyperparameterTuner(
    vocab_size=len(question_vocab),
    num_classes=len(answer_vocab),
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
    n_trials=50
)

# Run tuning
study = tuner.run()

In [None]:
# Load the best hyperparameter for the model
best_params_path = os.path.join(HYPERPARAMETERS_RESULT_PATH, 'best_params_BLSTM.json')
with open(best_params_path, 'r') as f:
    best_params = json.load(f)

best_params

## Train the final model

In [None]:
class FinalModelTrainer:
    def __init__(self, train_dataset, val_dataset, test_dataset, 
                 best_params, vocab_size, num_classes):
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.best_params = best_params
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        
        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'learning_rates': []
        }
        
        # Create results directory
        self.results_dir = Path(FINAL_MODEL_PATH)
        self.results_dir.mkdir(exist_ok=True)
    
    def final_evaluation(self, model):
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=32,
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        model.eval()
        
        # Store all predictions and results
        all_predictions = []
        all_targets = []
        all_ids = []
        
        # Type-specific tracking
        type_stats = {
            'CLOSED': {'correct': 0, 'total': 0, 'predictions': [], 'targets': []},
            'OPEN': {'correct': 0, 'total': 0, 'predictions': [], 'targets': []}
        }
        
        # Get predictions
        with torch.no_grad():
            for batch in tqdm(test_loader, desc='Testing'):
                images = batch['image'].to(device)
                questions = batch['question'].to(device)
                question_lengths = batch['question_lengths'].to(device)
                answers = batch['answer'].to(device)
                
                logits = model(images, questions, question_lengths)
                predictions = torch.argmax(logits, dim=1)
                
                all_predictions.extend(predictions.cpu().tolist())
                all_targets.extend(answers.cpu().tolist())
                all_ids.extend(batch['id'])
        
        # Categorize by answer type
        for pred, target, qid in zip(all_predictions, all_targets, all_ids):
            # Find the question in the dataset
            item = next((x for x in self.test_dataset.data if x['qid'] == qid), None)
            
            if item is not None:
                answer_type = item.get('answer_type', 'OPEN').upper()
                
                # Ensure answer_type is in our tracking dict
                if answer_type not in type_stats:
                    type_stats[answer_type] = {
                        'correct': 0, 'total': 0, 
                        'predictions': [], 'targets': []
                    }
                
                type_stats[answer_type]['total'] += 1
                type_stats[answer_type]['predictions'].append(pred)
                type_stats[answer_type]['targets'].append(target)
                
                if pred == target:
                    type_stats[answer_type]['correct'] += 1
        
        # Calculate accuracies
        overall_correct = sum(p == t for p, t in zip(all_predictions, all_targets))
        overall_total = len(all_predictions)
        overall_acc = 100 * overall_correct / overall_total if overall_total > 0 else 0
        
        type_accuracies = {}
        for answer_type, stats in type_stats.items():
            if stats['total'] > 0:
                acc = 100 * stats['correct'] / stats['total']
                type_accuracies[answer_type] = acc
        
        # Print detailed results
        print("DETAILED EVALUATION RESULTS")
        print(f"Overall Accuracy: {overall_acc:.2f}% ({overall_correct}/{overall_total})")
        print(f"\nPerrformance on Answer Types:")
        
        for answer_type in sorted(type_stats.keys()):
            stats = type_stats[answer_type]
            if stats['total'] > 0:
                acc = type_accuracies[answer_type]
                print(f"  {answer_type:12s}: {acc:6.2f}% ({stats['correct']:4d}/{stats['total']:4d})")
        
        # Prepare results dictionary
        results = {
            'overall_accuracy': overall_acc,
            'overall_correct': overall_correct,
            'overall_total': overall_total,
            'type_accuracies': type_accuracies,
            'type_stats': {
                answer_type: {
                    'accuracy': type_accuracies.get(answer_type, 0),
                    'correct': stats['correct'],
                    'total': stats['total']
                }
                for answer_type, stats in type_stats.items()
            },
            'predictions': all_predictions,
            'targets': all_targets,
            'ids': all_ids
        }
        
        return results
    
    def train(self, num_epochs=100, threshold=15, save_every=10):
        print("TRAINING FINAL MODEL WITH BEST HYPERPARAMETERS")
        print(f"Training for up to {num_epochs} epochs")
        print(f"Early stopping threshold: {threshold} epochs")

        print(f"\nBest hyperparameters:")
        print(json.dumps(self.best_params, indent=2))
        
        # 1. Create model with best hyperparameters
        model = VQA_ResNet_BiLSTM_Attention(
            vocab_size=self.vocab_size,
            num_classes=self.num_classes,
            embed_dim=self.best_params['best_params']['embed_dim'],
            lstm_hidden=self.best_params['best_params']['lstm_hidden'],
            lstm_num_layers=self.best_params['best_params']['lstm_num_layers'],
            lstm_dropout=self.best_params['best_params']['lstm_dropout'],
            pooling_strategy=self.best_params['best_params']['pooling_strategy'],
            attention_heads=self.best_params['best_params']['attention_heads'],
            fusion_dim=self.best_params['best_params']['fusion_dim'],
            fusion_dropout=self.best_params['best_params']['fusion_dropout'],
        ).to(device)
        
        # 2. Create dataloaders with best batch size
        batch_size = self.best_params['best_params']['batch_size']
        
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        # 3. Setup optimizer and scheduler with best parameters
        criterion = nn.CrossEntropyLoss()
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.best_params['best_params']['learning_rate'],
            weight_decay=self.best_params['best_params']['weight_decay']
        )
        
        # Scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.best_params['best_params']['scheduler_step_size'],
            gamma=self.best_params['best_params']['scheduler_gamma']
        )
        
        # 4. Training loop
        best_val_acc = 0.0
        best_epoch = 0
        threshold_counter = 0
        
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            
            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer
            )
            
            # Validate
            val_loss, val_acc = validate(
                model, val_loader, criterion
            )
            
            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(current_lr)
            
            # Print metrics
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
            print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
            print(f"Learning Rate: {current_lr:.6f}")
            
            # Learning rate scheduling
            scheduler.step()
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_epoch = epoch + 1
                threshold_counter = 0
                
                # Save best model
                self.save_checkpoint(
                    model, optimizer, epoch, val_acc, 
                    filename='best_model.pth'
                )
                print(f"New best model found with Val Acc: {val_acc:.2f}%")
            else:
                threshold_counter += 1
                print(f"No improvement ({threshold_counter}/{threshold})")
            
            # Save periodic checkpoint
            if (epoch + 1) % save_every == 0:
                self.save_checkpoint(
                    model, optimizer, epoch, val_acc,
                    filename=f'checkpoint_epoch_{epoch+1}.pth'
                )
            
            # Early stopping check
            if threshold_counter >= threshold:
                print(f"Early stopping triggered at epoch {epoch+1}")
                print(f"Best validation accuracy: {best_val_acc:.2f}% at epoch {best_epoch}")
                break
        
        # 5. Load best model and evaluate on test set
        print("FINAL EVALUATION ON TEST SET")
        
        self.load_checkpoint(model, 'best_model.pth')
        test_results = self.final_evaluation(model)
        
        # 6. Save training history and results
        self.save_results(test_results, best_epoch, best_val_acc)
        
        # 7. Plot training curves
        plot_training_curves(
            self.history,
            FINAL_MODEL_PATH,
            'baseline'
        )
        
        print("TRAINING COMPLETE!")
        print(f"Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
        print(f"\nTest Set Results:")
        print(f"  Overall Accuracy: {test_results['overall_accuracy']:.2f}%")
        if 'type_accuracies' in test_results:
            print(f"\n  By Answer Type:")
            for answer_type in sorted(test_results['type_accuracies'].keys()):
                acc = test_results['type_accuracies'][answer_type]
                total = test_results['type_stats'][answer_type]['total']
                print(f"    {answer_type:12s}: {acc:6.2f}% ({total:4d} samples)")
        print(f"\nResults saved to: {self.results_dir}")
        print(f"{'='*70}")
        
        return model, test_results
    
    def save_checkpoint(self, model, optimizer, epoch, val_acc, filename):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'best_params': self.best_params,
            'history': self.history
        }
        torch.save(checkpoint, self.results_dir / filename)
    
    def load_checkpoint(self, model, filename):
        checkpoint = torch.load(self.results_dir / filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from {filename} (Epoch {checkpoint['epoch']}, Val Acc: {checkpoint['val_acc']:.2f}%)")
    
    def save_results(self, test_results, best_epoch, best_val_acc):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        results = {
            'timestamp': timestamp,
            'best_hyperparameters': self.best_params,
            'best_epoch': best_epoch,
            'best_val_acc': best_val_acc,
            'test_results': {
                'overall_accuracy': test_results['overall_accuracy'],
                'overall_correct': test_results['overall_correct'],
                'overall_total': test_results['overall_total'],
                'type_accuracies': test_results['type_accuracies'],
                'type_stats': test_results['type_stats']
            },
            'training_history': self.history
        }
        
        with open(self.results_dir / f'final_results_{timestamp}.json', 'w') as f:
            json.dump(results, f, indent=2)

In [None]:
final_model_trainer = FinalModelTrainer(
    train_dataset,
    validation_dataset,
    test_dataset,
    best_params,
    len(question_vocab),
    len(answer_vocab)
)

In [None]:
final_model, test_results = final_model_trainer.train(
    num_epochs=100,
    threshold=15,
    save_every=10
)

## Comprehensive Metrics Evaluation on Final Model

The model will be compared against a generative model, so we need more metrics to find a common ground </br>
Metrics that will be calculated and the reasonings: </br>
1. Classification Metrics
    * Accuracy: Overall correctness
    * F1 Score: Both Macro and Weighted -> Balances between Precision and Recall scores
2. Text Generation Metrics
    * BLEU: Standard for NLP, captures exact correctness, checks N-gram overlap. Up to 4-grams will be calculated.
    * METEOR: Semantic focused, accounts for synonyms
    * BERTScore: Checks semantic similarities
    * Rouge: Emphasizes recall
    * Exact string matching: Strictly match string outputs
3. Answer Type Specific Metrics
    * Accuracy, F1 and exact matching for OPEN and CLOSED answer types

In [None]:
class AllMetricsCalculator:
    def __init__(self, model, test_dataset, answer_vocab):
        self.model = model
        self.test_dataset = test_dataset
        self.answer_vocab = answer_vocab

        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.results_dir = Path(FINAL_MODEL_PATH)

    def get_predictions(self, batch_size):
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=slake_collate_fn,
            pin_memory=True
        )
        
        self.model.eval()
        
        all_predictions = []
        all_targets = []
        all_pred_texts = []
        all_target_texts = []
        all_question_types = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc='Getting predictions'):
                images = batch['image'].to(device)
                questions = batch['question'].to(device)
                question_lengths = batch['question_lengths'].to(device)
                answers = batch['answer'].to(device)
                
                # Forward pass
                logits = self.model(images, questions, question_lengths)
                predictions = torch.argmax(logits, dim=1)
                
                # Convert to text
                for pred_idx, target_idx, qid in zip(predictions.cpu().tolist(), 
                                                       answers.cpu().tolist(), 
                                                       batch['id']):
                    pred_text = self.answer_vocab.itos.get(pred_idx, '<unk>')
                    target_text = self.answer_vocab.itos.get(target_idx, '<unk>')
                    
                    all_predictions.append(pred_idx)
                    all_targets.append(target_idx)
                    all_pred_texts.append(pred_text)
                    all_target_texts.append(target_text)
                    
                    # Get question type
                    item = next((x for x in self.test_dataset.data if x['qid'] == qid), None)
                    if item:
                        q_type = item.get('answer_type', 'UNKNOWN').upper()
                        all_question_types.append(q_type)
                    else:
                        all_question_types.append('UNKNOWN')
        
        return all_predictions, all_targets, all_pred_texts, all_target_texts, all_question_types
    
    def calculate_classification_metrics(self, predictions, targets):
        metrics = {
            'accuracy': accuracy_score(targets, predictions) * 100,
            'macro_f1': f1_score(targets, predictions, average='macro') * 100,
            'weighted_f1': f1_score(targets, predictions, average='weighted') * 100,
        }
        
        return metrics
    
    def calculate_bleu_scores(self, predictions, references):
        smoothing = SmoothingFunction().method1
        
        bleu_scores = {
            'bleu1': [],
            'bleu2': [],
            'bleu3': [],
            'bleu4': []
        }
        
        for pred, ref in zip(predictions, references):
            # Tokenize
            pred_tokens = pred.lower().split()
            ref_tokens = [ref.lower().split()]
            
            # Calculate BLEU scores up to 4-grams
            try:
                bleu1 = sentence_bleu(ref_tokens, pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing)
                bleu2 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
                bleu3 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing)
                bleu4 = sentence_bleu(ref_tokens, pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
                
                bleu_scores['bleu1'].append(bleu1)
                bleu_scores['bleu2'].append(bleu2)
                bleu_scores['bleu3'].append(bleu3)
                bleu_scores['bleu4'].append(bleu4)
            except:
                bleu_scores['bleu1'].append(0.0)
                bleu_scores['bleu2'].append(0.0)
                bleu_scores['bleu3'].append(0.0)
                bleu_scores['bleu4'].append(0.0)
        
        # Average scores
        metrics = {
            'bleu1': np.mean(bleu_scores['bleu1']) * 100,
            'bleu2': np.mean(bleu_scores['bleu2']) * 100,
            'bleu3': np.mean(bleu_scores['bleu3']) * 100,
            'bleu4': np.mean(bleu_scores['bleu4']) * 100,
        }
        
        return metrics
    
    def calculate_rouge_scores(self, predictions, references):
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []
        
        for pred, ref in zip(predictions, references):
            scores = self.rouge_scorer.score(ref, pred)
            rouge1_scores.append(scores['rouge1'].fmeasure)
            rouge2_scores.append(scores['rouge2'].fmeasure)
            rougeL_scores.append(scores['rougeL'].fmeasure)
        
        return {
            'rouge1': np.mean(rouge1_scores) * 100,
            'rouge2': np.mean(rouge2_scores) * 100,
            'rougeL': np.mean(rougeL_scores) * 100,
        }
    
    def calculate_meteor_score(self, predictions, references):
        meteor_scores = []
        
        for pred, ref in zip(predictions, references):
            pred_tokens = pred.lower().split()
            ref_tokens = ref.lower().split()
            
            try:
                score = meteor_score([ref_tokens], pred_tokens)
                meteor_scores.append(score)
            except:
                meteor_scores.append(0.0)
        
        return {
            'meteor': np.mean(meteor_scores) * 100
        }
    
    def calculate_bertscore(self, predictions, references):
        # Using a lightweight model for faster computation instead of full BERT
        P, R, F1 = bert_score(
            predictions, 
            references, 
            lang='en',
            model_type='distilbert-base-uncased',
            verbose=False
        )
        
        return {
            'bertscore_precision': P.mean().item() * 100,
            'bertscore_recall': R.mean().item() * 100,
            'bertscore_f1': F1.mean().item() * 100,
        }
    
    def calculate_exact_match(self, predictions, references):
        exact_matches = sum(pred.lower().strip() == ref.lower().strip() 
                          for pred, ref in zip(predictions, references))
        
        return {
            'exact_match': (exact_matches / len(predictions)) * 100
        }
    
    def calculate_type_specific_metrics(self, predictions, targets, pred_texts, target_texts, question_types):
        type_metrics = defaultdict(lambda: {
            'predictions': [],
            'targets': [],
            'pred_texts': [],
            'target_texts': []
        })
        
        # Group by type
        for pred, target, pred_text, target_text, q_type in zip(
            predictions, targets, pred_texts, target_texts, question_types
        ):
            type_metrics[q_type]['predictions'].append(pred)
            type_metrics[q_type]['targets'].append(target)
            type_metrics[q_type]['pred_texts'].append(pred_text)
            type_metrics[q_type]['target_texts'].append(target_text)
        
        # Calculate metrics for each type
        results = {}
        for q_type, data in type_metrics.items():
            if len(data['predictions']) > 0:
                results[q_type] = {
                    'accuracy': accuracy_score(data['targets'], data['predictions']) * 100,
                    'f1': f1_score(data['targets'], data['predictions'], average='macro', zero_division=0) * 100,
                    'exact_match': sum(p.lower() == t.lower() for p, t in zip(data['pred_texts'], data['target_texts'])) / len(data['pred_texts']) * 100,
                    'count': len(data['predictions'])
                }
        
        return results

    def evaluate_all_metrics(self, batch_size):
        print("CALCULATING ALL METRICS")

        # Get predictions and ground truth
        predictions, targets, pred_texts, target_texts, question_types = self.get_predictions(batch_size)
        
        # Calculate all metrics
        metrics = {}
        
        # 1. Classification Metrics (Accuracy, F1, Precision, Recall)
        print("\nCalculating Classification Metrics...")
        metrics['classification'] = self.calculate_classification_metrics(predictions, targets)
        
        # 2. BLEU Scores
        print("Calculating BLEU (1-4) Scores...")
        metrics['bleu'] = self.calculate_bleu_scores(pred_texts, target_texts)

        # 3. Rouge Scores
        print("Calculating ROUGE Scores...")
        metrics['rouge'] = self.calculate_rouge_scores(pred_texts, target_texts)
        
        # 3. METEOR Score
        print("Calculating METEOR Score...")
        metrics['meteor'] = self.calculate_meteor_score(pred_texts, target_texts)
        
        # 4. BERTScore
        print("Calculating BERTScore... (Might take some time)")
        metrics['bertscore'] = self.calculate_bertscore(pred_texts, target_texts)
        
        # 5. Exact Match
        print("Calculating Exact Match...")
        metrics['exact_match'] = self.calculate_exact_match(pred_texts, target_texts)
        
        # 6. Type-specific metrics
        print("Calculating Type-Specific Metrics...")
        metrics['by_type'] = self.calculate_type_specific_metrics(
            predictions, targets, pred_texts, target_texts, question_types
        )
        
        # Print results
        print_all_metrics(metrics)
        
        # Save results
        self.save_metrics(metrics)
        
        # Create comparison plots
        plot_type_specific_comparison(metrics['by_type'], FINAL_MODEL_PATH)
        plot_ngram_analysis(metrics['bleu'], FINAL_MODEL_PATH)
        
        return metrics

    def save_metrics(self, metrics):
        output_file = self.results_dir / 'all_metrics_baseline.json'
        
        with open(output_file, 'w') as f:
            json.dump(metrics, f, indent=2)
        
        print(f"\nAll Metrics saved to: {output_file}")

    def load_metrics(self):
        input_file = self.results_dir / 'all_metrics_baseline.json'

        with open(input_file, 'r') as f:
            return json.load(f)

In [None]:
# Load the best model
model = VQA_ResNet_BiLSTM_Attention(
    vocab_size=len(question_vocab),
    num_classes=len(answer_vocab),
    embed_dim=best_params['best_params']['embed_dim'],
    lstm_hidden=best_params['best_params']['lstm_hidden'],
    lstm_num_layers=best_params['best_params']['lstm_num_layers'],
    lstm_dropout=best_params['best_params']['lstm_dropout'],
    pooling_strategy=best_params['best_params']['pooling_strategy'],
    attention_heads=best_params['best_params']['attention_heads'],
    fusion_dim=best_params['best_params']['fusion_dim'],
    fusion_dropout=best_params['best_params']['fusion_dropout'],
).to(device)

checkpoint = torch.load(os.path.join(FINAL_MODEL_PATH, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# Create evaluator
evaluator = AllMetricsCalculator(
    model=model,
    test_dataset=test_dataset,
    answer_vocab=answer_vocab
)

In [None]:
metrics = evaluator.evaluate_all_metrics(batch_size=best_params['best_params']['batch_size'])

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=best_params['best_params']['batch_size'],
    shuffle=False,
    collate_fn=slake_collate_fn,
    pin_memory=True
)

In [None]:
batch = next(iter(test_loader))

In [None]:
all_predictions = []
all_targets = []
all_pred_texts = []
all_target_texts = []
all_question_types = []


with torch.no_grad():
    images = batch['image'].to(device)
    questions = batch['question'].to(device)
    question_lengths = batch['question_lengths'].to(device)
    answers = batch['answer'].to(device)
    
    # Forward pass
    logits = model(images, questions, question_lengths)
    predictions = torch.argmax(logits, dim=1)
    
    # Convert to text
    for pred_idx, target_idx, qid in zip(predictions.cpu().tolist(), 
                                            answers.cpu().tolist(), 
                                            batch['id']):
        pred_text = answer_vocab.itos.get(pred_idx, '<unk>')
        target_text = answer_vocab.itos.get(target_idx, '<unk>')
        
        all_predictions.append(pred_idx)
        all_targets.append(target_idx)
        all_pred_texts.append(pred_text)
        all_target_texts.append(target_text)
        
        # Get question type
        item = next((x for x in test_dataset.data if x['qid'] == qid), None)
        if item:
            q_type = item.get('answer_type', 'UNKNOWN').upper()
            all_question_types.append(q_type)
        else:
            all_question_types.append('UNKNOWN')

In [None]:
# Get all predictions
predictions, targets, pred_texts, target_texts, question_types = evaluator.get_predictions(batch_size=best_params['best_params']['batch_size'])

In [None]:
questions = [item['question'] for item in test_dataset.data if item is not None]
images = [item['img_name'] for item in test_dataset.data if item is not None]

# Print a few to verify
print("Sample Questions:")
for i in range(min(5, len(questions))):
    print(f"  Question: {questions[i]}")
    print(f"  Predicted Answer: {pred_texts[i]}")
    print(f"  Ground Truth Answer: {target_texts[i]}")

In [None]:
# Get all wrong predictions
wrong_predictions = []
for i in range(len(pred_texts)):
    if pred_texts[i].lower().strip() != target_texts[i].lower().strip():
        wrong_predictions.append({
            'question': questions[i],
            'ground_truth': target_texts[i],
            'prediction': pred_texts[i],
            'image': images[i],
            'question_type': question_types[i]
        })

len(wrong_predictions)

In [None]:
class DirectionalErrorAnalyzer:
    def __init__(self):
        self.directional_terms = {
            'lateral': ['left', 'right'],
            'vertical': ['upper', 'lower', 'top', 'bottom'],
            'medial': ['central', 'center', 'middle']
        }
    
    def extract_directional_info(self, text):
        text_lower = text.lower().strip()
        
        directions = {
            'lateral': [],
            'vertical': [],
            'medial': []
        }
        
        for direction_type, terms in self.directional_terms.items():
            for term in terms:
                if re.search(r'\b' + term + r'\b', text_lower):
                    directions[direction_type].append(term)
        
        return directions
    
    def compare_directions(self, gt_directions, pred_directions):
        for direction_type in ['lateral', 'vertical', 'medial']:
            gt_terms_array = gt_directions[direction_type]
            pred_terms_array = pred_directions[direction_type]
            gt_terms = set(gt_directions[direction_type])
            pred_terms = set(pred_directions[direction_type])

            if direction_type == 'medial':
                if not gt_terms and pred_terms:
                    return 'extra_medial_direction'
                elif gt_terms and not pred_terms:
                    return 'missing_medial_direction'
                
            if direction_type == 'lateral':
                if len(gt_terms_array) < len(pred_terms_array):
                    return 'extra_lateral_direction'
                elif len(gt_terms_array) > len(pred_terms_array):
                    return 'missing_lateral_direction'
                else:
                    for i in range(len(gt_terms_array)):
                        if gt_terms_array[i] != pred_terms_array[i]:
                            return 'left_right_confusion'
                        
            if direction_type == 'vertical':
                if len(gt_terms_array) < len(pred_terms_array):
                    return 'extra_vertical_direction'
                elif len(gt_terms_array) > len(pred_terms_array):
                    return 'missing_vertical_direction'
                else:
                    for i in range(len(gt_terms_array)):
                        if gt_terms_array[i] != pred_terms_array[i]:
                            return 'upper_lower_confusion'

        return None
    
    def has_directional_info(self, text):
        directions = self.extract_directional_info(text)
        return any(len(terms) > 0 for terms in directions.values())
    
    def analyze_predictions(self, predictions, ground_truths, questions, images):
        results = {
            'errors': {
                'left_right_confusion': [],
                'upper_lower_confusion': [],
                'partial_match': [],
                'extra_medial_direction': [],
                'missing_medial_direction': [],
                'missing_lateral_direction': [],
                'extra_lateral_direction': [],
                'missing_vertical_direction': [],
                'extra_vertical_direction': [],
                'missing_direction': [],
                'extra_direction': []
            },
            'correct': [],
            'non_directional': []
        }
        
        for pred, gt, question, img in zip(predictions, ground_truths, questions, images):
            pred_lower = pred.lower().strip()
            gt_lower = gt.lower().strip()
            
            has_gt_direction = self.has_directional_info(gt)
            has_pred_direction = self.has_directional_info(pred)
            
            if not has_gt_direction and not has_pred_direction:
                results['non_directional'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'correct': pred_lower == gt_lower
                })
                continue

            if not has_gt_direction and has_pred_direction:
                results['errors']['extra_direction'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred
                })
                continue
            
            if pred_lower == gt_lower:
                results['correct'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred
                })
                continue
            
            gt_directions = self.extract_directional_info(gt)
            pred_directions = self.extract_directional_info(pred)
            
            if not has_pred_direction:
                results['errors']['missing_direction'].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'gt_directions': gt_directions
                })
                continue
            
            error_type = self.compare_directions(gt_directions, pred_directions)
            
            if error_type:
                results['errors'][error_type].append({
                    'image': img,
                    'question': question,
                    'ground_truth': gt,
                    'prediction': pred,
                    'gt_directions': gt_directions,
                    'pred_directions': pred_directions
                })
            else:
                if gt_directions != pred_directions:
                    results['errors']['partial_match'].append({
                        'image': img,
                        'question': question,
                        'ground_truth': gt,
                        'prediction': pred,
                        'gt_directions': gt_directions,
                        'pred_directions': pred_directions
                    })
                else:
                    results['correct'].append({
                        'image': img,
                        'question': question,
                        'ground_truth': gt,
                        'prediction': pred,
                        'note': 'Directionally correct but overall wrong'
                    })
        
        return results
    
    def calculate_statistics(self, results):
        total_directional = (
            len(results['correct']) + 
            sum(len(errors) for errors in results['errors'].values())
        )
        
        if total_directional == 0:
            return None
        
        stats = {
            'total_directional_questions': total_directional,
            'correct': len(results['correct']),
            'directional_accuracy': 100 * len(results['correct']) / total_directional,
            'error_breakdown': {}
        }
        
        for error_type, errors in results['errors'].items():
            count = len(errors)
            stats['error_breakdown'][error_type] = {
                'count': count,
                'percentage': 100 * count / total_directional
            }
        
        total_errors = sum(len(errors) for errors in results['errors'].values())
        stats['total_errors'] = total_errors
        stats['error_rate'] = 100 * total_errors / total_directional
        
        return stats
    
    def print_analysis(self, results, stats):
        print("DIRECTIONAL ERROR ANALYSIS")
        
        if stats is None:
            print("No directional questions found in dataset.")
            return
        
        print(f"\nTotal questions with directional info: {stats['total_directional_questions']}")
        print(f"Correct: {stats['correct']} ({stats['directional_accuracy']:.2f}%)")
        print(f"Errors: {stats['total_errors']} ({stats['error_rate']:.2f}%)")
        
        print("Error Breakdown:")
        
        sorted_errors = sorted(
            stats['error_breakdown'].items(),
            key=lambda x: x[1]['count'],
            reverse=True
        )
        
        for error_type, error_stats in sorted_errors:
            if error_stats['count'] > 0:
                error_name = error_type.replace('_', ' ').title()
                print(f"  {error_name:30s}: {error_stats['count']:3d} ({error_stats['percentage']:5.2f}%)")
        
        for error_type, errors in results['errors'].items():
            if len(errors) > 0:
                print(f"\n {error_type.replace('_', ' ').title()} Examples:")
                for i, example in enumerate(errors[:3]):
                    print(f"\n  Example {i+1}:")
                    print(f"    Question:     {example['question']}")
                    print(f"    Ground Truth: {example['ground_truth']}")
                    print(f"    Prediction:   {example['prediction']}")
                    
                    if 'gt_directions' in example and 'pred_directions' in example:
                        print(f"    GT Directions: {example['gt_directions']}")
                        print(f"    Pred Directions: {example['pred_directions']}")
    
    def save_detailed_report(self, results, stats, model_name, output_dir=FINAL_MODEL_PATH):
        from pathlib import Path
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)
        
        report = {
            'model': model_name,
            'statistics': stats,
            'examples': {}
        }
        
        # Add examples for each error type
        for error_type, errors in results['errors'].items():
            report['examples'][error_type] = errors[:10]  # Save first 10 of each type
        
        # Add correct examples
        report['examples']['correct'] = results['correct'][:10]
        
        output_file = output_dir / f'{model_name.lower()}_directional_analysis.json'
        with open(output_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        print(f"Detailed report saved to {output_file}")

In [None]:
def analyze_biomedblip_directional_errors(predictions, ground_truths, questions, images):
    analyzer = DirectionalErrorAnalyzer()
    
    # Run analysis
    results = analyzer.analyze_predictions(predictions, ground_truths, questions, images)
    stats = analyzer.calculate_statistics(results)
    
    # Print results
    analyzer.print_analysis(results, stats)
    
    # Save detailed report
    analyzer.save_detailed_report(results, stats, 'Baseline')
    
    return results, stats

In [None]:
results, stats = analyze_biomedblip_directional_errors(
    pred_texts,
    target_texts,
    questions,
    images
)