In [None]:
# Setup and imports
import sys
import os
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter
from torchvision import transforms, models
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import time

# Setup paths
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

# Import existing modules
from src.data.dataset import create_multimodal_dataloaders

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Project root: {project_root}")

In [None]:
# Create reduced vocabulary with top 1000 classes
def create_reduced_vocabulary(train_csv_path, top_n=1000):
    """Create vocabulary with only top N most frequent answers"""
    
    # Load training data
    train_df = pd.read_csv(train_csv_path)
    
    # Count answer frequencies
    answer_counts = Counter(train_df['answer'])
    
    # Get top N answers
    top_answers = [answer for answer, count in answer_counts.most_common(top_n)]
    
    # Create mapping
    answer_to_idx = {answer: idx for idx, answer in enumerate(top_answers)}
    answer_to_idx['<UNK>'] = len(answer_to_idx)
    
    # Save reduced answers file
    dataset_path = project_root / 'data'
    reduced_answers_path = dataset_path / f'answers_top_{top_n}.txt'
    
    with open(reduced_answers_path, 'w') as f:
        for answer in top_answers:
            f.write(f"{answer}\n")
        f.write("<UNK>\n")
    
    coverage = sum(answer_counts[ans] for ans in top_answers) / sum(answer_counts.values()) * 100
    
    print(f"Created reduced vocabulary:")
    print(f"  Classes: {len(answer_to_idx)}")
    print(f"  Coverage: {coverage:.1f}% of training data")
    print(f"  Saved to: {reduced_answers_path}")
    
    return str(reduced_answers_path), answer_to_idx

# Create reduced vocabulary
dataset_path = project_root / 'data'
reduced_answers_file, reduced_answer_to_idx = create_reduced_vocabulary(
    train_csv_path=str(dataset_path / 'trainrenamed.csv'),
    top_n=1000
)

In [None]:
# Enhanced data transforms with augmentation
def get_enhanced_transforms():
    """Enhanced data augmentation"""
    
    train_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
        transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.1, scale=(0.02, 0.08))
    ])
    
    val_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transforms, val_transforms

train_transforms, val_transforms = get_enhanced_transforms()
print("Enhanced data transforms created with augmentation")

In [None]:
# Load data with reduced vocabulary
train_loader, val_loader, test_loader, vocab_size, num_classes, vocab, answer_to_idx = create_multimodal_dataloaders(
    train_csv=str(dataset_path / 'trainrenamed.csv'),
    test_csv=str(dataset_path / 'testrenamed.csv'),
    image_dir=str(dataset_path / 'train'),
    answers_file=reduced_answers_file,  # Use reduced vocabulary
    batch_size=32,  # Increased batch size
    val_split=0.1,
    num_workers=2,
    image_size=224
)

print(f"Data loaded with improvements:")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Number of classes: {num_classes}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Batch size increased to: 32")

In [None]:
# Improved multimodal model
class ImprovedMultimodalVQA(nn.Module):
    """Enhanced multimodal VQA with optimizations"""
    
    def __init__(self, vocab_size, num_classes, embedding_dim=300, 
                 text_hidden_dim=512, fusion_hidden_dim=512, dropout=0.1):
        super().__init__()
        self.num_classes = num_classes
        
        # Text encoder
        self.text_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.text_lstm = nn.LSTM(embedding_dim, text_hidden_dim, 
                                batch_first=True, bidirectional=True, dropout=0.1)
        self.text_dropout = nn.Dropout(dropout)
        
        # Vision encoder (trainable)
        self.vision_encoder = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.vision_encoder = nn.Sequential(*list(self.vision_encoder.children())[:-2])
        
        # Simplified spatial attention
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2048, 256, 1),
            nn.ReLU(),
            nn.Conv2d(256, 1, 1),
            nn.Sigmoid()
        )
        
        # Cross-modal fusion
        self.vision_proj = nn.Linear(2048, fusion_hidden_dim)
        self.text_proj = nn.Linear(text_hidden_dim * 2, fusion_hidden_dim)
        
        # Reduced attention heads for efficiency
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=fusion_hidden_dim, 
            num_heads=4,  # Reduced from 8
            dropout=dropout
        )
        
        # Enhanced classifier
        self.classifier = nn.Sequential(
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_dim // 2, num_classes)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Xavier initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.1)
    
    def forward(self, questions, images):
        # Text processing
        text_embedded = self.text_embedding(questions)
        text_output, (text_hidden, _) = self.text_lstm(text_embedded)
        text_features = torch.cat([text_hidden[0], text_hidden[1]], dim=1)
        text_features = self.text_dropout(text_features)
        
        # Vision processing with spatial attention
        vision_maps = self.vision_encoder(images)
        attention_weights = self.spatial_attention(vision_maps)
        attended_vision = vision_maps * attention_weights
        vision_features = nn.functional.adaptive_avg_pool2d(attended_vision, 1).squeeze()
        
        # Handle batch dimension
        if len(vision_features.shape) == 1:
            vision_features = vision_features.unsqueeze(0)
        
        # Project and fuse
        vision_proj = self.vision_proj(vision_features).unsqueeze(1)
        text_proj = self.text_proj(text_features).unsqueeze(1)
        
        # Cross-modal attention
        attended_features, _ = self.cross_attention(
            query=text_proj.transpose(0, 1),
            key=vision_proj.transpose(0, 1),
            value=vision_proj.transpose(0, 1)
        )
        
        fused_features = attended_features.transpose(0, 1).squeeze(1)
        logits = self.classifier(fused_features)
        
        return logits

# Create improved model
model = ImprovedMultimodalVQA(
    vocab_size=vocab_size,
    num_classes=num_classes,
    embedding_dim=300,
    text_hidden_dim=512,
    fusion_hidden_dim=512,
    dropout=0.1  # Reduced dropout
).to(device)

print(f"Improved model created:")
print(f"  Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"  Reduced dropout: 0.1")
print(f"  Simplified attention: 4 heads")

In [None]:
# Enhanced training setup
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, alpha=1, gamma=2, label_smoothing=0.1):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(
            inputs, targets, 
            reduction='none',
            label_smoothing=self.label_smoothing
        )
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

# Setup enhanced training
vision_params = []
other_params = []

for name, param in model.named_parameters():
    if 'vision_encoder' in name:
        vision_params.append(param)
    else:
        other_params.append(param)

# Enhanced optimizer with higher learning rates
optimizer = optim.AdamW([
    {'params': vision_params, 'lr': 1e-5, 'weight_decay': 1e-4},    # Higher for vision
    {'params': other_params, 'lr': 1e-4, 'weight_decay': 1e-4}     # Higher for new layers  
], betas=(0.9, 0.999))

# More stable scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8, 15], gamma=0.3)

# Focal loss with label smoothing
criterion = FocalLoss(alpha=1, gamma=2, label_smoothing=0.1)

print("Enhanced training setup:")
print(f"  Vision parameters: {len(vision_params):,}")
print(f"  Other parameters: {len(other_params):,}")
print(f"  Vision LR: 1e-5")
print(f"  Other LR: 1e-4") 
print(f"  Using Focal Loss with label smoothing")
print(f"  MultiStepLR scheduler")

In [None]:
# Enhanced training loop
def train_epoch_improved(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epoch, total_epochs):
    """Enhanced training loop"""
    
    model.train()
    train_losses = []
    train_correct = 0
    train_total = 0
    
    print(f"Epoch {epoch}/{total_epochs}")
    print("-" * 50)
    
    for batch_idx, batch in enumerate(train_loader):
        # Handle dictionary format
        if isinstance(batch, dict):
            questions = batch['question'].to(device)
            images = batch['image'].to(device)
            answers = batch['answer'].to(device)
        else:
            questions, images, answers = batch
            questions, images, answers = questions.to(device), images.to(device), answers.to(device)
        
        optimizer.zero_grad()
        outputs = model(questions, images)
        loss = criterion(outputs, answers)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        train_losses.append(loss.item())
        _, predicted = torch.max(outputs.data, 1)
        train_total += answers.size(0)
        train_correct += (predicted == answers).sum().item()
        
        if batch_idx % 30 == 0:  # More frequent updates
            current_acc = 100. * train_correct / train_total
            print(f"Batch {batch_idx:3d}/{len(train_loader):3d} | Loss: {loss.item():.4f} | Acc: {current_acc:.2f}%")
    
    # Calculate averages
    avg_train_loss = sum(train_losses) / len(train_losses)
    train_accuracy = 100. * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_losses = []
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            if isinstance(batch, dict):
                questions = batch['question'].to(device)
                images = batch['image'].to(device)
                answers = batch['answer'].to(device)
            else:
                questions, images, answers = batch
                questions, images, answers = questions.to(device), images.to(device), answers.to(device)
            
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            
            val_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            val_total += answers.size(0)
            val_correct += (predicted == answers).sum().item()
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    val_accuracy = 100. * val_correct / val_total
    
    # Update scheduler
    scheduler.step()
    
    print(f"\nEpoch {epoch} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_accuracy:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
    print("=" * 50)
    
    return {
        'train_loss': avg_train_loss,
        'train_accuracy': train_accuracy,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    }

print("Enhanced training function ready")
print("This should show significant improvement over the current 25% plateau")

In [None]:
# Run improved training
EPOCHS = 20
history = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': []
}

best_val_acc = 0.0
best_model_state = None

print("Starting improved multimodal training...")
print(f"Training for {EPOCHS} epochs with all optimizations")
print(f"Expected target: 35-45% accuracy (vs current 25%)")

start_time = time.time()

try:
    for epoch in range(1, EPOCHS + 1):
        epoch_results = train_epoch_improved(
            model, train_loader, val_loader, 
            optimizer, criterion, scheduler, 
            device, epoch, EPOCHS
        )
        
        # Store history
        history['train_losses'].append(epoch_results['train_loss'])
        history['train_accuracies'].append(epoch_results['train_accuracy'])
        history['val_losses'].append(epoch_results['val_loss'])
        history['val_accuracies'].append(epoch_results['val_accuracy'])
        
        # Save best model
        if epoch_results['val_accuracy'] > best_val_acc:
            best_val_acc = epoch_results['val_accuracy']
            best_model_state = model.state_dict().copy()
            print(f"New best validation accuracy: {best_val_acc:.2f}%")
            
            # Save checkpoint
            checkpoint_path = project_root / 'checkpoints' / 'improved_multimodal' / 'best_model.pth'
            checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history,
                'num_classes': num_classes
            }, checkpoint_path)

except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except Exception as e:
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()

training_time = time.time() - start_time
print(f"\nTraining completed in {training_time/60:.2f} minutes")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Best model loaded for evaluation")

In [None]:
# Visualize improved training results
if len(history['train_losses']) > 0:
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs_range = range(1, len(history['train_losses']) + 1)
    
    # Loss curves
    ax1.plot(epochs_range, history['train_losses'], 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs_range, history['val_losses'], 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves  
    ax2.plot(epochs_range, history['train_accuracies'], 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs_range, history['val_accuracies'], 'r-', label='Validation Accuracy', linewidth=2)
    ax2.axhline(y=47.36, color='g', linestyle='--', label='Text Baseline (47.36%)', alpha=0.7)
    ax2.axhline(y=25, color='orange', linestyle=':', label='Previous Best (25%)', alpha=0.7)
    ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Improvement tracking
    max_val_acc = max(history['val_accuracies'])
    improvement_from_previous = max_val_acc - 25.0
    
    ax3.bar(['Previous Best', 'Improved Model'], [25.0, max_val_acc], 
           color=['orange', 'green'], alpha=0.7)
    ax3.axhline(y=47.36, color='red', linestyle='--', alpha=0.7, label='Target (47.36%)')
    ax3.set_title('Performance Comparison', fontsize=14, fontweight='bold')
    ax3.set_ylabel('Accuracy (%)')
    ax3.legend()
    
    # Summary statistics
    ax4.axis('off')
    stats_text = f"""Improved Training Results:
    
Final Validation Accuracy: {history['val_accuracies'][-1]:.2f}%
Best Validation Accuracy: {max_val_acc:.2f}%
Improvement over Previous: +{improvement_from_previous:.2f}%
Target Achievement: {max_val_acc/47.36*100:.1f}%
    
Total Training Epochs: {len(history['train_losses'])}
Final Training Loss: {history['train_losses'][-1]:.4f}
Final Validation Loss: {history['val_losses'][-1]:.4f}
    
Key Improvements Applied:
- Reduced classes: 4,142 → 1,000
- Enhanced data augmentation
- Higher learning rates
- Focal loss with label smoothing
- Optimized architecture"""
    
    ax4.text(0.1, 0.5, stats_text, fontsize=11, verticalalignment='center',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
    ax4.set_title('Training Summary', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print key results
    print("Improved Multimodal Training Results:")
    print("=" * 50)
    print(f"Best Validation Accuracy: {max_val_acc:.2f}%")
    print(f"Improvement: +{improvement_from_previous:.2f}% over previous")
    print(f"Target Progress: {max_val_acc/47.36*100:.1f}% toward 47.36% baseline")
    
    if max_val_acc > 47.36:
        print("SUCCESS: Exceeded text baseline!")
    elif max_val_acc > 35:
        print("GOOD: Significant improvement achieved")
    else:
        print("PROGRESS: Some improvement, may need further tuning")
        
else:
    print("No training history available - run training first")

## Key Improvements Summary

This improved training setup addresses the main bottlenecks:

1. **Class Reduction**: 4,142 → 1,000 classes (biggest impact)
2. **Enhanced Data Augmentation**: Better generalization 
3. **Higher Learning Rates**: Faster convergence
4. **Focal Loss**: Handles class imbalance
5. **Optimized Architecture**: 4 attention heads, reduced dropout
6. **Better Scheduling**: More stable MultiStepLR

**Expected Results**: 35-45% accuracy vs previous 25% plateau
**Target**: Beat 47.36% text baseline