## 1. Setup and Configuration

In [None]:
# Environment setup
import sys
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
import json
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set up paths
current_dir = Path().absolute()
project_root = current_dir.parent if current_dir.name == 'notebooks' else current_dir
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"GPU available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Fine-tuning Configuration

In [None]:
# Fine-tuning strategy selector
FINE_TUNING_STRATEGY = "conservative"  # Options: conservative, layerwise, progressive, augmented, regularization, architecture

# Base configuration
FINE_TUNING_CONFIG = {
    "conservative": {
        "name": "Conservative Fine-tuning",
        "description": "Lower learning rates, maintain model stability",
        "learning_rate": 1e-6,
        "vision_lr_factor": 0.1,
        "epochs": 5,
        "batch_size": 12,
        "weight_decay": 1e-4,
        "label_smoothing": 0.05,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "light"
    },
    "layerwise": {
        "name": "Layer-wise Fine-tuning",
        "description": "Different learning rates for vision, text, and fusion components",
        "learning_rate": 5e-6,
        "vision_lr_factor": 0.05,
        "text_lr_factor": 0.3,
        "fusion_lr_factor": 1.0,
        "epochs": 8,
        "batch_size": 10,
        "weight_decay": 5e-4,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "medium"
    },
    "progressive": {
        "name": "Progressive Unfreezing",
        "description": "Gradually unfreeze layers during training",
        "learning_rate": 3e-6,
        "vision_lr_factor": 0.1,
        "epochs": 12,
        "batch_size": 12,
        "weight_decay": 1e-4,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_schedule": {1: ["vision_encoder.layer4"], 3: ["vision_encoder.layer3"], 5: []},
        "augmentation_strength": "medium"
    },
    "augmented": {
        "name": "Data Augmentation Enhanced",
        "description": "Strong data augmentation for better generalization",
        "learning_rate": 2e-6,
        "vision_lr_factor": 0.1,
        "epochs": 10,
        "batch_size": 10,
        "weight_decay": 1e-3,
        "label_smoothing": 0.15,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "strong"
    },
    "regularization": {
        "name": "Regularization Tuning",
        "description": "Optimize regularization parameters",
        "learning_rate": 3e-6,
        "vision_lr_factor": 0.1,
        "epochs": 8,
        "batch_size": 12,
        "weight_decay": 2e-3,
        "label_smoothing": 0.2,
        "dropout_factor": 1.2,
        "freeze_layers": [],
        "augmentation_strength": "medium"
    },
    "architecture": {
        "name": "Architecture Tweaking",
        "description": "Minor architectural modifications",
        "learning_rate": 5e-6,
        "vision_lr_factor": 0.1,
        "epochs": 10,
        "batch_size": 10,
        "weight_decay": 1e-4,
        "label_smoothing": 0.1,
        "dropout_factor": 1.0,
        "freeze_layers": [],
        "augmentation_strength": "medium",
        "add_batch_norm": True,
        "increase_attention_heads": True
    }
}

config = FINE_TUNING_CONFIG[FINE_TUNING_STRATEGY]
print(f"Selected Strategy: {config['name']}")
print(f"Description: {config['description']}")
print(f"Epochs: {config['epochs']}, Learning Rate: {config['learning_rate']}")

## 3. Load Data with Enhanced Augmentation

In [None]:
# Import data loading modules
from src.data.dataset import create_multimodal_dataloaders
from torchvision import transforms

def get_augmentation_transforms(strength="medium"):
    """Get data augmentation transforms based on strength level"""
    
    if strength == "light":
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.RandomRotation(5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    elif strength == "medium":
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    else:  # strong
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# Load data with enhanced augmentation
data_dir = project_root / 'data'
dataset_path = data_dir / 'train'

print(f"Loading data with {config['augmentation_strength']} augmentation...")

train_loader, val_loader, test_loader, vocab_size, num_classes, vocab, answer_to_idx = create_multimodal_dataloaders(
    train_csv=str(data_dir / 'trainrenamed.csv'),
    test_csv=str(data_dir / 'testrenamed.csv'),
    image_dir=str(dataset_path),
    answers_file=str(data_dir / 'answers.txt'),
    batch_size=config['batch_size'],
    val_split=0.1,
    num_workers=0,
    image_size=224,
    train_transform=get_augmentation_transforms(config['augmentation_strength'])
)

print(f"Data loaded: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test batches")
print(f"Vocabulary size: {vocab_size}, Classes: {num_classes}")

## 4. Load Best Model and Create Fine-tuned Version

In [None]:
# Import model classes
from improved_multimodal_model import ImprovedMultimodalVQA

def create_fine_tuned_model(base_model, config):
    """Create a fine-tuned version of the model with optional architectural tweaks"""
    
    if config.get('add_batch_norm', False):
        # Add batch normalization layers
        print("Adding batch normalization to classifier")
        base_model.classifier = nn.Sequential(
            nn.Linear(base_model.classifier[0].in_features, base_model.classifier[0].out_features),
            nn.BatchNorm1d(base_model.classifier[0].out_features),
            nn.ReLU(),
            nn.Dropout(base_model.classifier[2].p * config.get('dropout_factor', 1.0)),
            nn.Linear(base_model.classifier[4].in_features, base_model.classifier[4].out_features)
        )
    
    if config.get('increase_attention_heads', False):
        # Increase attention heads if possible
        print("Increasing attention heads to 16")
        base_model.cross_attention = nn.MultiheadAttention(
            embed_dim=512, 
            num_heads=16,  # Increased from 8
            dropout=0.3 * config.get('dropout_factor', 1.0)
        )
    
    # Adjust dropout in other layers if needed
    if config.get('dropout_factor', 1.0) != 1.0:
        factor = config['dropout_factor']
        base_model.text_dropout.p = min(0.5, base_model.text_dropout.p * factor)
        
        # Update classifier dropout if not already modified
        if not config.get('add_batch_norm', False):
            base_model.classifier[2].p = min(0.7, base_model.classifier[2].p * factor)
    
    return base_model

# Load the best checkpoint
checkpoint_dir = project_root / 'checkpoints' / 'multimodal_concat'
checkpoint_path = checkpoint_dir / 'best_model.pth'

if checkpoint_path.exists():
    print(f"Loading best model from {checkpoint_path}")
    
    # Create base model
    model = ImprovedMultimodalVQA(
        vocab_size=vocab_size,
        num_classes=num_classes,
        embedding_dim=300,
        text_hidden_dim=512,
        fusion_hidden_dim=512,
        dropout=0.3
    )
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Apply fine-tuning modifications
    model = create_fine_tuned_model(model, config)
    model = model.to(device)
    
    print(f"Model loaded successfully!")
    print(f"Original best validation accuracy from checkpoint: {checkpoint.get('best_val_acc', 'N/A')}")
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
else:
    raise FileNotFoundError(f"Best model checkpoint not found at {checkpoint_path}")
    print("Please run the main training notebook first to create the best model checkpoint.")

## 5. Setup Fine-tuning Optimizer and Training Strategy

In [None]:
def setup_fine_tuning_optimizer(model, config):
    """Setup optimizer with strategy-specific parameters"""
    
    # Categorize parameters
    vision_params = []
    text_params = []
    fusion_params = []
    
    for name, param in model.named_parameters():
        if any(x in name for x in ['vision_encoder', 'spatial_attention', 'vision_proj']):
            vision_params.append(param)
        elif any(x in name for x in ['text_embedding', 'text_lstm', 'text_proj', 'text_dropout']):
            text_params.append(param)
        else:
            fusion_params.append(param)
    
    print(f"Vision parameters: {len(vision_params)}")
    print(f"Text parameters: {len(text_params)}")
    print(f"Fusion parameters: {len(fusion_params)}")
    
    # Setup parameter groups based on strategy
    if FINE_TUNING_STRATEGY == "layerwise":
        param_groups = [
            {'params': vision_params, 'lr': config['learning_rate'] * config['vision_lr_factor'], 'name': 'vision'},
            {'params': text_params, 'lr': config['learning_rate'] * config.get('text_lr_factor', 0.3), 'name': 'text'},
            {'params': fusion_params, 'lr': config['learning_rate'] * config.get('fusion_lr_factor', 1.0), 'name': 'fusion'}
        ]
        print(f"Layer-wise LR: vision={param_groups[0]['lr']:.2e}, text={param_groups[1]['lr']:.2e}, fusion={param_groups[2]['lr']:.2e}")
    else:
        param_groups = [
            {'params': vision_params, 'lr': config['learning_rate'] * config['vision_lr_factor'], 'name': 'vision'},
            {'params': text_params + fusion_params, 'lr': config['learning_rate'], 'name': 'other'}
        ]
        print(f"Standard LR: vision={param_groups[0]['lr']:.2e}, other={param_groups[1]['lr']:.2e}")
    
    optimizer = optim.AdamW(param_groups, weight_decay=config['weight_decay'])
    
    return optimizer

def apply_layer_freezing(model, freeze_layers):
    """Freeze specified layers"""
    for layer_name in freeze_layers:
        for name, param in model.named_parameters():
            if layer_name in name:
                param.requires_grad = False
                print(f"Frozen: {name}")

# Setup optimizer
optimizer = setup_fine_tuning_optimizer(model, config)

# Apply initial layer freezing if specified
if config.get('freeze_layers', []):
    apply_layer_freezing(model, config['freeze_layers'])

# Setup loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
print(f"Using label smoothing: {config['label_smoothing']}")

# Setup scheduler
if FINE_TUNING_STRATEGY == "progressive":
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.5, total_iters=config['epochs']//2)
else:
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=max(2, config['epochs']//3), T_mult=2, eta_min=1e-7
    )

print(f"Setup complete for {config['name']} strategy")

## 6. Fine-tuning Training Loop with Advanced Monitoring

In [None]:
def fine_tune_epoch(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epoch, config):
    """Enhanced training epoch with fine-tuning considerations"""
    
    # Progressive unfreezing logic
    if FINE_TUNING_STRATEGY == "progressive" and 'freeze_schedule' in config:
        freeze_schedule = config['freeze_schedule']
        if epoch in freeze_schedule:
            # Unfreeze specified layers
            for name, param in model.named_parameters():
                if any(layer in name for layer in freeze_schedule[epoch]):
                    param.requires_grad = True
                    print(f"Unfrozen at epoch {epoch}: {name}")
    
    # Training phase
    model.train()
    train_losses = []
    train_correct = 0
    train_total = 0
    
    print(f"Fine-tuning Epoch {epoch}/{config['epochs']}")
    print("-" * 50)
    
    pbar = tqdm(train_loader, desc=f"Training")
    
    for batch_idx, (questions, images, answers) in enumerate(pbar):
        questions, images, answers = questions.to(device), images.to(device), answers.to(device)
        
        optimizer.zero_grad()
        outputs = model(questions, images)
        loss = criterion(outputs, answers)
        loss.backward()
        
        # Gentle gradient clipping for fine-tuning
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        
        optimizer.step()
        
        # Track metrics
        train_losses.append(loss.item())
        _, predicted = torch.max(outputs.data, 1)
        train_total += answers.size(0)
        train_correct += (predicted == answers).sum().item()
        
        # Update progress bar
        current_acc = 100. * train_correct / train_total
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{current_acc:.2f}%'})
    
    # Calculate epoch averages
    avg_train_loss = sum(train_losses) / len(train_losses)
    train_accuracy = 100. * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_losses = []
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc="Validating")
        for questions, images, answers in pbar_val:
            questions, images, answers = questions.to(device), images.to(device), answers.to(device)
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            
            val_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            val_total += answers.size(0)
            val_correct += (predicted == answers).sum().item()
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    val_accuracy = 100. * val_correct / val_total
    
    # Update scheduler
    scheduler.step()
    
    # Print epoch summary
    print(f"\nEpoch {epoch} Results:")
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_accuracy:.2f}%")
    print(f"Learning Rate: Vision={optimizer.param_groups[0]['lr']:.2e}, Other={optimizer.param_groups[-1]['lr']:.2e}")
    print("=" * 50)
    
    return {
        'train_loss': avg_train_loss,
        'train_accuracy': train_accuracy,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    }

print("Fine-tuning training function ready!")

## 7. Execute Fine-tuning

In [None]:
# Fine-tuning execution
print(f"Starting {config['name']}...")
print(f"Strategy: {config['description']}")
print(f"Epochs: {config['epochs']}, Batch Size: {config['batch_size']}")

# Initialize tracking
fine_tune_history = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': []
}

# Get baseline performance (original best model performance)
original_best_acc = checkpoint.get('best_val_acc', 55.39)  # Default to known best
print(f"Target: Improve upon {original_best_acc:.2f}% validation accuracy")

best_fine_tune_acc = 0.0
best_model_state = None
start_time = time.time()

try:
    for epoch in range(1, config['epochs'] + 1):
        # Fine-tune one epoch
        epoch_results = fine_tune_epoch(
            model, train_loader, val_loader,
            optimizer, criterion, scheduler,
            device, epoch, config
        )
        
        # Store history
        fine_tune_history['train_losses'].append(epoch_results['train_loss'])
        fine_tune_history['train_accuracies'].append(epoch_results['train_accuracy'])
        fine_tune_history['val_losses'].append(epoch_results['val_loss'])
        fine_tune_history['val_accuracies'].append(epoch_results['val_accuracy'])
        
        # Check for improvement
        current_val_acc = epoch_results['val_accuracy']
        if current_val_acc > best_fine_tune_acc:
            best_fine_tune_acc = current_val_acc
            best_model_state = model.state_dict().copy()
            improvement_vs_original = current_val_acc - original_best_acc
            print(f"New best fine-tuned accuracy: {current_val_acc:.2f}%")
            if improvement_vs_original > 0:
                print(f"Improvement over original: +{improvement_vs_original:.2f} pp")
            
            # Save fine-tuned checkpoint
            ft_checkpoint_dir = project_root / 'checkpoints' / f'fine_tuned_{FINE_TUNING_STRATEGY}'
            ft_checkpoint_dir.mkdir(exist_ok=True)
            ft_checkpoint_path = ft_checkpoint_dir / 'best_model.pth'
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_fine_tune_acc,
                'original_acc': original_best_acc,
                'fine_tune_history': fine_tune_history,
                'config': config,
                'strategy': FINE_TUNING_STRATEGY
            }, ft_checkpoint_path)
        
        # Early stopping check (very patient for fine-tuning)
        if epoch >= 5:  # Minimum epochs
            recent_accs = fine_tune_history['val_accuracies'][-3:]
            if len(recent_accs) >= 3 and all(acc < max(fine_tune_history['val_accuracies']) - 1.0 for acc in recent_accs):
                print(f"Early stopping triggered - no improvement for 3 epochs")
                break

except KeyboardInterrupt:
    print("\nFine-tuning interrupted by user")

training_time = time.time() - start_time
print(f"\nFine-tuning completed in {training_time/60:.2f} minutes")
print(f"Best fine-tuned validation accuracy: {best_fine_tune_acc:.2f}%")
print(f"Original best accuracy: {original_best_acc:.2f}%")
improvement = best_fine_tune_acc - original_best_acc
print(f"Overall improvement: {improvement:+.2f} percentage points")

# Load best fine-tuned model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Best fine-tuned model loaded for evaluation")

## 8. Comprehensive Evaluation and Comparison

In [None]:
# Comprehensive evaluation function
def evaluate_fine_tuned_model(model, test_loader, device, strategy_name):
    """Comprehensive evaluation of fine-tuned model"""
    model.eval()
    all_predictions = []
    all_targets = []
    test_loss = 0.0
    
    criterion = nn.CrossEntropyLoss()
    
    print(f"Evaluating {strategy_name} fine-tuned model on test set...")
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Testing")
        for questions, images, answers in pbar:
            questions, images, answers = questions.to(device), images.to(device), answers.to(device)
            
            outputs = model(questions, images)
            loss = criterion(outputs, answers)
            test_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(answers.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    avg_loss = test_loss / len(test_loader)
    
    return {
        'accuracy': accuracy,
        'avg_loss': avg_loss,
        'predictions': all_predictions,
        'targets': all_targets
    }

# Evaluate fine-tuned model
ft_results = evaluate_fine_tuned_model(model, test_loader, device, config['name'])

print("\n" + "=" * 60)
print(f"FINE-TUNING RESULTS - {config['name'].upper()}")
print("=" * 60)
print(f"Test Accuracy: {ft_results['accuracy']:.4f} ({ft_results['accuracy']*100:.2f}%)")
print(f"Test Loss: {ft_results['avg_loss']:.4f}")

# Load original results for comparison
baseline_results_path = project_root / 'results' / 'text_baseline_results.json'
original_multimodal_path = project_root / 'results' / 'improved_multimodal_results.json'

print("\nPerformance Comparison:")
print("-" * 30)

if baseline_results_path.exists():
    with open(baseline_results_path, 'r') as f:
        baseline_results = json.load(f)
    baseline_acc = baseline_results['accuracy']
    print(f"Text Baseline:        {baseline_acc:.4f} ({baseline_acc*100:.2f}%)")
    
    improvement_vs_baseline = (ft_results['accuracy'] - baseline_acc) * 100
    print(f"vs Text Baseline:     {improvement_vs_baseline:+.2f} pp")

if original_multimodal_path.exists():
    with open(original_multimodal_path, 'r') as f:
        original_results = json.load(f)
    original_acc = original_results['test_metrics']['accuracy']
    print(f"Original Multimodal:  {original_acc:.4f} ({original_acc*100:.2f}%)")
    
    improvement_vs_original = (ft_results['accuracy'] - original_acc) * 100
    print(f"vs Original:          {improvement_vs_original:+.2f} pp")

print(f"Fine-tuned Model:     {ft_results['accuracy']:.4f} ({ft_results['accuracy']*100:.2f}%)")
print("=" * 60)

## 9. Visualization and Analysis

In [None]:
# Training history visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, len(fine_tune_history['train_losses']) + 1)

# Loss curves
ax1.plot(epochs_range, fine_tune_history['train_losses'], 'b-', label='Training Loss', linewidth=2)
ax1.plot(epochs_range, fine_tune_history['val_losses'], 'r-', label='Validation Loss', linewidth=2)
ax1.set_title(f'Fine-tuning Loss Curves - {config["name"]}', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(epochs_range, fine_tune_history['train_accuracies'], 'b-', label='Training Accuracy', linewidth=2)
ax2.plot(epochs_range, fine_tune_history['val_accuracies'], 'r-', label='Validation Accuracy', linewidth=2)
ax2.axhline(y=original_best_acc, color='g', linestyle='--', alpha=0.7, label=f'Original Best ({original_best_acc:.1f}%)')
ax2.set_title('Fine-tuning Accuracy Curves', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Performance comparison bar chart
models = ['Text Baseline', 'Original Multimodal', f'Fine-tuned\n({FINE_TUNING_STRATEGY})']
accuracies = []

if baseline_results_path.exists():
    accuracies.append(baseline_acc * 100)
else:
    accuracies.append(47.36)  # Known baseline

if original_multimodal_path.exists():
    accuracies.append(original_acc * 100)
else:
    accuracies.append(55.39)  # Known best

accuracies.append(ft_results['accuracy'] * 100)

bars = ax3.bar(models, accuracies, color=['skyblue', 'lightgreen', 'lightcoral'], alpha=0.8)

# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height + 0.3,
            f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold')

ax3.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
ax3.set_ylabel('Accuracy (%)')
ax3.set_ylim(0, max(accuracies) + 5)
ax3.grid(True, alpha=0.3, axis='y')

# Fine-tuning strategy summary
ax4.axis('off')
summary_text = f"""
Fine-tuning Strategy: {config['name']}

Configuration:
• Learning Rate: {config['learning_rate']:.2e}
• Vision LR Factor: {config['vision_lr_factor']}
• Epochs: {config['epochs']}
• Label Smoothing: {config['label_smoothing']}
• Augmentation: {config['augmentation_strength']}

Results:
• Best Val Acc: {best_fine_tune_acc:.2f}%
• Test Accuracy: {ft_results['accuracy']*100:.2f}%
• Improvement: {improvement:+.2f} pp
• Training Time: {training_time/60:.1f} min
"""

ax4.text(0.05, 0.95, summary_text, fontsize=11, verticalalignment='top',
         bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))
ax4.set_title('Fine-tuning Summary', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Save visualization
results_dir = project_root / 'results' / 'figures'
results_dir.mkdir(exist_ok=True, parents=True)
fig.savefig(results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.png', dpi=300, bbox_inches='tight')
print(f"\nVisualization saved to: {results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.png'}")

## 10. Save Fine-tuning Results and Generate Report

In [None]:
# Save comprehensive fine-tuning results
results_dir = project_root / 'results'
results_dir.mkdir(exist_ok=True)

fine_tuning_results = {
    'strategy': FINE_TUNING_STRATEGY,
    'strategy_name': config['name'],
    'strategy_description': config['description'],
    'configuration': config,
    'training_history': fine_tune_history,
    'results': {
        'best_validation_accuracy': best_fine_tune_acc / 100,
        'test_accuracy': ft_results['accuracy'],
        'test_loss': ft_results['avg_loss'],
        'training_time_minutes': training_time / 60,
        'total_epochs': len(fine_tune_history['train_losses'])
    },
    'improvements': {
        'vs_original_multimodal': improvement,
        'vs_text_baseline': improvement_vs_baseline if 'improvement_vs_baseline' in locals() else None
    },
    'model_info': {
        'total_parameters': total_params,
        'trainable_parameters': trainable_params
    }
}

# Save results
results_file = results_dir / f'fine_tuning_{FINE_TUNING_STRATEGY}_results.json'
with open(results_file, 'w') as f:
    json.dump(fine_tuning_results, f, indent=2)

print(f"Fine-tuning results saved to: {results_file}")

# Generate confusion matrix if we have good performance
if ft_results['accuracy'] > 0.5:  # Only if accuracy is decent
    plt.figure(figsize=(10, 8))
    
    # Use a subset of classes for cleaner visualization
    unique_targets = sorted(set(ft_results['targets']))
    if len(unique_targets) > 20:  # Too many classes for clean visualization
        # Show top 20 most common classes
        from collections import Counter
        target_counts = Counter(ft_results['targets'])
        top_classes = [cls for cls, _ in target_counts.most_common(20)]
        
        # Filter data for top classes only
        filtered_targets = []
        filtered_preds = []
        for i, target in enumerate(ft_results['targets']):
            if target in top_classes:
                filtered_targets.append(target)
                filtered_preds.append(ft_results['predictions'][i])
        
        cm = confusion_matrix(filtered_targets, filtered_preds, labels=top_classes)
        title_suffix = " (Top 20 Classes)"
    else:
        cm = confusion_matrix(ft_results['targets'], ft_results['predictions'])
        title_suffix = ""
    
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', cbar=True)
    plt.title(f'Fine-tuned Model Confusion Matrix{title_suffix}\n{config["name"]}')
    plt.xlabel('Predicted Class')
    plt.ylabel('True Class')
    plt.tight_layout()
    
    # Save confusion matrix
    cm_path = results_dir / 'figures' / f'fine_tuned_{FINE_TUNING_STRATEGY}_confusion_matrix.png'
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Confusion matrix saved to: {cm_path}")

# Print final summary
print("\n" + "=" * 70)
print(f"FINE-TUNING COMPLETED: {config['name'].upper()}")
print("=" * 70)
print(f"Strategy Description: {config['description']}")
print(f"Training Time: {training_time/60:.2f} minutes")
print(f"Best Validation Accuracy: {best_fine_tune_acc:.2f}%")
print(f"Test Accuracy: {ft_results['accuracy']*100:.2f}%")
print(f"Improvement over Original: {improvement:+.2f} percentage points")

if improvement > 0:
    print("\nCONGRATULATIONS! Fine-tuning was successful!")
    print(f"Your model has improved by {improvement:.2f} percentage points.")
else:
    print("\nFine-tuning did not improve performance.")
    print("Consider trying a different strategy or adjusting hyperparameters.")

print(f"\nAll results and checkpoints saved in:")
print(f"- Results: {results_file}")
print(f"- Checkpoint: {ft_checkpoint_path}")
print(f"- Visualizations: {results_dir / 'figures'}")
print("=" * 70)

## 11. Strategy Comparison and Recommendations

**Available Fine-tuning Strategies:**

1. **Conservative:** Safe, minimal changes with very low learning rates
2. **Layerwise:** Different learning rates for different model components  
3. **Progressive:** Gradually unfreeze layers during training
4. **Augmented:** Enhanced data augmentation for better generalization
5. **Regularization:** Optimize dropout and weight decay parameters
6. **Architecture:** Minor architectural modifications

**To try different strategies:**
1. Change `FINE_TUNING_STRATEGY` at the top of this notebook
2. Re-run cells 2 onwards
3. Compare results across strategies

**Next Steps:**
- Try multiple strategies and compare results
- Ensemble the best fine-tuned models
- Use the best model for inference applications