# Ablation Study Analysis

Amazon Product Classification - DATA304 Final Project

**Compare:** Model architectures (BERT/GCN/GAT), Loss functions (BCE/Focal/Hierarchical), Training strategies

## Setup

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import torch
from matplotlib.patches import Patch

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('Set2')

# Directories
models_dir = Path('../models')
fig_dir = Path('../results/images/ablation')
fig_dir.mkdir(parents=True, exist_ok=True)

print(f"{'='*70}")
print(f"              ABLATION STUDY ANALYSIS")
print(f"{'='*70}")
print(f"✓ Images will be saved to: {fig_dir}\n")

## 1. Model Architecture Comparison

In [None]:
# Define models to compare
model_configs = [
    ('baseline', 'BERT Baseline', '#2E86AB'),
    ('gcn', 'GCN + Hierarchy', '#A23B72'),
    ('gat', 'GAT + Hierarchy', '#F18F01'),
]

# Load training histories
histories = {}

for model_name, label, color in model_configs:
    history_path = models_dir / model_name / 'training_history.json'
    if history_path.exists():
        with open(history_path) as f:
            histories[model_name] = (json.load(f), label, color)
        print(f"✓ Loaded: {model_name}")
    else:
        print(f"⚠️  Not found: {model_name}")

print(f"\nLoaded {len(histories)} models for comparison")

In [None]:
# Plot training loss comparison
if histories:
    plt.figure(figsize=(14, 7))
    
    for model_name, (history, label, color) in histories.items():
        plt.plot(history['epochs'], history['train_loss'], 
                 marker='o', linewidth=2.5, label=label, color=color, markersize=8)
    
    plt.xlabel('Epoch', fontsize=14, fontweight='bold')
    plt.ylabel('Training Loss', fontsize=14, fontweight='bold')
    plt.title('Model Architecture Comparison - Training Loss', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12, loc='upper right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(fig_dir / 'architecture_comparison_loss.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'architecture_comparison_loss.png'}")
else:
    print("⚠️  No training histories found. Train models first!")

## 2. Loss Function Ablation

In [None]:
# Define loss function experiments
loss_configs = [
    ('baseline', 'BCE Loss', '#264653'),
    ('focal_loss', 'Focal Loss', '#2A9D8F'),
    ('asymmetric_loss', 'Asymmetric Loss', '#E76F51'),
]

# Load loss function experiments
loss_histories = {}

for model_name, label, color in loss_configs:
    history_path = models_dir / model_name / 'training_history.json'
    if history_path.exists():
        with open(history_path) as f:
            loss_histories[model_name] = (json.load(f), label, color)
        print(f"✓ Loaded: {model_name}")

if loss_histories:
    plt.figure(figsize=(14, 7))
    
    for model_name, (history, label, color) in loss_histories.items():
        plt.plot(history['epochs'], history['train_loss'], 
                 marker='s', linewidth=2.5, label=label, color=color, markersize=8)
    
    plt.xlabel('Epoch', fontsize=14, fontweight='bold')
    plt.ylabel('Training Loss', fontsize=14, fontweight='bold')
    plt.title('Loss Function Ablation Study', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12, loc='upper right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(fig_dir / 'loss_function_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'loss_function_comparison.png'}")
else:
    print("⚠️  No loss function experiments found")

## 3. Performance Metrics Comparison

In [None]:
# Load metrics from checkpoints
all_models = model_configs + loss_configs
results = []

for model_name, label, color in all_models:
    checkpoint_path = models_dir / model_name / 'best_model.pt'
    if checkpoint_path.exists():
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            metrics = checkpoint.get('metrics', {})
            
            results.append({
                'Model': label,
                'Type': 'Architecture' if model_name in ['baseline', 'gcn', 'gat'] else 'Loss Function',
                'Micro F1': metrics.get('micro_f1', 0.0),
                'Macro F1': metrics.get('macro_f1', 0.0),
                'Precision': metrics.get('precision', 0.0),
                'Recall': metrics.get('recall', 0.0),
                'Hamming Loss': metrics.get('hamming_loss', 0.0),
                'Final Loss': checkpoint.get('loss', 0.0),
                'Epoch': checkpoint.get('epoch', 0)
            })
            print(f"✓ Loaded metrics: {model_name}")
        except Exception as e:
            print(f"⚠️  Error loading {model_name}: {e}")

# Create DataFrame
if results:
    df = pd.DataFrame(results)
    
    print("\n" + "="*90)
    print("                      ABLATION STUDY RESULTS")
    print("="*90)
    print(df.to_string(index=False))
    print("="*90 + "\n")
    
    # Save CSV
    csv_path = fig_dir / 'ablation_results_detailed.csv'
    df.to_csv(csv_path, index=False)
    print(f"✓ Saved: {csv_path}")
else:
    print("⚠️  No results to display. Train models first!")
    df = None

In [None]:
# Visualize metrics comparison
if results and df is not None:
    metrics_to_plot = ['Micro F1', 'Macro F1', 'Precision', 'Recall']
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()
    
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes[idx]
        
        x = np.arange(len(df))
        width = 0.7
        
        # Color by type
        colors = ['#2E86AB' if t == 'Architecture' else '#E76F51' for t in df['Type']]
        bars = ax.bar(x, df[metric], width=width, alpha=0.8, edgecolor='black',
                      color=colors)
        
        # Add value labels
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.3f}', ha='center', va='bottom', 
                       fontsize=9, fontweight='bold')
        
        ax.set_ylabel(metric, fontsize=12, fontweight='bold')
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(df['Model'], rotation=25, ha='right', fontsize=10)
        ax.set_ylim(0, min(1.0, df[metric].max() * 1.2) if df[metric].max() > 0 else 1.0)
        ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # Add legend
    legend_elements = [
        Patch(facecolor='#2E86AB', label='Architecture'),
        Patch(facecolor='#E76F51', label='Loss Function')
    ]
    fig.legend(handles=legend_elements, loc='upper right', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'ablation_metrics_detailed.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'ablation_metrics_detailed.png'}")
else:
    print("⚠️  No results to plot")

## 4. Convergence Speed Analysis

In [None]:
# Analyze convergence speed (epochs to reach 95% of final performance)
if histories:
    convergence_data = []
    
    for model_name, (history, label, color) in histories.items():
        train_loss = history['train_loss']
        final_loss = train_loss[-1]
        initial_loss = train_loss[0]
        threshold = initial_loss - 0.95 * (initial_loss - final_loss)
        
        # Find convergence epoch
        convergence_epoch = len(train_loss)
        for idx, loss in enumerate(train_loss):
            if loss <= threshold:
                convergence_epoch = idx + 1
                break
        
        convergence_data.append({
            'Model': label,
            'Convergence Epoch': convergence_epoch,
            'Initial Loss': initial_loss,
            'Final Loss': final_loss,
            'Improvement': (initial_loss - final_loss) / initial_loss * 100
        })
    
    conv_df = pd.DataFrame(convergence_data)
    
    plt.figure(figsize=(12, 7))
    bars = plt.bar(conv_df['Model'], conv_df['Convergence Epoch'], 
                   alpha=0.8, edgecolor='black', color='#4ECDC4', width=0.6)
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}', ha='center', va='bottom', 
                fontsize=12, fontweight='bold')
    
    plt.ylabel('Epochs to 95% Convergence', fontsize=13, fontweight='bold')
    plt.title('Model Convergence Speed Comparison', fontsize=15, fontweight='bold')
    plt.xticks(rotation=15)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(fig_dir / 'convergence_speed.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'convergence_speed.png'}")
    print("\nConvergence Analysis:")
    print(conv_df.to_string(index=False))
else:
    print("⚠️  No convergence data available")

## 5. Summary Heatmap

In [None]:
# Create summary heatmap of all metrics
if results and df is not None:
    # Prepare data for heatmap
    heatmap_metrics = ['Micro F1', 'Macro F1', 'Precision', 'Recall']
    heatmap_data = df.set_index('Model')[heatmap_metrics].T
    
    plt.figure(figsize=(14, 8))
    sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='YlGnBu', 
                linewidths=1.5, linecolor='gray', cbar_kws={'label': 'Score'},
                vmin=0, vmax=1.0)
    plt.title('Ablation Study - Performance Heatmap', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('Metric', fontsize=13, fontweight='bold')
    plt.xlabel('Model', fontsize=13, fontweight='bold')
    plt.xticks(rotation=25, ha='right')
    plt.tight_layout()
    plt.savefig(fig_dir / 'ablation_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'ablation_heatmap.png'}")
else:
    print("⚠️  No data for heatmap")

## Summary

✅ All ablation study visualizations saved to: `results/images/ablation/`

**Generated Files:**
- `architecture_comparison_loss.png` - BERT vs GCN vs GAT training curves
- `loss_function_comparison.png` - BCE vs Focal vs Asymmetric loss comparison
- `ablation_metrics_detailed.png` - Detailed F1, Precision, Recall comparison
- `convergence_speed.png` - Epochs to reach 95% convergence
- `ablation_heatmap.png` - Performance summary heatmap
- `ablation_results_detailed.csv` - Complete results table

---

### Key Findings:

**Best Architecture:**
- Compare BERT baseline vs GCN vs GAT
- Analyze which model converges faster
- Evaluate final performance metrics

**Best Loss Function:**
- BCE: Standard binary cross-entropy baseline
- Focal Loss: Better for class imbalance
- Asymmetric Loss: Optimized for multi-label

**Training Efficiency:**
- Convergence speed comparison
- Training time vs performance trade-off

**Overall Recommendations:**
1. Choose model based on performance-speed trade-off
2. Consider loss function for specific dataset characteristics
3. Use hierarchy information if GCN/GAT shows improvement

---

### Next Steps:
1. Run detailed case studies on best-performing model
2. Analyze failure cases and error patterns
3. Fine-tune hyperparameters for production deployment