# Multimodal VQA Fine-tuning Notebook

This notebook provides comprehensive fine-tuning strategies for improving multimodal Visual Question Answering performance.

**Environment Support:**
- Local VS Code/Jupyter
- Google Colab  
- Other Jupyter environments

## 0. Environment Setup & Path Configuration

In [2]:
# ============================================================================
# UNIVERSAL ENVIRONMENT SETUP - Works on Colab, VS Code, and Jupyter
# ============================================================================

import os
import sys
from pathlib import Path

def setup_environment():
    """
    Automatically detect environment and setup project paths.
    Supports: Google Colab, VS Code, Local Jupyter
    """

    # Detect if running on Google Colab
    try:
        import google.colab
        IN_COLAB = True
        print("Environment detected: Google Colab")
    except ImportError:
        IN_COLAB = False
        print("Environment detected: Local (VS Code/Jupyter)")

    if IN_COLAB:
        # ========== GOOGLE COLAB SETUP ==========
        print("Setting up Google Colab environment...")

        # Mount Google Drive
        from google.colab import drive
        drive.mount('/content/drive')

        # Install required packages
        print("Installing packages...")
        os.system('pip install -q torch torchvision tqdm pyyaml scikit-learn pandas matplotlib seaborn Pillow')

        # Set project path (adjust this path to your Google Drive structure)
        #"/content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning"
        PROJECT_PATH = "/content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning"

        # Alternative paths - uncomment the one that matches your Drive structure
        # PROJECT_PATH = "/content/drive/MyDrive/data"
        # PROJECT_PATH = "/content/drive/MyDrive/WOA7015 Advanced Machine Learning/data"

        if os.path.exists(PROJECT_PATH):
            os.chdir(PROJECT_PATH)
            project_root = Path(PROJECT_PATH)
            print(f"SUCCESS: Colab project root: {PROJECT_PATH}")
        else:
            print(f"ERROR: Project path not found: {PROJECT_PATH}")
            print("Available paths in Drive:")
            base_path = "/content/drive/MyDrive"
            if os.path.exists(base_path):
                for item in os.listdir(base_path):
                    print(f"   - {os.path.join(base_path, item)}")
            raise FileNotFoundError(f"Please update PROJECT_PATH in the code to match your Google Drive structure")

    else:
        # ========== LOCAL ENVIRONMENT SETUP ==========
        print("Setting up local environment...")

        # Determine project root (works from notebooks/ subdirectory or project root)
        current_dir = Path().absolute()

        if current_dir.name == 'notebooks':
            project_root = current_dir.parent
            print("Running from notebooks/ directory")
        else:
            project_root = current_dir
            print("Running from project root directory")

        print(f"SUCCESS: Local project root: {project_root}")

    # ========== COMMON SETUP FOR ALL ENVIRONMENTS ==========

    # Add paths to Python path
    sys.path.insert(0, str(project_root))
    sys.path.insert(0, str(project_root / 'src'))

    # Verify project structure
    required_dirs = ['src', 'data', 'notebooks', 'checkpoints']
    missing_dirs = []

    for req_dir in required_dirs:
        dir_path = project_root / req_dir
        if dir_path.exists():
            print(f"Found: {req_dir}/")
        else:
            missing_dirs.append(req_dir)
            print(f"Missing: {req_dir}/")

    if missing_dirs:
        print(f"\nWarning: Some directories are missing: {missing_dirs}")
        print("   Make sure you're running from the correct project directory.")

    return project_root, IN_COLAB

# Run environment setup
project_root, is_colab = setup_environment()

print(f"\nEnvironment ready!")
print(f"   Project root: {project_root}")
print(f"   Running on Colab: {is_colab}")
print("=" * 60)

Environment detected: Google Colab
Setting up Google Colab environment...
Mounted at /content/drive
Installing packages...
SUCCESS: Colab project root: /content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning
Found: src/
Found: data/
Found: notebooks/
Found: checkpoints/

Environment ready!
   Project root: /content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning
   Running on Colab: True


### Copy this setup to other notebooks:

**For consistent environment setup across all notebooks, copy the cell above to:**
- `01_data_exploration.ipynb`  
- `02_text_baseline_training.ipynb`
- `03_multimodal_training.ipynb`
- `improved_multimodal_training.ipynb`
- Any new notebooks

**Customization for Colab:**
- Update the `PROJECT_PATH` variable in the setup cell to match your Google Drive folder structure
- Common paths:
  - `/content/drive/MyDrive/WOA7015 Advanced Machine Learning/my_projct`
  - `/content/drive/MyDrive/data`

### üöÄ Alternative: Simple One-Line Setup

If you prefer a simpler approach, you can use the standalone setup module:

In [3]:
# ============================================================================
# üöÄ ALTERNATIVE: Simple One-Line Setup (using setup_environment.py)
# ============================================================================
# Uncomment and run this instead of the detailed setup above

# from setup_environment import quick_setup
# project_root, device, is_colab, modules = quick_setup()

# # Access imported modules
# create_multimodal_dataloaders = modules['create_multimodal_dataloaders']
# ImprovedMultimodalVQA = modules['ImprovedMultimodalVQA']

# # Import other required packages
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import transforms
# import matplotlib.pyplot as plt
# import seaborn as sns
# from sklearn.metrics import accuracy_score
# from tqdm import tqdm
# import json, time, numpy as np

# print("‚úÖ One-line setup complete!")

# ============================================================================
# For Google Colab with custom path:
# project_root, device, is_colab, modules = quick_setup(
#     colab_project_path="/content/drive/MyDrive/your-custom-path"
# )
# ============================================================================

## 1. Setup and Configuration

In [4]:
# Import required libraries (using standardized project_root from above)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import yaml
import json
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import project modules (project_root already set in cell above)
try:
    from src.data.dataset import create_multimodal_dataloaders
    from src.models.improved_multimodal_model import ImprovedMultimodalVQA
    print("SUCCESS: Project modules imported successfully")
except ImportError as e:
    print(f"ERROR: Import error: {e}")
    print("   Make sure the environment setup cell above ran successfully")
    print("   and the project structure is correct")
    raise

# Device setup
print(f"GPU available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print(f"Project root: {project_root}")
print("Setup complete!")
print("=" * 60)

SUCCESS: Project modules imported successfully
GPU available: True
Using device: cuda
   GPU: Tesla T4
   Memory: 14.7 GB
Project root: /content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning
Setup complete!


## 2. Fine-tuning Configuration

In [None]:
# Fine-tuning strategy selector
FINE_TUNING_STRATEGY = "layerwise"  # Changed from "conservative" - try layerwise with better learning rates

# Enhanced configuration with better hyperparameters
FINE_TUNING_CONFIG = {
    "conservative": {
        "name": "Conservative Fine-tuning",
        "description": "Lower learning rates, maintain model stability",
        "learning_rate": 5e-6,  # Increased from 1e-6 - was too low
        "vision_lr_factor": 0.1,
        "epochs": 5,
        "batch_size": 12,
        "weight_decay": 1e-4,
        "label_smoothing": 0.05,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "light"
    },
    "layerwise": {
        "name": "Layer-wise Fine-tuning",
        "description": "Different learning rates for vision, text, and fusion components",
        "learning_rate": 2e-5,  # Increased from 5e-6 - better base learning rate
        "vision_lr_factor": 0.1,  # Vision: 2e-6 (careful with pretrained weights)
        "text_lr_factor": 0.5,   # Text: 1e-5 (moderate for LSTM)
        "fusion_lr_factor": 1.0, # Fusion: 2e-5 (highest for new layers)
        "epochs": 6,
        "batch_size": 10,
        "weight_decay": 1e-4,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "medium"
    },
    "progressive": {
        "name": "Progressive Unfreezing",
        "description": "Gradually unfreeze layers during training",
        "learning_rate": 1e-5,
        "vision_lr_factor": 0.1,
        "epochs": 8,
        "batch_size": 12,
        "weight_decay": 5e-5,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_schedule": {1: ["vision_encoder.layer4"], 3: ["vision_encoder.layer3"], 5: []},
        "augmentation_strength": "medium"
    },
    "aggressive": {
        "name": "Aggressive Fine-tuning",
        "description": "Higher learning rates for faster convergence",
        "learning_rate": 5e-5,  # Much higher learning rate
        "vision_lr_factor": 0.2,  # Vision: 1e-5
        "text_lr_factor": 0.8,   # Text: 4e-5
        "fusion_lr_factor": 1.5,  # Fusion: 7.5e-5
        "epochs": 4,
        "batch_size": 8,
        "weight_decay": 1e-5,
        "label_smoothing": 0.05,
        "dropout_factor": 0.8,
        "freeze_layers": [],
        "augmentation_strength": "light"
    },
    "regularization": {
        "name": "Regularization Tuning",
        "description": "Optimize regularization parameters",
        "learning_rate": 1e-5,
        "vision_lr_factor": 0.1,
        "epochs": 6,
        "batch_size": 12,
        "weight_decay": 5e-4,
        "label_smoothing": 0.15,
        "dropout_factor": 1.1,
        "freeze_layers": [],
        "augmentation_strength": "medium"
    },
    "architecture": {
        "name": "Architecture Tweaking",
        "description": "Minor architectural modifications with optimal learning rates",
        "learning_rate": 1.5e-5,
        "vision_lr_factor": 0.1,
        "epochs": 6,
        "batch_size": 10,
        "weight_decay": 1e-4,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "medium",
        "add_batch_norm": True,
        "increase_attention_heads": True
    }
}

config = FINE_TUNING_CONFIG[FINE_TUNING_STRATEGY]
print(f"Selected Strategy: {config['name']}")
print(f"Description: {config['description']}")
print(f"Epochs: {config['epochs']}, Learning Rate: {config['learning_rate']}")
print(f"Vision LR: {config['learning_rate'] * config['vision_lr_factor']:.2e}")
if 'text_lr_factor' in config:
    print(f"Text LR: {config['learning_rate'] * config['text_lr_factor']:.2e}")
    print(f"Fusion LR: {config['learning_rate'] * config['fusion_lr_factor']:.2e}")

Selected Strategy: Conservative Fine-tuning
Description: Lower learning rates, maintain model stability
Epochs: 5, Learning Rate: 1e-06


## 3. Load Data with Enhanced Augmentation

### üéØ Quick Strategy Selection

**Available Strategies & When to Use:**

1. **`layerwise`** ‚≠ê **(RECOMMENDED FIRST)** - Different LRs for model components
   - Vision: 2e-6, Text: 1e-5, Fusion: 2e-5
   - Best balance of stability and improvement

2. **`aggressive`** - Higher learning rates for faster convergence
   - Use if layerwise is too slow or conservative

3. **`progressive`** - Gradually unfreeze layers
   - Good for preventing catastrophic forgetting

4. **`architecture`** - Add batch norm + more attention heads
   - Try after finding good learning rates

5. **`regularization`** - Optimize dropout and regularization
   - Use if overfitting is an issue

**Quick Switch:** Just change `FINE_TUNING_STRATEGY` above and re-run from cell 10 onwards!

In [6]:
# Temporary monkey-patch to disable pin_memory for debugging CUDA device-side assert triggered
import torch.utils.data._utils.pin_memory as pin_memory_utils

# Store original function to potentially restore later
_original_pin_memory_fn = pin_memory_utils.pin_memory

def no_op_pin_memory(data, device=None):
    # This function simply returns the data, bypassing actual pinning
    return data

# Apply the monkey patch
pin_memory_utils.pin_memory = no_op_pin_memory
print("WARNING: torch.utils.data._utils.pin_memory.pin_memory has been temporarily monkey-patched to disable pinning for debugging.")
print("         This allows data errors to be caught at a higher level. Remember to revert this patch for optimal performance!")

# Load data with standard augmentation (the dataloader has its own transforms)
data_dir = project_root / 'data'
dataset_path = data_dir / 'train'

print(f"Loading data with {config['augmentation_strength']} augmentation...")

# Check if reduced vocabulary file exists, otherwise use full vocabulary
answers_file_path = data_dir / 'answers_top_1000.txt'
if not answers_file_path.exists():
    print("Reduced vocabulary file not found, using full vocabulary")
    answers_file_path = data_dir / 'answers.txt'

# Use standard dataloader without custom transforms (it has built-in augmentation)
train_loader, val_loader, test_loader, vocab_size, num_classes, vocab, answer_to_idx = create_multimodal_dataloaders(
    train_csv=str(data_dir / 'trainrenamed.csv'),
    test_csv=str(data_dir / 'testrenamed.csv'),
    image_dir=str(dataset_path),
    answers_file=str(answers_file_path),
    batch_size=config['batch_size'],
    val_split=0.1,
    num_workers=0,
    image_size=224
)

# Explicitly disable pin_memory on the DataLoader instances to prevent CUDA asserts
# (This is a more robust fix than just monkey-patching torch.utils.data._utils.pin_memory.pin_memory if DataLoader was already instantiated)
if hasattr(train_loader, 'pin_memory'):
    train_loader.pin_memory = False
    val_loader.pin_memory = False
    test_loader.pin_memory = False
    print("INFO: Explicitly disabled pin_memory for DataLoader instances.")

print(f"Data loaded: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test batches")
print(f"Vocabulary size: {vocab_size}, Classes: {num_classes}")
print(f"Using vocabulary file: {answers_file_path.name}")
print("Note: Using standard augmentation from dataloader. Custom transforms defined but not applied yet.")

         This allows data errors to be caught at a higher level. Remember to revert this patch for optimal performance!
Loading data with light augmentation...
Reduced vocabulary file not found, using full vocabulary
Loaded 19755 samples
  Vocab size: 5245
  Num classes: 4142
  Image size: 224x224
Train size: 17780, Val size: 1975
Loaded 3164 samples
  Vocab size: 5245
  Num classes: 4142
  Image size: 224x224
INFO: Explicitly disabled pin_memory for DataLoader instances.
Data loaded: 1482 train, 165 val, 264 test batches
Vocabulary size: 5245, Classes: 4142
Using vocabulary file: answers.txt
Note: Using standard augmentation from dataloader. Custom transforms defined but not applied yet.


## 4. Load Best Model and Create Fine-tuned Version

In [7]:
# The model is already imported at the top of the notebook

def create_fine_tuned_model(base_model, config):
    """Create a fine-tuned version of the model with optional architectural tweaks"""

    if config.get('add_batch_norm', False):
        # Add batch normalization layers
        print("Adding batch normalization to classifier")
        base_model.classifier = nn.Sequential(
            nn.Linear(base_model.classifier[0].in_features, base_model.classifier[0].out_features),
            nn.BatchNorm1d(base_model.classifier[0].out_features),
            nn.ReLU(),
            nn.Dropout(base_model.classifier[2].p * config.get('dropout_factor', 1.0)),
            nn.Linear(base_model.classifier[4].in_features, base_model.classifier[4].out_features)
        )

    if config.get('increase_attention_heads', False):
        # Increase attention heads if possible
        print("Increasing attention heads to 16")
        base_model.cross_attention = nn.MultiheadAttention(
            embed_dim=512,
            num_heads=16,  # Increased from 8
            dropout=0.3 * config.get('dropout_factor', 1.0)
        )

    # Adjust dropout in other layers if needed
    if config.get('dropout_factor', 1.0) != 1.0:
        factor = config['dropout_factor']
        base_model.text_dropout.p = min(0.5, base_model.text_dropout.p * factor)

        # Update classifier dropout if not already modified
        if not config.get('add_batch_norm', False):
            base_model.classifier[2].p = min(0.7, base_model.classifier[2].p * factor)

    return base_model

# Load the best checkpoint
checkpoint_dir = project_root / 'checkpoints' / 'multimodal_concat'
checkpoint_path = checkpoint_dir / 'best_model.pth'

if checkpoint_path.exists():
    print(f"Loading best model from {checkpoint_path}")

    # Create base model
    model = ImprovedMultimodalVQA(
        vocab_size=vocab_size,
        num_classes=num_classes,
        embedding_dim=300,
        text_hidden_dim=512,
        fusion_hidden_dim=512,
        dropout=0.3
    )

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    # Apply fine-tuning modifications
    model = create_fine_tuned_model(model, config)
    model = model.to(device)

    print(f"Model loaded successfully!")
    print(f"Original best validation accuracy from checkpoint: {checkpoint.get('best_val_acc', 'N/A')}")

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

else:
    raise FileNotFoundError(f"Best model checkpoint not found at {checkpoint_path}")
    print("Please run the main training notebook first to create the best model checkpoint.")

Loading best model from /content/drive/MyDrive/Colab Notebooks/WOA7015 Advanced Machine Learning/checkpoints/multimodal_concat/best_model.pth
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 195MB/s]


Model loaded successfully!
Original best validation accuracy from checkpoint: N/A
Total parameters: 33,285,611
Trainable parameters: 33,285,611


## 5. Setup Fine-tuning Optimizer and Training Strategy

In [8]:
def setup_fine_tuning_optimizer(model, config):
    """Setup optimizer with strategy-specific parameters"""

    # Categorize parameters
    vision_params = []
    text_params = []
    fusion_params = []

    for name, param in model.named_parameters():
        if any(x in name for x in ['vision_encoder', 'spatial_attention', 'vision_proj']):
            vision_params.append(param)
        elif any(x in name for x in ['text_embedding', 'text_lstm', 'text_proj', 'text_dropout']):
            text_params.append(param)
        else:
            fusion_params.append(param)

    print(f"Vision parameters: {len(vision_params)}")
    print(f"Text parameters: {len(text_params)}")
    print(f"Fusion parameters: {len(fusion_params)}")

    # Setup parameter groups based on strategy
    if FINE_TUNING_STRATEGY == "layerwise":
        param_groups = [
            {'params': vision_params, 'lr': config['learning_rate'] * config['vision_lr_factor'], 'name': 'vision'},
            {'params': text_params, 'lr': config['learning_rate'] * config.get('text_lr_factor', 0.3), 'name': 'text'},
            {'params': fusion_params, 'lr': config['learning_rate'] * config.get('fusion_lr_factor', 1.0), 'name': 'fusion'}
        ]
        print(f"Layer-wise LR: vision={param_groups[0]['lr']:.2e}, text={param_groups[1]['lr']:.2e}, fusion={param_groups[2]['lr']:.2e}")
    else:
        param_groups = [
            {'params': vision_params, 'lr': config['learning_rate'] * config['vision_lr_factor'], 'name': 'vision'},
            {'params': text_params + fusion_params, 'lr': config['learning_rate'], 'name': 'other'}
        ]
        print(f"Standard LR: vision={param_groups[0]['lr']:.2e}, other={param_groups[1]['lr']:.2e}")

    optimizer = optim.AdamW(param_groups, weight_decay=config['weight_decay'])

    return optimizer

def apply_layer_freezing(model, freeze_layers):
    """Freeze specified layers"""
    for layer_name in freeze_layers:
        for name, param in model.named_parameters():
            if layer_name in name:
                param.requires_grad = False
                print(f"Frozen: {name}")

# Setup optimizer
optimizer = setup_fine_tuning_optimizer(model, config)

# Apply initial layer freezing if specified
if config.get('freeze_layers', []):
    apply_layer_freezing(model, config['freeze_layers'])

# Setup loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
print(f"Using label smoothing: {config['label_smoothing']}")

# Setup scheduler
if FINE_TUNING_STRATEGY == "progressive":
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.5, total_iters=config['epochs']//2)
else:
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=max(2, config['epochs']//3), T_mult=2, eta_min=1e-7
    )

print(f"Setup complete for {config['name']} strategy")

Vision parameters: 165
Text parameters: 11
Fusion parameters: 8
Standard LR: vision=1.00e-07, other=1.00e-06
Using label smoothing: 0.05
Setup complete for Conservative Fine-tuning strategy


## 6. Fine-tuning Training Loop with Advanced Monitoring

In [11]:
import torch
from tqdm import tqdm


def _validate_and_process_tensor_cpu(
    tensor,
    target_dtype,
    tensor_name,
    vocab_size_check=None,
    num_classes_check=None
):
    if tensor is None or not torch.is_tensor(tensor):
        return None

    # NaN / Inf checks
    if torch.is_floating_point(tensor):
        if torch.isnan(tensor).any() or torch.isinf(tensor).any():
            print(f"ERROR: Invalid NaN/Inf in {tensor_name}")
            return None

    if target_dtype == "long":
        if tensor.min().item() < 0:
            print(f"ERROR: Negative index in {tensor_name}")
            return None

        if tensor_name == "questions" and vocab_size_check is not None:
            if tensor.max().item() >= vocab_size_check:
                print(f"ERROR: Question ID out of bounds")
                return None

        if tensor_name == "answers" and num_classes_check is not None:
            if tensor.max().item() >= num_classes_check:
                print(f"ERROR: Answer ID out of bounds")
                return None

        return tensor.long()

    if target_dtype == "float":
        return tensor.float()

    # attention mask handling
    if tensor_name == "attention_mask" and tensor.dtype == torch.bool:
        return tensor.long()

    return tensor


def _safe_get(batch, *keys):
    """Safely retrieve first existing non-None key from dict"""
    for k in keys:
        if k in batch and batch[k] is not None:
            return batch[k]
    return None


def _extract_and_validate_batch_cpu(batch, vocab_size, num_classes):
    questions = attention_mask = images = answers = None

    if isinstance(batch, dict):
        questions = _safe_get(batch, "question", "input_ids", "questions")
        attention_mask = _safe_get(batch, "attention_mask", "mask")
        images = _safe_get(batch, "image", "images", "pixel_values")
        answers = _safe_get(batch, "answer", "answers", "labels", "target")

    elif isinstance(batch, (list, tuple)):
        for item in batch:
            if not torch.is_tensor(item):
                continue

            if item.ndim == 4 and item.dtype == torch.float:
                images = images or item
            elif item.ndim == 2 and item.dtype == torch.long:
                questions = questions or item
            elif item.ndim == 2 and item.dtype in (torch.bool, torch.uint8, torch.float):
                attention_mask = attention_mask or item
            elif item.ndim == 1 and item.dtype == torch.long:
                answers = answers or item

    else:
        return None

    questions = _validate_and_process_tensor_cpu(
        questions, "long", "questions", vocab_size, num_classes
    )
    attention_mask = _validate_and_process_tensor_cpu(
        attention_mask, None, "attention_mask"
    )
    images = _validate_and_process_tensor_cpu(
        images, "float", "images"
    )
    answers = _validate_and_process_tensor_cpu(
        answers, "long", "answers", None, num_classes
    )

    if questions is None or images is None or answers is None:
        return None

    # Safe shape normalization
    if questions.ndim == 3 and questions.size(1) == 1:
        questions = questions.squeeze(1)

    if images.ndim == 5 and images.size(1) == 1:
        images = images.squeeze(1)

    if questions.ndim not in (2, 3):
        return None

    return questions, attention_mask, images, answers


def fine_tune_epoch(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    scheduler,
    device,
    epoch,
    config,
    vocab_size,
    num_classes
):
    model.train()
    train_losses = []
    train_correct = 0
    train_total = 0

    print(f"\nEpoch {epoch}/{config['epochs']}")
    print("-" * 50)

    for batch in tqdm(train_loader, desc="Training"):
        batch_data = _extract_and_validate_batch_cpu(
            batch, vocab_size, num_classes
        )
        if batch_data is None:
            continue

        questions, attention_mask, images, answers = batch_data

        questions = questions.to(device)
        images = images.to(device)
        answers = answers.to(device)
        attention_mask = attention_mask.to(device) if attention_mask is not None else None

        optimizer.zero_grad()

        try:
            outputs = (
                model(questions, images, attention_mask=attention_mask)
                if attention_mask is not None
                else model(questions, images)
            )
        except Exception as e:
            print(f"Forward error: {e}")
            continue

        if outputs.ndim != 2 or outputs.size(1) != num_classes:
            continue

        loss = criterion(outputs, answers)

        if torch.isnan(loss) or torch.isinf(loss):
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        train_losses.append(loss.item())
        _, preds = torch.max(outputs, 1)
        train_total += answers.size(0)
        train_correct += (preds == answers).sum().item()

    if not train_losses:
        return dict(
            train_loss=float("inf"),
            train_accuracy=0.0,
            val_loss=float("inf"),
            val_accuracy=0.0,
        )

    avg_train_loss = sum(train_losses) / len(train_losses)
    train_accuracy = 100.0 * train_correct / train_total

    # ================= VALIDATION =================

    model.eval()
    val_losses = []
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            batch_data = _extract_and_validate_batch_cpu(
                batch, vocab_size, num_classes
            )
            if batch_data is None:
                continue

            questions, attention_mask, images, answers = batch_data

            questions = questions.to(device)
            images = images.to(device)
            answers = answers.to(device)
            attention_mask = attention_mask.to(device) if attention_mask is not None else None

            try:
                outputs = (
                    model(questions, images, attention_mask=attention_mask)
                    if attention_mask is not None
                    else model(questions, images)
                )
            except Exception:
                continue

            if outputs.ndim != 2 or outputs.size(1) != num_classes:
                continue

            loss = criterion(outputs, answers)
            val_losses.append(loss.item())

            _, preds = torch.max(outputs, 1)
            val_total += answers.size(0)
            val_correct += (preds == answers).sum().item()

    avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else float("inf")
    val_accuracy = 100.0 * val_correct / val_total if val_total else 0.0

    if scheduler is not None:
        try:
            scheduler.step(avg_val_loss)
        except TypeError:
            scheduler.step()

    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"Val   Loss: {avg_val_loss:.4f} | Val   Acc: {val_accuracy:.2f}%")
    print("=" * 50)

    return dict(
        train_loss=avg_train_loss,
        train_accuracy=train_accuracy,
        val_loss=avg_val_loss,
        val_accuracy=val_accuracy,
    )



## 7. Execute Fine-tuning

In [None]:
print(f"\n{'='*20} Starting Fine-tuning ({config['name']}) {'='*20}")

fine_tune_history = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': []
}

best_fine_tune_acc = 0.0
ft_checkpoint_path = checkpoint_dir / f"fine_tuned_best_model_{FINE_TUNING_STRATEGY}.pth"

start_time = time.time()

for epoch in range(1, config['epochs'] + 1):

    epoch_results = fine_tune_epoch(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        device=device,
        epoch=epoch,
        config=config,
        vocab_size=vocab_size,
        num_classes=num_classes
    )

    fine_tune_history['train_losses'].append(epoch_results['train_loss'])
    fine_tune_history['train_accuracies'].append(epoch_results['train_accuracy'])
    fine_tune_history['val_losses'].append(epoch_results['val_loss'])
    fine_tune_history['val_accuracies'].append(epoch_results['val_accuracy'])

    # Save best model
    if epoch_results['val_accuracy'] > best_fine_tune_acc:
        best_fine_tune_acc = epoch_results['val_accuracy']

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_fine_tune_acc,
                'config': config,
                'fine_tuning_strategy': FINE_TUNING_STRATEGY,
            },
            ft_checkpoint_path
        )

        print(f"--> Saved best model with validation accuracy: {best_fine_tune_acc:.2f}%")

training_time = time.time() - start_time

print(f"\n{'='*20} Fine-tuning Complete {'='*20}")
print(f"Total Fine-tuning Time: {training_time / 60:.2f} minutes")
print(f"Best Validation Accuracy Achieved: {best_fine_tune_acc:.2f}%")
print(f"Best model saved to: {ft_checkpoint_path}")





Epoch 1/5
--------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1482/1482 [09:00<00:00,  2.74it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 165/165 [00:45<00:00,  3.59it/s]


Train Loss: 6.2360 | Train Acc: 27.13%
Val   Loss: 5.7744 | Val   Acc: 24.96%
--> Saved best model with validation accuracy: 24.96%

Epoch 2/5
--------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1482/1482 [09:00<00:00,  2.74it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 165/165 [00:46<00:00,  3.56it/s]


Train Loss: 5.8915 | Train Acc: 26.11%
Val   Loss: 5.7027 | Val   Acc: 24.96%

Epoch 3/5
--------------------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1482/1482 [08:59<00:00,  2.75it/s]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 165/165 [00:46<00:00,  3.52it/s]


Train Loss: 5.8600 | Train Acc: 25.77%
Val   Loss: 5.6688 | Val   Acc: 24.96%

Epoch 4/5
--------------------------------------------------


Training:  13%|‚ñà‚ñé        | 190/1482 [01:09<07:31,  2.86it/s]

## 8. Comprehensive Evaluation and Comparison

In [None]:
# Comprehensive evaluation function
def evaluate_fine_tuned_model(model, test_loader, device, strategy_name, vocab_size, num_classes):
    """Comprehensive evaluation of fine-tuned model"""
    model.eval()
    all_predictions = []
    all_targets = []
    test_loss = 0.0

    criterion = nn.CrossEntropyLoss()

    print(f"Evaluating {strategy_name} fine-tuned model on test set...")

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Testing")
        for batch_idx, batch in enumerate(pbar):
            _SKIP_BATCH = False
            questions_cpu, attention_mask_cpu, images_cpu, answers_cpu = None, None, None, None

            # Attempt to extract tensors flexibly (on CPU first)
            try:
                if isinstance(batch, dict):
                    questions_raw = batch.get('question') or batch.get('input_ids') or batch.get('questions')
                    attention_mask_raw = batch.get('attention_mask') or batch.get('mask')
                    images_raw = batch.get('image') or batch.get('images') or batch.get('pixel_values')
                    answers_raw = batch.get('answer') or batch.get('answers') or batch.get('labels') or batch.get('target')

                    questions_cpu = _validate_and_process_tensor_cpu(questions_raw, 'long', 'questions', vocab_size, num_classes)
                    attention_mask_cpu = _validate_and_process_tensor_cpu(attention_mask_raw, None, 'attention_mask', vocab_size, num_classes)
                    images_cpu = _validate_and_process_tensor_cpu(images_raw, 'float', 'images', vocab_size, num_classes)
                    answers_cpu = _validate_and_process_tensor_cpu(answers_raw, 'long', 'answers', vocab_size, num_classes)

                elif isinstance(batch, (list, tuple)):
                    found_questions = None
                    found_attention_mask = None
                    found_images = None
                    found_answers = None

                    for item in batch:
                        if not torch.is_tensor(item):
                            continue # Skip non-tensor items

                        if item.ndim == 4 and item.dtype == torch.float: # Likely images (Batch, C, H, W)
                            if found_images is None: found_images = item
                        elif item.ndim == 2 and item.dtype == torch.long: # Likely questions (Batch, SeqLen)
                            if found_questions is None: found_questions = item
                        elif item.ndim == 2 and (item.dtype == torch.bool or item.dtype == torch.uint8 or item.dtype == torch.float): # Likely attention mask
                            if found_attention_mask is None: found_attention_mask = item
                        elif item.ndim == 1 and item.dtype == torch.long: # Likely answers (Batch)
                            if found_answers is None: found_answers = item

                    questions_cpu = _validate_and_process_tensor_cpu(found_questions, 'long', 'questions', vocab_size, num_classes)
                    attention_mask_cpu = _validate_and_process_tensor_cpu(found_attention_mask, None, 'attention_mask', vocab_size, num_classes)
                    images_cpu = _validate_and_process_tensor_cpu(found_images, 'float', 'images', vocab_size, num_classes)
                    answers_cpu = _validate_and_process_tensor_cpu(found_answers, 'long', 'answers', vocab_size, num_classes)

                # If any CPU validation failed, skip batch
                if questions_cpu is None or images_cpu is None or answers_cpu is None:
                    _SKIP_BATCH = True

                if not _SKIP_BATCH:
                    # Safe squeeze operations on CPU tensors
                    questions_cpu = questions_cpu.squeeze() if questions_cpu.ndim > 2 else questions_cpu
                    images_cpu = images_cpu.squeeze() if images_cpu.ndim > 4 else images_cpu

            except Exception as e:
                print(f"ERROR: Exception during batch unpacking/CPU validation at evaluation batch {batch_idx}: {e}")
                print(f"Batch content type: {type(batch)}")
                if isinstance(batch, dict): print(f"Batch keys: {batch.keys()}")
                else: print(f"Batch length: {len(batch) if hasattr(batch, '__len__') else 'N/A'}")
                _SKIP_BATCH = True

            if _SKIP_BATCH:
                continue

            # --- Move validated CPU tensors to device ---
            questions = questions_cpu.to(device)
            attention_mask = attention_mask_cpu.to(device) if attention_mask_cpu is not None else None
            images = images_cpu.to(device)
            answers = answers_cpu.to(device)

            # Ensure questions tensor is the right shape for LSTM (2D or 3D)
            if questions.ndim != 2 and questions.ndim != 3:
                print(f"ERROR: Questions tensor has {questions.ndim} dimensions (expected 2 or 3) at evaluation batch {batch_idx} after squeeze. Shape: {questions.shape}")
                _SKIP_BATCH = True

            if _SKIP_BATCH:
                continue

            try:
                if attention_mask is not None:
                    outputs = model(questions, images, attention_mask=attention_mask)
                else:
                    outputs = model(questions, images)
            except TypeError:
                outputs = model(questions, images)
            except Exception as e:
                print(f"ERROR: Exception during model forward pass in evaluation batch {batch_idx}: {e}")
                _SKIP_BATCH = True
            if _SKIP_BATCH:
                continue

            # Validate model outputs
            if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                print(f"ERROR: Invalid values in model outputs at evaluation batch {batch_idx}")
                _SKIP_BATCH = True
            if outputs.shape[1] != num_classes:
                print(f"ERROR: Model output shape mismatch: {outputs.shape[1]} vs {num_classes} at evaluation batch {batch_idx}")
                _SKIP_BATCH = True
            if _SKIP_BATCH:
                continue

            loss = criterion(outputs, answers)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"ERROR: Invalid loss value at evaluation batch {batch_idx}: {loss}")
                _SKIP_BATCH = True
            if _SKIP_BATCH:
                continue

            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(answers.cpu().numpy())

    # Check if any batches were processed successfully
    if len(test_loader) == 0 or len(all_predictions) == 0:
        print("WARNING: No valid batches processed during evaluation. Returning default metrics.")
        return {
            'accuracy': 0.0,
            'avg_loss': float('inf'),
            'predictions': [],
            'targets': []
        }

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    avg_loss = test_loss / len(test_loader)

    return {
        'accuracy': accuracy,
        'avg_loss': avg_loss,
        'predictions': all_predictions,
        'targets': all_targets
    }

# Evaluate fine-tuned model
ft_results = evaluate_fine_tuned_model(model, test_loader, device, config['name'], vocab_size, num_classes)

print("\n" + "=" * 60)
print(f"FINE-TUNING RESULTS - {config['name'].upper()}")
print("=" * 60)
print(f"Test Accuracy: {ft_results['accuracy']:.4f} ({ft_results['accuracy']*100:.2f}%)")
print(f"Test Loss: {ft_results['avg_loss']:.4f}")

# Load original results for comparison
baseline_results_path = project_root / 'results' / 'text_baseline_results.json'
original_multimodal_path = project_root / 'results' / 'improved_multimodal_results.json'

print("\nPerformance Comparison:")
print("-" * 30)

if baseline_results_path.exists():
    with open(baseline_results_path, 'r') as f:
        baseline_results = json.load(f)
    baseline_acc = baseline_results['test_accuracy'] # Corrected key
    print(f"Text Baseline:        {baseline_acc:.4f} ({baseline_acc*100:.2f}%)")

    improvement_vs_baseline = (ft_results['accuracy'] - baseline_acc) * 100
    print(f"vs Text Baseline:     {improvement_vs_baseline:+.2f} pp")

if original_multimodal_path.exists():
    with open(original_multimodal_path, 'r') as f:
        original_results = json.load(f)
    original_acc = original_results['test_metrics']['accuracy']
    print(f"Original Multimodal:  {original_acc:.4f} ({original_acc*100:.2f}%)")

    improvement_vs_original = (ft_results['accuracy'] - original_acc) * 100
    print(f"vs Original:          {improvement_vs_original:+.2f} pp")

print(f"Fine-tuned Model:     {ft_results['accuracy']:.4f} ({ft_results['accuracy']*100:.2f}%)")
print("=" * 60)

## 9. Visualization and Analysis

In [None]:
# Training history visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, len(fine_tune_history['train_losses']) + 1)

# Loss curves
ax1.plot(epochs_range, fine_tune_history['train_losses'], 'b-', label='Training Loss', linewidth=2)
ax1.plot(epochs_range, fine_tune_history['val_losses'], 'r-', label='Validation Loss', linewidth=2)
ax1.set_title(f'Fine-tuning Loss Curves - {config["name"]}', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(epochs_range, fine_tune_history['train_accuracies'], 'b-', label='Training Accuracy', linewidth=2)
ax2.plot(epochs_range, fine_tune_history['val_accuracies'], 'r-', label='Validation Accuracy', linewidth=2)
ax2.axhline(y=original_best_acc, color='g', linestyle='--', alpha=0.7, label=f'Original Best ({original_best_acc:.1f}%)')
ax2.set_title('Fine-tuning Accuracy Curves', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Performance comparison bar chart
models = ['Text Baseline', 'Original Multimodal', f'Fine-tuned\n({FINE_TUNING_STRATEGY})']
accuracies = []

if baseline_results_path.exists():
    accuracies.append(baseline_acc * 100)
else:
    accuracies.append(47.36)  # Known baseline

if original_multimodal_path.exists():
    accuracies.append(original_acc * 100)
else:
    accuracies.append(55.39)  # Known best

accuracies.append(ft_results['accuracy'] * 100)

bars = ax3.bar(models, accuracies, color=['skyblue', 'lightgreen', 'lightcoral'], alpha=0.8)

# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height + 0.3,
            f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')

ax3.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
ax3.set_ylabel('Accuracy (%)')
ax3.set_ylim(0, max(accuracies) + 5)
ax3.grid(True, alpha=0.3, axis='y')

# Fine-tuning strategy summary
ax4.axis('off')
summary_text = f"""
Fine-tuning Strategy: {config['name']}

Configuration:
‚Ä¢ Learning Rate: {config['learning_rate']:.2e}
‚Ä¢ Vision LR Factor: {config['vision_lr_factor']}
‚Ä¢ Epochs: {config['epochs']}
‚Ä¢ Label Smoothing: {config['label_smoothing']}
‚Ä¢ Augmentation: {config['augmentation_strength']}

Results:
‚Ä¢ Best Val Acc: {best_fine_tune_acc:.2f}%
‚Ä¢ Test Accuracy: {ft_results['accuracy']*100:.2f}%
‚Ä¢ Improvement: {improvement:+.2f} pp
‚Ä¢ Training Time: {training_time/60:.1f} min
"""

ax4.text(0.05, 0.95, summary_text, fontsize=11, verticalalignment='top',
         bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))
ax4.set_title('Fine-tuning Summary', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Save visualization
results_dir = project_root / 'results' / 'figures'
results_dir.mkdir(exist_ok=True, parents=True)
fig.savefig(results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.png', dpi=300, bbox_inches='tight')
print(f"\nVisualization saved to: {results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.png'}")

## 10. Save Fine-tuning Results and Generate Report

In [None]:
# Save comprehensive fine-tuning results
results_dir = project_root / 'results'
results_dir.mkdir(exist_ok=True)

fine_tuning_results = {
    'strategy': FINE_TUNING_STRATEGY,
    'strategy_name': config['name'],
    'strategy_description': config['description'],
    'configuration': config,
    'training_history': fine_tune_history,
    'results': {
        'best_validation_accuracy': best_fine_tune_acc / 100,
        'test_accuracy': ft_results['accuracy'],
        'test_loss': ft_results['avg_loss'],
        'training_time_minutes': training_time / 60,
        'total_epochs': len(fine_tune_history['train_losses'])
    },
    'improvements': {
        'vs_original_multimodal': improvement,
        'vs_text_baseline': improvement_vs_baseline if 'improvement_vs_baseline' in locals() else None
    },
    'model_info': {
        'total_parameters': total_params,
        'trainable_parameters': trainable_params
    }
}

# Save results
results_file = results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.json'
with open(results_file, 'w') as f:
    json.dump(fine_tuning_results, f, indent=2)

print(f"Fine-tuning results saved to: {results_file}")

# Generate confusion matrix if we have good performance
if ft_results['accuracy'] > 0.5:  # Only if accuracy is decent
    plt.figure(figsize=(10, 8))

    # Use a subset of classes for cleaner visualization
    unique_targets = sorted(set(ft_results['targets']))
    if len(unique_targets) > 20:  # Too many classes for clean visualization
        # Show top 20 most common classes
        from collections import Counter
        target_counts = Counter(ft_results['targets'])
        top_classes = [cls for cls, _ in target_counts.most_common(20)]

        # Filter data for top classes only
        filtered_targets = []
        filtered_preds = []
        for i, target in enumerate(ft_results['targets']):
            if target in top_classes:
                filtered_targets.append(target)
                filtered_preds.append(ft_results['predictions'][i])

        cm = confusion_matrix(filtered_targets, filtered_preds, labels=top_classes)
        title_suffix = " (Top 20 Classes)"
    else:
        cm = confusion_matrix(ft_results['targets'], ft_results['predictions'])
        title_suffix = ""

    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', cbar=True)
    plt.title(f'Fine-tuned Model Confusion Matrix{title_suffix}\n{config["name"]}')
    plt.xlabel('Predicted Class')
    plt.ylabel('True Class')
    plt.tight_layout()

    # Save confusion matrix
    cm_path = results_dir / 'figures' / f'fine_tuned_{FINE_TUNING_STRATEGY}_confusion_matrix.png'
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Confusion matrix saved to: {cm_path}")

# Print final summary
print("\n" + "=" * 70)
print(f"FINE-TUNING COMPLETED: {config['name'].upper()}")
print("=" * 70)
print(f"Strategy Description: {config['description']}")
print(f"Training Time: {training_time/60:.2f} minutes")
print(f"Best Validation Accuracy: {best_fine_tune_acc:.2f}%")
print(f"Test Accuracy: {ft_results['accuracy']*100:.2f}%")
print(f"Improvement over Original: {improvement:+.2f} percentage points")

if improvement > 0:
    print("\nCONGRATULATIONS! Fine-tuning was successful!")
    print(f"Your model has improved by {improvement:.2f} percentage points.")
else:
    print("\nFine-tuning did not improve performance.")
    print("Consider trying a different strategy or adjusting hyperparameters.")

print(f"\nAll results and checkpoints saved in:")
print(f"- Results: {results_file}")
print(f"- Checkpoint: {ft_checkpoint_path}")
print(f"- Visualizations: {results_dir / 'figures'}")
print("=" * 70)

## 11. Strategy Comparison and Recommendations

**Available Fine-tuning Strategies:**

1. **Conservative:** Safe, minimal changes with very low learning rates
2. **Layerwise:** Different learning rates for different model components  
3. **Progressive:** Gradually unfreeze layers during training
4. **Augmented:** Enhanced data augmentation for better generalization
5. **Regularization:** Optimize dropout and weight decay parameters
6. **Architecture:** Minor architectural modifications

**To try different strategies:**
1. Change `FINE_TUNING_STRATEGY` at the top of this notebook
2. Re-run cells 2 onwards
3. Compare results across strategies

**Next Steps:**
- Try multiple strategies and compare results
- Ensemble the best fine-tuned models
- Use the best model for inference applications