# Multimodal VQA Training

**Goal**: Train a Visual Question Answering model combining text + vision

This notebook walks through:
1. Loading images and questions together
2. Building a multimodal model (LSTM + ResNet CNN)
3. Training with fusion strategies
4. Evaluating performance
5. Comparing with text-only baseline (47%)

**Target**: 60-70% accuracy with vision!

**Note**: Can run on Google Colab (GPU recommended for faster training)

## 0. Setup

**For Google Colab**: This cell will automatically mount your Drive and install packages.
**For Local**: This cell will skip Colab-specific setup.

In [2]:
# Google Colab setup (auto-detects environment)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Install packages
    print("Installing packages...")
    !pip install -q torch torchvision tqdm pyyaml scikit-learn pandas matplotlib seaborn Pillow
    
    # Set project path
    import os
    PROJECT_PATH = "/content/drive/MyDrive/WOA7015 Advanced Machine Learning/data"
    os.chdir(PROJECT_PATH)
    print(f"‚úì Running on Colab - Path: {PROJECT_PATH}")
    
except ImportError:
    # Running locally
    PROJECT_PATH = None
    print("‚úì Running locally")

‚úì Running locally


## 1. Imports and Setup

In [3]:
import sys
import os
from pathlib import Path

# Setup paths
if 'PROJECT_PATH' in globals() and PROJECT_PATH:
    project_root = Path(PROJECT_PATH)
else:
    project_root = Path().absolute().parent

sys.path.insert(0, str(project_root))

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import yaml
from tqdm.notebook import tqdm

# Import our modules
from src.data.dataset import create_multimodal_dataloaders
from src.models.multimodal_model import create_multimodal_model
from src.training.multimodal_trainer import MultimodalVQATrainer
from src.evaluation.metrics import VQAMetrics, calculate_accuracy

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Imports successful")
print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

ModuleNotFoundError: No module named 'torch'

In [None]:
# Reload modules (important after fixing dataset.py)
import importlib
if 'src.data.dataset' in sys.modules:
    importlib.reload(sys.modules['src.data.dataset'])
if 'src.models.multimodal_model' in sys.modules:
    importlib.reload(sys.modules['src.models.multimodal_model'])
if 'src.training.multimodal_trainer' in sys.modules:
    importlib.reload(sys.modules['src.training.multimodal_trainer'])

print("‚úì Modules reloaded - dataset.py changes applied")

## 2. Configuration

In [None]:
# Load configuration
config_path = project_root / 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

# Update config for multimodal training
config['training']['batch_size'] = 16  # Smaller batch for multimodal
config['training']['num_epochs'] = 10
config['training']['learning_rate'] = 1e-4  # Lower LR for vision features
config['model']['vision_encoder'] = 'resnet50'
config['model']['fusion_strategy'] = 'concatenation'  # Start with simplest

print("\nMultimodal training config:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Vision encoder: {config['model']['vision_encoder']}")
print(f"  Fusion strategy: {config['model']['fusion_strategy']}")

## 3. Data Loading

Loading the PathVQA dataset with **images** using the new `MultimodalVQADataset`.

In [1]:
# Create multimodal dataloaders
dataset_path = project_root / 'data'
train_loader, val_loader, test_loader, vocab_size, num_classes, vocab, answer_to_idx = create_multimodal_dataloaders(
    train_csv=str(dataset_path / 'trainrenamed.csv'),
    test_csv=str(dataset_path / 'testrenamed.csv'),
    image_dir=str(dataset_path / 'train'),
    answers_file=str(dataset_path / 'answers.txt'),  # This was missing!
    batch_size=config['training']['batch_size'],
    val_split=0.1,
    num_workers=0,
    image_size=224
)

print(f"  Data loaded successfully")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Number of classes: {num_classes}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Get a sample batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  Questions: {sample_batch['question'].shape}")
print(f"  Images: {sample_batch['image'].shape}")  # [batch, 3, 224, 224]
print(f"  Answers: {sample_batch['answer'].shape}")

NameError: name 'project_root' is not defined

In [None]:
# Debug: Check answer vocabulary consistency
print("Answer vocabulary debug:")
print(f"  Number of classes: {num_classes}")
print(f"  Max answer index in answer_to_idx: {max(answer_to_idx.values()) if answer_to_idx else 'N/A'}")
print(f"  '<UNK>' token index: {answer_to_idx.get('<UNK>', 'Not present')}")

# Check a few samples from the training data
sample_batch = next(iter(train_loader))
print(f"\nSample batch answer indices:")
print(f"  Min answer index: {sample_batch['answer'].min().item()}")
print(f"  Max answer index: {sample_batch['answer'].max().item()}")
print(f"  Answer indices range should be 0 to {num_classes-1}")

# Verify all indices are within valid range
max_answer_in_batch = sample_batch['answer'].max().item()
if max_answer_in_batch >= num_classes:
    print(f"‚ùå ERROR: Found answer index {max_answer_in_batch} but only {num_classes} classes!")
    print("This will cause the IndexError during training")
else:
    print("‚úì All answer indices are within valid range")

## 4. Visualize Sample Images

Let's verify the images are loading correctly.

In [None]:
# Display 4 sample images with their questions and answers
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.ravel()

# Get batch
batch = next(iter(train_loader))

for i in range(4):
    # Denormalize image for display
    img = batch['image'][i].cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    # Get question and answer
    question = batch['question'][i].cpu().numpy()
    answer_idx = batch['answer'][i].item()
    
    # Decode question (first few words)
    question_text = ' '.join([vocab.idx_to_word.get(idx, '<UNK>') 
                               for idx in question[:15] if idx > 0])
    answer_text = vocab.idx_to_word.get(answer_idx, '<UNK>')
    
    # Plot
    axes[i].imshow(img)
    axes[i].set_title(f"Q: {question_text}...\nA: {answer_text}", fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print("‚úì Images loading correctly!")

## 5. Create Multimodal Model

Creating a multimodal VQA model with:
- **Vision Encoder**: ResNet50 CNN (extracts features from images)
- **Text Encoder**: LSTM (processes questions)
- **Fusion**: Concatenation strategy

In [None]:
# Create multimodal model
model = create_multimodal_model(
    model_type='concat',  # Valid options: 'concat', 'attention', 'bilinear', 'cross_attention'
    vocab_size=vocab_size,
    num_classes=num_classes,
    embedding_dim=config['text']['embedding_dim'],
    text_hidden_dim=config['model']['baseline']['hidden_dim'],
    fusion_hidden_dim=config['model']['baseline']['hidden_dim'],
    dropout=config['model']['baseline']['dropout']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Model created successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1024**2:.1f} MB")
print(f"\nModel architecture:")
print(model)

## 6. Test Forward Pass

Quick sanity check that the model works.

In [None]:
# Test forward pass
model.eval()
with torch.no_grad():
    test_batch = next(iter(train_loader))
    test_questions = test_batch['question'].to(device)
    test_images = test_batch['image'].to(device)
    test_answers = test_batch['answer'].to(device)
    
    outputs = model(test_questions, test_images)
    predictions = torch.argmax(outputs, dim=1)
    
    accuracy = (predictions == test_answers).float().mean()
    
    print(f"‚úì Forward pass successful")
    print(f"  Input questions shape: {test_questions.shape}")
    print(f"  Input images shape: {test_images.shape}")
    print(f"  Output shape: {outputs.shape}")
    print(f"  Random accuracy: {accuracy.item():.4f} (should be ~0.002 for random)")

model.train()

## 7. Setup Training

Initialize the trainer with early stopping and checkpointing.

In [None]:
# Create trainer
trainer = MultimodalVQATrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,  # Pass the full config dictionary
    device=device,
    checkpoint_dir=project_root / 'checkpoints',
    experiment_name='multimodal_concat'
)

print(f"‚úì Trainer initialized")
print(f"  Checkpoint directory: {trainer.checkpoint_dir}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Early stopping patience: 5 epochs")

## 8. Train Model

This will take some time (~1-2 hours on CPU, ~15-20 minutes on GPU).

In [None]:
# Train for 10 epochs (epochs already configured in config)
trainer.train()

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print(f"Best validation accuracy: {trainer.best_val_acc:.4f}")
print(f"Current epoch: {trainer.current_epoch + 1}")

## 9. Plot Training History

In [None]:
# Since trainer.train() doesn't return history, let's create a summary plot
# Based on your training results that showed steady improvement from 28.86% to 50.58%

# Create approximate training history from your actual results
epochs = list(range(1, 11))
train_acc_approx = [0.2798, 0.4039, 0.4634, 0.4891, 0.5111, 0.5277, 0.5456, 0.5603, 0.5705, 0.5778]
val_acc_approx = [0.2886, 0.4172, 0.4370, 0.4577, 0.4719, 0.4810, 0.4886, 0.4982, 0.5053, 0.5058]
train_loss_approx = [3.8770, 3.1919, 2.9653, 2.8185, 2.6852, 2.5848, 2.5007, 2.4280, 2.3859, 2.3555]
val_loss_approx = [3.5777, 3.3643, 3.2973, 3.3337, 3.3527, 3.2975, 3.3170, 3.3624, 3.3535, 3.3633]

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
axes[0].plot(epochs, train_loss_approx, label='Train Loss', marker='o', color='blue')
axes[0].plot(epochs, val_loss_approx, label='Val Loss', marker='s', color='orange')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Plot accuracy
axes[1].plot(epochs, train_acc_approx, label='Train Accuracy', marker='o', color='blue')
axes[1].plot(epochs, val_acc_approx, label='Val Accuracy', marker='s', color='orange')
axes[1].axhline(y=0.4736, color='r', linestyle='--', label='Text-only baseline (47.36%)')
axes[1].axhline(y=0.5058, color='g', linestyle=':', alpha=0.7, label='Final multimodal (50.58%)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(project_root / 'results' / 'figures' / 'multimodal_training_history.png', dpi=150)
plt.show()

print(f"‚úì Training history plot created")
print(f"‚úì Plot saved to results/figures/multimodal_training_history.png")
print(f"üìà Key achievements:")
print(f"  ‚Ä¢ Started at: 28.86% validation accuracy")
print(f"  ‚Ä¢ Ended at: 50.58% validation accuracy") 
print(f"  ‚Ä¢ Improvement over text baseline: +3.22 percentage points!")
print(f"  ‚Ä¢ Total training improvement: +21.72 percentage points from epoch 1!")

## 10. Evaluate on Test Set

In [None]:
# Load best model checkpoint
best_model_path = trainer.checkpoint_dir / 'best_model.pth'
if best_model_path.exists():
    trainer.load_checkpoint(str(best_model_path))
    print(f"‚úì Loaded best model from: {best_model_path}")
else:
    print(f"‚ö†Ô∏è Best model not found at: {best_model_path}")
    print("Using current model state for evaluation")

# Evaluate
test_loss, test_acc, all_preds, all_labels = trainer.evaluate(test_loader)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"\nComparison:")
print(f"  Text-only baseline: 47.36%")
print(f"  Multimodal (concat): {test_acc*100:.2f}%")
print(f"  Improvement: {(test_acc - 0.4736)*100:.2f} percentage points")
print(f"  Relative improvement: {((test_acc/0.4736 - 1)*100):.1f}%")

## 11. Analyze Predictions

In [None]:
# Get some test samples and predictions
model.eval()
test_batch = next(iter(test_loader))

with torch.no_grad():
    questions = test_batch['question'].to(device)
    images = test_batch['image'].to(device)
    true_answers = test_batch['answer'].to(device)
    
    outputs = model(questions, images)
    pred_answers = torch.argmax(outputs, dim=1)

# Visualize some predictions
fig, axes = plt.subplots(3, 3, figsize=(16, 14))
axes = axes.ravel()

for i in range(9):
    # Denormalize image
    img = images[i].cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    # Get question
    question = questions[i].cpu().numpy()
    question_text = ' '.join([vocab.idx_to_word.get(idx, '<UNK>') 
                               for idx in question[:20] if idx > 0])
    
    # Get answers
    true_ans = vocab.idx_to_word.get(true_answers[i].item(), '<UNK>')
    pred_ans = vocab.idx_to_word.get(pred_answers[i].item(), '<UNK>')
    
    # Determine if correct
    is_correct = true_answers[i].item() == pred_answers[i].item()
    color = 'green' if is_correct else 'red'
    
    # Plot
    axes[i].imshow(img)
    axes[i].set_title(
        f"Q: {question_text}...\n"
        f"True: {true_ans} | Pred: {pred_ans}",
        fontsize=9,
        color=color
    )
    axes[i].axis('off')

plt.tight_layout()
plt.savefig(project_root / 'results' / 'figures' / 'multimodal_predictions.png', dpi=150)
plt.show()

print(f"‚úì Predictions visualization saved")

## 12. Save Results

In [None]:
# Save test predictions
results_df = pd.DataFrame({
    'true_answer': all_labels,
    'predicted_answer': all_preds,
    'correct': (np.array(all_preds) == np.array(all_labels)).astype(int)
})

results_path = project_root / 'results' / 'predictions' / 'multimodal_concat_predictions.csv'
results_df.to_csv(results_path, index=False)

print(f"‚úì Results saved to {results_path}")
print(f"\nSummary:")
print(f"  Total predictions: {len(results_df)}")
print(f"  Correct: {results_df['correct'].sum()}")
print(f"  Incorrect: {len(results_df) - results_df['correct'].sum()}")
print(f"  Accuracy: {results_df['correct'].mean():.4f}")

## 13. Next Steps

üéØ **Achieved Goals:**
- ‚úÖ Implemented multimodal VQA with vision + text
- ‚úÖ Trained concatenation fusion model
- ‚úÖ Compared with text-only baseline (47.36%)

üìà **Potential Improvements:**

1. **Try different fusion strategies:**
   - Attention fusion (better weighting)
   - Bilinear fusion (richer interactions)
   - Cross-modal attention (full co-attention)

2. **Better vision encoder:**
   - Use AttentionVisionEncoder with spatial attention
   - Try different CNN backbones (ResNet101, EfficientNet)
   - Fine-tune vision encoder instead of freezing

3. **Data augmentation:**
   - Stronger image augmentation
   - Text augmentation (paraphrasing)

4. **Hyperparameter tuning:**
   - Learning rate scheduling
   - Different batch sizes
   - Gradient clipping

5. **Pre-trained models:**
   - Use CLIP or ViLT
   - Fine-tune BERT for questions