# Multimodal VQA Training

**Goal**: Train a Visual Question Answering model combining text + vision

This notebook walks through:
1. Loading images and questions together
2. Building a multimodal model (LSTM + ResNet CNN)
3. Training with fusion strategies
4. Evaluating performance
5. Comparing with text-only baseline (47%)

**Target**: 60-70% accuracy with vision!

**Note**: Can run on Google Colab (GPU recommended for faster training)

## 0. Setup

**For Google Colab**: This cell will automatically mount your Drive and install packages.
**For Local**: This cell will skip Colab-specific setup.

In [None]:
# Google Colab setup (auto-detects environment)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Install packages
    print("Installing packages...")
    !pip install -q torch torchvision tqdm pyyaml scikit-learn pandas matplotlib seaborn Pillow
    
    # Set project path
    import os
    PROJECT_PATH = "/content/drive/MyDrive/WOA7015 Advanced Machine Learning"
    os.chdir(PROJECT_PATH)
    print(f"  Running on Colab - Path: {PROJECT_PATH}")
    
except ImportError:
    # Running locally
    PROJECT_PATH = None
    print("  Running locally")

✓ Running locally


## 1. Imports and Setup

In [None]:
import sys
import os
from pathlib import Path

# Setup paths
if 'PROJECT_PATH' in globals() and PROJECT_PATH:
    project_root = Path(PROJECT_PATH)
else:
    project_root = Path().absolute().parent

sys.path.insert(0, str(project_root))

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import yaml
from tqdm.notebook import tqdm

# Import our modules
from src.data.dataset import create_multimodal_dataloaders
from src.models.multimodal_model import create_multimodal_model
from src.training.multimodal_trainer import MultimodalVQATrainer
from src.evaluation.metrics import VQAMetrics, calculate_accuracy

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("  Imports successful")
print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

ModuleNotFoundError: No module named 'torch'

## 2. Configuration

In [None]:
# Load configuration
config_path = project_root / 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
model_type='attention'  # Valid options: 'concat', 'attention', 'bilinear', 'cross_attention'
# Update config for multimodal training
config['training']['batch_size'] = 16  # Smaller batch for multimodal
config['training']['num_epochs'] = 10
config['training']['learning_rate'] = 1e-4  # Lower LR for vision features
config['model']['vision_encoder'] = 'resnet50'
config['model']['fusion_strategy'] = model_type  # Start with simplest
print("\nMultimodal training config:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Vision encoder: {config['model']['vision_encoder']}")
print(f"  Fusion strategy: {config['model']['fusion_strategy']}")

## 3. Data Loading

Loading the PathVQA dataset with **images** using the new `MultimodalVQADataset`.

In [1]:
# Create multimodal dataloaders
dataset_path = project_root / 'data'
train_loader, val_loader, test_loader, vocab_size, num_classes, vocab, answer_to_idx = create_multimodal_dataloaders(
    train_csv=str(dataset_path / 'trainrenamed.csv'),
    test_csv=str(dataset_path / 'testrenamed.csv'),
    image_dir=str(dataset_path / 'train'),
    answers_file=str(dataset_path / 'answers.txt'),  # This was missing!
    batch_size=config['training']['batch_size'],
    val_split=0.1,
    num_workers=0,
    image_size=224
)

print(f"  Data loaded successfully")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Number of classes: {num_classes}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Get a sample batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  Questions: {sample_batch['question'].shape}")
print(f"  Images: {sample_batch['image'].shape}")  # [batch, 3, 224, 224]
print(f"  Answers: {sample_batch['answer'].shape}")

NameError: name 'project_root' is not defined

In [None]:
# Debug: Check answer vocabulary consistency
print("Answer vocabulary debug:")
print(f"  Number of classes: {num_classes}")
print(f"  Max answer index in answer_to_idx: {max(answer_to_idx.values()) if answer_to_idx else 'N/A'}")
print(f"  '<UNK>' token index: {answer_to_idx.get('<UNK>', 'Not present')}")

# Check a few samples from the training data
sample_batch = next(iter(train_loader))
print(f"\nSample batch answer indices:")
print(f"  Min answer index: {sample_batch['answer'].min().item()}")
print(f"  Max answer index: {sample_batch['answer'].max().item()}")
print(f"  Answer indices range should be 0 to {num_classes-1}")

# Verify all indices are within valid range
max_answer_in_batch = sample_batch['answer'].max().item()
if max_answer_in_batch >= num_classes:
    print(f"  ERROR: Found answer index {max_answer_in_batch} but only {num_classes} classes!")
    print("This will cause the IndexError during training")
else:
    print("  All answer indices are within valid range")

## 4. Visualize Sample Images

Let's verify the images are loading correctly.

In [None]:
# Display 4 sample images with their questions and answers
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.ravel()

# Get batch
batch = next(iter(train_loader))

for i in range(4):
    # Denormalize image for display
    img = batch['image'][i].cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    # Get question and answer
    question = batch['question'][i].cpu().numpy()
    answer_idx = batch['answer'][i].item()
    
    # Decode question (first few words)
    question_text = ' '.join([vocab.idx_to_word.get(idx, '<UNK>') 
                               for idx in question[:15] if idx > 0])
    answer_text = vocab.idx_to_word.get(answer_idx, '<UNK>')
    
    # Plot
    axes[i].imshow(img)
    axes[i].set_title(f"Q: {question_text}...\nA: {answer_text}", fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print("  Images loading correctly!")

## 5. Create Multimodal Model

Creating a multimodal VQA model with:
- **Vision Encoder**: ResNet50 CNN (extracts features from images)
- **Text Encoder**: LSTM (processes questions)
- **Fusion**: Concatenation strategy

In [None]:
# Create multimodal model

model = create_multimodal_model(
    model_type=config['model']['fusion_strategy'],  # Valid options: 'concat', 'attention', 'bilinear', 'cross_attention'
    vocab_size=vocab_size,
    num_classes=num_classes,
    embedding_dim=config['text']['embedding_dim'],
    text_hidden_dim=config['model']['baseline']['hidden_dim'],
    fusion_hidden_dim=config['model']['baseline']['hidden_dim'],
    dropout=config['model']['baseline']['dropout']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"  Model created successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.1f} MB")
print(f"\nModel architecture:")
print(model)

## 6. Test Forward Pass

Quick sanity check that the model works.

In [None]:
# Test forward pass
model.eval()
with torch.no_grad():
    test_batch = next(iter(train_loader))
    test_questions = test_batch['question'].to(device)
    test_images = test_batch['image'].to(device)
    test_answers = test_batch['answer'].to(device)
    
    outputs = model(test_questions, test_images)
    predictions = torch.argmax(outputs, dim=1)
    
    accuracy = (predictions == test_answers).float().mean()
    
    print(f"  Forward pass successful")
    print(f"  Input questions shape: {test_questions.shape}")
    print(f"  Input images shape: {test_images.shape}")
    print(f"  Output shape: {outputs.shape}")
    print(f"  Random accuracy: {accuracy.item():.4f} (should be ~0.002 for random)")

model.train()

## 7. Setup Training

Initialize the trainer with early stopping and checkpointing.

In [None]:
# Create trainer with enhanced history tracking
experiment_name=f"multimodal_{model_type}"
trainer = MultimodalVQATrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,  # Pass the full config dictionary
    device=device,
    checkpoint_dir=project_root / 'checkpoints',
    experiment_name=experiment_name
)

# Initialize training history tracking if not present
if not hasattr(trainer, 'train_losses'):
    trainer.train_losses = []
if not hasattr(trainer, 'val_losses'):
    trainer.val_losses = []
if not hasattr(trainer, 'train_accuracies'):
    trainer.train_accuracies = []
if not hasattr(trainer, 'val_accuracies'):
    trainer.val_accuracies = []

# Add history tracking method to trainer
def track_epoch_history(self, train_loss, train_acc, val_loss, val_acc):
    """Track training history for plotting"""
    self.train_losses.append(train_loss)
    self.train_accuracies.append(train_acc)
    self.val_losses.append(val_loss)
    self.val_accuracies.append(val_acc)

# Add method to trainer instance
import types
trainer.track_epoch_history = types.MethodType(track_epoch_history, trainer)

print(f"Trainer initialized with history tracking")
print(f"  Checkpoint directory: {trainer.checkpoint_dir}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Early stopping patience: 5 epochs")
print(f"  Training history will be automatically tracked")

## 8. Train Model

This will take some time (~1-2 hours on CPU, ~15-20 minutes on GPU).

In [None]:
# Custom training loop with guaranteed history tracking
print("Starting Custom Training with History Tracking")
print("=" * 60)

import time

# Training configuration
num_epochs = config['training']['num_epochs']
patience = 5
best_val_acc = 0.0
patience_counter = 0
start_time = time.time()

# Optimizer and criterion
optimizer = torch.optim.AdamW(model.parameters(), 
                             lr=config['training']['learning_rate'], 
                             weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for batch in train_pbar:
        questions = batch['question'].to(device)
        images = batch['image'].to(device)
        answers = batch['answer'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(questions, images)
        loss = criterion(outputs, answers)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        train_loss += loss.item()
        predictions = outputs.argmax(dim=1)
        train_correct += (predictions == answers).sum().item()
        train_total += answers.size(0)
        
        # Update progress bar
        current_acc = train_correct / train_total
        train_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{current_acc:.4f}'
        })
    
    train_loss_avg = train_loss / len(train_loader)
    train_acc = train_correct / train_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
            questions = batch['question'].to(device)
            images = batch['image'].to(device)
            answers = batch['answer'].to(device)
            
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            
            val_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            val_correct += (predictions == answers).sum().item()
            val_total += answers.size(0)
    
    val_loss_avg = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    
    # Track history
    trainer.track_epoch_history(train_loss_avg, train_acc, val_loss_avg, val_acc)
    
    # Print epoch results
    print(f"Train Loss: {train_loss_avg:.4f}, Train Acc: {train_acc:.4f} ({train_acc*100:.2f}%)")
    print(f"Val Loss: {val_loss_avg:.4f}, Val Acc: {val_acc:.4f} ({val_acc*100:.2f}%)")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        
        # Save checkpoint
        checkpoint_path = trainer.checkpoint_dir / experiment_name / 'best_model.pth'
        checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'train_acc': train_acc,
            'val_loss': val_loss_avg,
            'train_loss': train_loss_avg
        }, checkpoint_path)
        
        trainer.best_val_acc = val_acc
        print(f"New best model saved! Val Acc: {val_acc:.4f} ({val_acc*100:.2f}%)")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered after {patience_counter} epochs without improvement")
            break

training_time = time.time() - start_time
trainer.current_epoch = epoch

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print(f"Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
print(f"Total training time: {training_time/60:.2f} minutes")
print(f"Training epochs completed: {len(trainer.train_losses)}")
print(f"Training history successfully tracked with {len(trainer.train_losses)} data points")

## 9. Plot Training History

In [None]:
# Plot actual training history from trainer
if hasattr(trainer, 'train_losses') and len(trainer.train_losses) > 0:
    # Use actual training history from trainer
    epochs = list(range(1, len(trainer.train_losses) + 1))
    train_losses = trainer.train_losses
    val_losses = trainer.val_losses if hasattr(trainer, 'val_losses') else [0] * len(epochs)
    train_accs = trainer.train_accuracies if hasattr(trainer, 'train_accuracies') else [0] * len(epochs)
    val_accs = trainer.val_accuracies if hasattr(trainer, 'val_accuracies') else [0] * len(epochs)
    
    print(f"Using actual training history from trainer")
    print(f"  Training completed: {len(epochs)} epochs")
    print(f"  Best validation accuracy: {max(val_accs):.4f} ({max(val_accs)*100:.2f}%)")
    
else:
    # Fallback: Check if trainer has history method or attributes
    if hasattr(trainer, 'get_history'):
        history = trainer.get_history()
        epochs = list(range(1, len(history['train_loss']) + 1))
        train_losses = history['train_loss']
        val_losses = history['val_loss']
        train_accs = history['train_acc']
        val_accs = history['val_acc']
        print(f"Using training history from trainer.get_history()")
        
    elif hasattr(trainer, 'history'):
        history = trainer.history
        epochs = list(range(1, len(history['train_loss']) + 1))
        train_losses = history['train_loss']
        val_losses = history['val_loss'] 
        train_accs = history['train_acc']
        val_accs = history['val_acc']
        print(f"Using training history from trainer.history")
        
    else:
        # No training history available - likely training hasn't been run yet
        print("No training history found. Either:")
        print("  1. Training hasn't been completed yet")
        print("  2. Trainer doesn't store history")
        print("  3. Run the training cell first")
        
        # Create placeholder for visualization
        epochs = list(range(1, 11))
        train_losses = [0] * 10
        val_losses = [0] * 10
        train_accs = [0] * 10
        val_accs = [0] * 10
        print("  Using placeholder data for plot structure")

# Create training history plots
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
if max(train_losses) > 0:  # Only plot if we have real data
    axes[0].plot(epochs, train_losses, label='Train Loss', marker='o', color='blue')
    axes[0].plot(epochs, val_losses, label='Val Loss', marker='s', color='orange')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
else:
    axes[0].text(0.5, 0.5, 'Training not completed\nRun training cell first', 
                 ha='center', va='center', transform=axes[0].transAxes,
                 bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
    axes[0].set_title('Training Loss (Pending)')

# Plot accuracy
if max(train_accs) > 0:  # Only plot if we have real data
    axes[1].plot(epochs, train_accs, label='Train Accuracy', marker='o', color='blue')
    axes[1].plot(epochs, val_accs, label='Val Accuracy', marker='s', color='orange')
    axes[1].axhline(y=0.4736, color='r', linestyle='--', label='Text-only baseline (47.36%)')
    
    # Add best validation accuracy line
    best_val_acc = max(val_accs)
    if best_val_acc > 0:
        axes[1].axhline(y=best_val_acc, color='g', linestyle=':', alpha=0.7, 
                       label=f'Best validation ({best_val_acc*100:.2f}%)')
    
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    # Print key achievements
    if best_val_acc > 0.4736:
        print(f"SUCCESS: Beat text baseline by {(best_val_acc - 0.4736)*100:.2f} pp!")
    elif best_val_acc > 0.4125:
        print(f"PROGRESS: Improved over original multimodal by {(best_val_acc - 0.4125)*100:.2f} pp")
        
else:
    axes[1].text(0.5, 0.5, 'Training not completed\nRun training cell first', 
                 ha='center', va='center', transform=axes[1].transAxes,
                 bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
    axes[1].set_title('Training Accuracy (Pending)')

plt.tight_layout()

# Save plot if we have real data
if max(train_accs) > 0:
    results_dir = project_root / 'results' / 'figures'
    results_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(results_dir / 'multimodal_training_history.png', dpi=150)
    print(f"Training history plot saved to {results_dir / 'multimodal_training_history.png'}")

plt.show()

# Print training summary if data is available
if max(train_accs) > 0:
    print(f"\nTraining Summary:")
    print(f"  Total epochs: {len(epochs)}")
    print(f"  Final train accuracy: {train_accs[-1]*100:.2f}%")
    print(f"  Final validation accuracy: {val_accs[-1]*100:.2f}%")
    print(f"  Best validation accuracy: {max(val_accs)*100:.2f}%")
    print(f"  Final train loss: {train_losses[-1]:.4f}")
    print(f"  Final validation loss: {val_losses[-1]:.4f}")

## 10. Evaluate on Test Set

In [None]:
# Load best model checkpoint
checkpoint_path = trainer.checkpoint_dir /"multimodal_concat" / 'best_model.pth'
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    # Optionally update best_val_acc in trainer if needed for consistency
    trainer.best_val_acc = checkpoint.get('val_acc', trainer.best_val_acc if hasattr(trainer, 'best_val_acc') else 0.0)
    print(f"  Loaded best model from: {checkpoint_path}")
else:
    print(f"   Best model not found at: {checkpoint_path}")
    print("Using current model state for evaluation")

# Ensure the model is in evaluation mode
model.eval()

# Evaluate
test_loss, test_acc, all_preds, all_labels = trainer.evaluate(test_loader)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)

## 11. Analyze Predictions

In [None]:
# Get some test samples and predictions
model.eval()
test_batch = next(iter(test_loader))

with torch.no_grad():
    questions = test_batch['question'].to(device)
    images = test_batch['image'].to(device)
    true_answers = test_batch['answer'].to(device)
    
    outputs = model(questions, images)
    pred_answers = torch.argmax(outputs, dim=1)

# Visualize some predictions
fig, axes = plt.subplots(3, 3, figsize=(16, 14))
axes = axes.ravel()

for i in range(9):
    # Denormalize image
    img = images[i].cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    # Get question
    question = questions[i].cpu().numpy()
    question_text = ' '.join([vocab.idx_to_word.get(idx, '<UNK>') 
                               for idx in question[:20] if idx > 0])
    
    # Get answers
    true_ans = vocab.idx_to_word.get(true_answers[i].item(), '<UNK>')
    pred_ans = vocab.idx_to_word.get(pred_answers[i].item(), '<UNK>')
    
    # Determine if correct
    is_correct = true_answers[i].item() == pred_answers[i].item()
    color = 'green' if is_correct else 'red'
    
    # Plot
    axes[i].imshow(img)
    axes[i].set_title(
        f"Q: {question_text}...\n"
        f"True: {true_ans} | Pred: {pred_ans}",
        fontsize=9,
        color=color
    )
    axes[i].axis('off')

plt.tight_layout()
plt.savefig(project_root / 'results' / 'figures' / 'multimodal_predictions.png', dpi=150)
plt.show()

print(f"  Predictions visualization saved")

## 12. Save Results

In [None]:
# Save test predictions
results_df = pd.DataFrame({
    'true_answer': all_labels,
    'predicted_answer': all_preds,
    'correct': (np.array(all_preds) == np.array(all_labels)).astype(int)
})

results_path = project_root / 'results' / 'predictions' / 'multimodal_concat_predictions.csv'
results_df.to_csv(results_path, index=False)

print(f"  Results saved to {results_path}")
print(f"\nSummary:")
print(f"  Total predictions: {len(results_df)}")
print(f"  Correct: {results_df['correct'].sum()}")
print(f"  Incorrect: {len(results_df) - results_df['correct'].sum()}")
print(f"  Accuracy: {results_df['correct'].mean():.4f}")

## 13. Next Steps

 **Achieved Goals:**
-  Implemented multimodal VQA with vision + text
-  Trained concatenation fusion model
-  Compared with text-only baseline (47.36%)

 **Potential Improvements:**

1. **Try different fusion strategies:**
   - Attention fusion (better weighting)
   - Bilinear fusion (richer interactions)
   - Cross-modal attention (full co-attention)

2. **Better vision encoder:**
   - Use AttentionVisionEncoder with spatial attention
   - Try different CNN backbones (ResNet101, EfficientNet)
   - Fine-tune vision encoder instead of freezing

3. **Data augmentation:**
   - Stronger image augmentation
   - Text augmentation (paraphrasing)

4. **Hyperparameter tuning:**
   - Learning rate scheduling
   - Different batch sizes
   - Gradient clipping

5. **Pre-trained models:**
   - Use CLIP or ViLT
   - Fine-tune BERT for questions

## 14. Improved Multimodal Model

Let's implement an **improved version** with the following enhancements:
1. **Trainable ResNet50** (instead of frozen)
2. **Cross-modal attention fusion** (instead of simple concatenation)
3. **Spatial attention** for vision features
4. **Better regularization** and training strategies

This should significantly improve the 41.25% → potentially 50%+ accuracy!

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class ImprovedMultimodalVQA(nn.Module):
    """Improved multimodal VQA with cross-modal attention fusion"""
    
    def __init__(self, vocab_size, num_classes, embedding_dim=300, 
                 text_hidden_dim=512, fusion_hidden_dim=512, dropout=0.3):
        super().__init__()
        self.num_classes = num_classes
        
        # Text encoder (same as before)
        self.text_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.text_lstm = nn.LSTM(embedding_dim, text_hidden_dim, 
                                batch_first=True, bidirectional=True)
        self.text_dropout = nn.Dropout(dropout)
        
        # Vision encoder (now trainable!) - Fix deprecation warning
        self.vision_encoder = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        # Remove final classification layer
        self.vision_encoder = nn.Sequential(*list(self.vision_encoder.children())[:-2])
        
        # Add spatial attention for vision
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2048, 512, 1),
            nn.ReLU(),
            nn.Conv2d(512, 1, 1),
            nn.Sigmoid()
        )
        
        # Cross-modal fusion with attention
        self.vision_proj = nn.Linear(2048, fusion_hidden_dim)
        self.text_proj = nn.Linear(text_hidden_dim * 2, fusion_hidden_dim)
        
        # Multi-head cross-attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=fusion_hidden_dim, 
            num_heads=8, 
            dropout=dropout
        )
        
        # Final classifier with more regularization
        self.classifier = nn.Sequential(
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout * 1.5),  # Stronger dropout
            nn.Linear(fusion_hidden_dim // 2, num_classes)
        )
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Better weight initialization"""
        for module in [self.text_embedding, self.vision_proj, self.text_proj, self.classifier]:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.1)
    
    def forward(self, questions, images):
        batch_size = questions.size(0)
        
        # Text processing
        text_embedded = self.text_embedding(questions)
        text_output, (text_hidden, _) = self.text_lstm(text_embedded)
        
        # Use final hidden state (both directions)
        text_features = torch.cat([text_hidden[0], text_hidden[1]], dim=1)  # [batch, 1024]
        text_features = self.text_dropout(text_features)
        
        # Vision processing with spatial attention
        vision_maps = self.vision_encoder(images)  # [batch, 2048, 7, 7]
        
        # Apply spatial attention
        attention_weights = self.spatial_attention(vision_maps)  # [batch, 1, 7, 7]
        attended_vision = vision_maps * attention_weights  # Broadcast multiply
        
        # Global average pooling
        vision_features = F.adaptive_avg_pool2d(attended_vision, 1).squeeze()  # [batch, 2048]
        
        # Project to same dimension
        vision_proj = self.vision_proj(vision_features)  # [batch, 512]
        text_proj = self.text_proj(text_features)        # [batch, 512]
        
        # Cross-modal attention: text queries vision
        text_proj = text_proj.unsqueeze(1)  # [batch, 1, 512]
        vision_proj = vision_proj.unsqueeze(1)  # [batch, 1, 512]
        
        # Attention: query=text, key=vision, value=vision
        attended_features, attention_weights = self.cross_attention(
            query=text_proj.transpose(0, 1),
            key=vision_proj.transpose(0, 1),
            value=vision_proj.transpose(0, 1)
        )
        
        # Back to [batch, hidden_dim]
        fused_features = attended_features.transpose(0, 1).squeeze(1)
        
        # Final classification
        logits = self.classifier(fused_features)
        
        return logits

print("Improved multimodal model class defined!")
print("\nKey improvements:")
print("1. ResNet50 is now trainable (not frozen)")
print("2. Spatial attention for focusing on important image regions") 
print("3. Cross-modal attention for better text-vision fusion")
print("4. Better weight initialization and stronger regularization")

In [None]:
# Create improved multimodal model
model = ImprovedMultimodalVQA(
    vocab_size=len(vocab),
    num_classes=num_classes,
    embedding_dim=300,
    text_hidden_dim=512,
    fusion_hidden_dim=512,
    dropout=0.3
).to(device)

print(f"Model created and moved to {device}")
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Test forward pass
try:
    # Get a small batch for testing
    test_batch = next(iter(train_loader))
    
    # Handle different data loader formats
    if isinstance(test_batch, dict):
        # Dictionary format: {'question': tensor, 'image': tensor, 'answer': tensor}
        test_questions = test_batch['question'].to(device)
        test_images = test_batch['image'].to(device) 
        test_answers = test_batch['answer'].to(device)
    elif isinstance(test_batch, (list, tuple)) and len(test_batch) == 3:
        # List/tuple format: (questions, images, answers)
        test_questions, test_images, test_answers = test_batch
        test_questions = test_questions.to(device)
        test_images = test_images.to(device)
        test_answers = test_answers.to(device)
    else:
        raise ValueError(f"Unexpected batch format: {type(test_batch)}")

    with torch.no_grad():
        test_outputs = model(test_questions, test_images)

    print(f"Forward pass successful!")
    print(f"Input shapes: questions={test_questions.shape}, images={test_images.shape}")
    print(f"Output shape: {test_outputs.shape}")
    print(f"Output range: [{test_outputs.min().item():.3f}, {test_outputs.max().item():.3f}]")
    print(f"Batch format: {type(test_batch)}")

except Exception as e:
    print(f"Error in forward pass: {e}")
    print(f"Batch type: {type(test_batch)}")
    print(f"Batch content: {test_batch if not isinstance(test_batch, (dict, list, tuple)) else 'batch data'}")
    import traceback
    traceback.print_exc()

## 15. Enhanced Training Setup

Now let's set up **enhanced training** with:
- **Differential learning rates** (lower for pretrained vision, higher for new components)
- **Label smoothing** to reduce overconfidence
- **Better scheduling** and **gradient clipping**

In [None]:
# Enhanced training configuration with differential learning rates
import torch.optim as optim

vision_params = []
other_params = []

for name, param in model.named_parameters():
    if 'vision_encoder' in name:
        vision_params.append(param)
    else:
        other_params.append(param)

# Different learning rates for different components
optimizer = optim.AdamW([
    {'params': vision_params, 'lr': 1e-5, 'weight_decay': 1e-4},    # Lower for pretrained
    {'params': other_params, 'lr': 1e-3, 'weight_decay': 1e-4}     # Higher for new layers
], betas=(0.9, 0.999))

# Enhanced loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=5, T_mult=2, eta_min=1e-6
)

print("Enhanced training setup configured:")
print(f"- Vision encoder parameters: {len(vision_params):,}")
print(f"- Other parameters: {len(other_params):,}")
print(f"- Vision encoder LR: 1e-5")
print(f"- Other components LR: 1e-3")
print(f"- Using label smoothing: 0.1")
print(f"- Scheduler: CosineAnnealingWarmRestarts")

## 16. Train Improved Model

Let's train the improved model and see if we can beat the original 41.25% accuracy!

In [None]:
def train_epoch_with_history(model, train_loader, val_loader, optimizer, criterion, scheduler, 
                           device, epoch, total_epochs):
    """Enhanced training with comprehensive history tracking"""
    
    # Training phase
    model.train()
    train_losses = []
    train_correct = 0
    train_total = 0
    
    print(f"Epoch {epoch}/{total_epochs}")
    print("-" * 50)
    
    for batch_idx, batch in enumerate(train_loader):
        # Handle dictionary format data loader
        if isinstance(batch, dict):
            questions = batch['question'].to(device)
            images = batch['image'].to(device)
            answers = batch['answer'].to(device)
        else:
            # Handle tuple format (fallback)
            questions, images, answers = batch
            questions, images, answers = questions.to(device), images.to(device), answers.to(device)
        
        optimizer.zero_grad()
        outputs = model(questions, images)
        loss = criterion(outputs, answers)
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Track metrics
        train_losses.append(loss.item())
        _, predicted = torch.max(outputs.data, 1)
        train_total += answers.size(0)
        train_correct += (predicted == answers).sum().item()
        
        if batch_idx % 50 == 0:
            current_acc = 100. * train_correct / train_total
            print(f"Batch {batch_idx:3d}/{len(train_loader):3d} | "
                  f"Loss: {loss.item():.4f} | "
                  f"Acc: {current_acc:.2f}%")
    
    # Calculate epoch averages
    avg_train_loss = sum(train_losses) / len(train_losses)
    train_accuracy = 100. * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_losses = []
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            # Handle dictionary format data loader
            if isinstance(batch, dict):
                questions = batch['question'].to(device)
                images = batch['image'].to(device)
                answers = batch['answer'].to(device)
            else:
                # Handle tuple format (fallback)
                questions, images, answers = batch
                questions, images, answers = questions.to(device), images.to(device), answers.to(device)
            
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            
            val_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            val_total += answers.size(0)
            val_correct += (predicted == answers).sum().item()
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    val_accuracy = 100. * val_correct / val_total
    
    # Update scheduler
    scheduler.step()
    
    # Print epoch summary
    print(f"\nEpoch {epoch} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_accuracy:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
    print("=" * 50)
    
    return {
        'train_loss': avg_train_loss,
        'train_accuracy': train_accuracy,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    }

In [None]:
# Custom training loop with comprehensive history tracking
EPOCHS = 15
history = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': []
}

best_val_acc = 0.0
best_model_state = None

print("Starting enhanced multimodal training...")
print(f"Training for {EPOCHS} epochs with differential learning rates")
print(f"Device: {device}")

start_time = time.time()

try:
    for epoch in range(1, EPOCHS + 1):
        # Train one epoch
        epoch_results = train_epoch_with_history(
            model, train_loader, val_loader, 
            optimizer, criterion, scheduler, 
            device, epoch, EPOCHS
        )
        
        # Store history
        history['train_losses'].append(epoch_results['train_loss'])
        history['train_accuracies'].append(epoch_results['train_accuracy'])
        history['val_losses'].append(epoch_results['val_loss'])
        history['val_accuracies'].append(epoch_results['val_accuracy'])
        
        # Save best model
        if epoch_results['val_accuracy'] > best_val_acc:
            best_val_acc = epoch_results['val_accuracy']
            best_model_state = model.state_dict().copy()
            print(f"New best validation accuracy: {best_val_acc:.2f}%")
            
            # Save checkpoint
            checkpoint_path = project_root / 'checkpoints' / 'multimodal_concat' / 'best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history
            }, checkpoint_path)

except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except Exception as e:
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()

training_time = time.time() - start_time
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Best model loaded for evaluation")

## 17. Compare Training Results

In [None]:
# Comprehensive visualization of training history
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, len(history['train_losses']) + 1)

# Loss curves
ax1.plot(epochs_range, history['train_losses'], 'b-', label='Training Loss', linewidth=2)
ax1.plot(epochs_range, history['val_losses'], 'r-', label='Validation Loss', linewidth=2)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(epochs_range, history['train_accuracies'], 'b-', label='Training Accuracy', linewidth=2)
ax2.plot(epochs_range, history['val_accuracies'], 'r-', label='Validation Accuracy', linewidth=2)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning rate over epochs (approximate)
lr_values = [1e-3 * (0.5 ** (epoch // 5)) for epoch in epochs_range]
ax3.semilogy(epochs_range, lr_values, 'g-', label='Learning Rate', linewidth=2)
ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('Learning Rate (log scale)')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Training summary statistics
final_stats = {
    'Final Train Acc': f"{history['train_accuracies'][-1]:.2f}%",
    'Final Val Acc': f"{history['val_accuracies'][-1]:.2f}%",
    'Best Val Acc': f"{max(history['val_accuracies']):.2f}%",
    'Final Train Loss': f"{history['train_losses'][-1]:.4f}",
    'Final Val Loss': f"{history['val_losses'][-1]:.4f}",
    'Total Epochs': len(history['train_losses'])
}

ax4.axis('off')
stats_text = '\n'.join([f"{k}: {v}" for k, v in final_stats.items()])
ax4.text(0.1, 0.5, stats_text, fontsize=12, verticalalignment='center',
         bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
ax4.set_title('Training Summary', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Print detailed results
print("Enhanced Multimodal Model Training Results:")
print("=" * 50)
print(f"Final Training Accuracy: {history['train_accuracies'][-1]:.2f}%")
print(f"Final Validation Accuracy: {history['val_accuracies'][-1]:.2f}%")
print(f"Best Validation Accuracy: {max(history['val_accuracies']):.2f}%")
print(f"Total Training Epochs: {len(history['train_losses'])}")
print("=" * 50)

## 18. Evaluate Improved Model on Test Set

In [None]:
# Final evaluation on test set
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def evaluate_model(model, test_loader, device):
    """Comprehensive model evaluation"""
    model.eval()
    all_predictions = []
    all_targets = []
    test_loss = 0.0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in test_loader:
            # Handle dictionary format data loader
            if isinstance(batch, dict):
                questions = batch['question'].to(device)
                images = batch['image'].to(device)
                answers = batch['answer'].to(device)
            else:
                # Handle tuple format (fallback)
                questions, images, answers = batch
                questions, images, answers = questions.to(device), images.to(device), answers.to(device)
            
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(answers.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
    recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)
    
    avg_loss = test_loss / len(test_loader)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'avg_loss': avg_loss,
        'predictions': all_predictions,
        'targets': all_targets
    }

# Evaluate the improved model
print("Evaluating improved multimodal model on test set...")
improved_results = evaluate_model(model, test_loader, device)

print("\nImproved Multimodal Model Results:")
print("=" * 40)
print(f"Test Accuracy: {improved_results['accuracy']:.4f} ({improved_results['accuracy']*100:.2f}%)")
print(f"Test Precision: {improved_results['precision']:.4f}")
print(f"Test Recall: {improved_results['recall']:.4f}")
print(f"Test F1-Score: {improved_results['f1_score']:.4f}")
print(f"Test Loss: {improved_results['avg_loss']:.4f}")
print("=" * 40)

# Load and compare with baseline if available
baseline_path = project_root / 'results' / 'text_baseline_results.json'
if baseline_path.exists():
    import json
    with open(baseline_path, 'r') as f:
        baseline_results = json.load(f)
    
    baseline_acc = baseline_results.get('accuracy', 0.0)
    improvement = improved_results['accuracy'] - baseline_acc
    
    print(f"\nComparison with Text Baseline:")
    print(f"Baseline Accuracy: {baseline_acc:.4f} ({baseline_acc*100:.2f}%)")
    print(f"Improved Model: {improved_results['accuracy']:.4f} ({improved_results['accuracy']*100:.2f}%)")
    print(f"Improvement: {improvement:.4f} ({improvement*100:.2f}% points)")
    
    if improvement > 0:
        print("Performance improvement achieved!")
    else:
        print("Performance did not improve - consider further tuning")
else:
    print("Baseline results not found - run text baseline first for comparison")

## 19. Compare Answer Distributions and Error Analysis

In [None]:
# Save improved model results
results_dir = project_root / 'results'
results_dir.mkdir(exist_ok=True)

# Prepare results dictionary
improved_model_results = {
    'model_type': 'improved_multimodal',
    'architecture': 'ResNet50 + LSTM + Cross-Modal Attention',
    'test_metrics': {
        'accuracy': improved_results['accuracy'],
        'precision': improved_results['precision'],
        'recall': improved_results['recall'],
        'f1_score': improved_results['f1_score'],
        'loss': improved_results['avg_loss']
    },
    'training_history': history,
    'best_validation_accuracy': max(history['val_accuracies']) / 100.0,
    'final_validation_accuracy': history['val_accuracies'][-1] / 100.0,
    'model_parameters': {
        'vocab_size': len(vocab),
        'num_classes': NUM_CLASSES,
        'embedding_dim': 300,
        'text_hidden_dim': 512,
        'fusion_hidden_dim': 512,
        'dropout': 0.3
    },
    'training_config': {
        'epochs': len(history['train_losses']),
        'vision_lr': 1e-5,
        'other_lr': 1e-3,
        'label_smoothing': 0.1,
        'scheduler': 'CosineAnnealingWarmRestarts'
    },
    'improvements': [
        'Trainable ResNet50 vision encoder',
        'Spatial attention mechanism',
        'Cross-modal attention fusion',
        'Differential learning rates',
        'Label smoothing',
        'Gradient clipping',
        'Enhanced regularization'
    ]
}

# Save results
results_file = results_dir / 'improved_multimodal_results.json'
with open(results_file, 'w') as f:
    json.dump(improved_model_results, f, indent=2)

print(f"Results saved to: {results_file}")

# Generate confusion matrix
if len(set(improved_results['targets'])) > 1:  # Check if we have multiple classes
    plt.figure(figsize=(10, 8))
    cm = confusion_matrix(improved_results['targets'], improved_results['predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=range(NUM_CLASSES), 
                yticklabels=range(NUM_CLASSES))
    plt.title('Improved Multimodal Model - Confusion Matrix')
    plt.xlabel('Predicted Class')
    plt.ylabel('True Class')
    plt.tight_layout()
    
    # Save confusion matrix
    cm_path = results_dir / 'figures' / 'improved_multimodal_confusion_matrix.png'
    cm_path.parent.mkdir(exist_ok=True)
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Confusion matrix saved to: {cm_path}")

# Performance comparison visualization if baseline exists
baseline_path = project_root / 'results' / 'text_baseline_results.json'
if baseline_path.exists():
    with open(baseline_path, 'r') as f:
        baseline_results = json.load(f)
    
    # Create comparison chart
    models = ['Text Baseline', 'Improved Multimodal']
    accuracies = [baseline_results['accuracy'], improved_results['accuracy']]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(models, [acc*100 for acc in accuracies], 
                   color=['skyblue', 'lightcoral'], alpha=0.8)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{acc*100:.2f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.title('Model Performance Comparison', fontsize=14, fontweight='bold')
    plt.ylabel('Accuracy (%)')
    plt.ylim(0, max(accuracies)*100 + 5)
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add improvement annotation
    improvement = (improved_results['accuracy'] - baseline_results['accuracy']) * 100
    plt.annotate(f'Improvement: {improvement:+.2f}%', 
                xy=(1, accuracies[1]*100), xytext=(1.2, accuracies[1]*100 + 2),
                arrowprops=dict(arrowstyle='->', color='green', lw=2),
                fontsize=12, fontweight='bold', color='green')
    
    plt.tight_layout()
    
    # Save comparison chart
    comparison_path = results_dir / 'figures' / 'model_comparison.png'
    comparison_path.parent.mkdir(exist_ok=True)
    plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Comparison chart saved to: {comparison_path}")

print("\nImproved multimodal training completed successfully!")
print("All results, visualizations, and model checkpoints have been saved.")