In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
from collections import Counter
from sklearn.manifold import TSNE
import torch

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# ============================================================
# CONFIGURATION: Change this for different models
# ============================================================
MODEL_NAME = 'baseline'  # 'baseline', 'gcn', 'gat', 'focal_loss', etc.

# Directories
models_dir = Path('../models')
predictions_dir = Path('../results/predictions')
fig_dir = Path(f'../results/images/{MODEL_NAME}')
comparison_dir = Path('../results/images/ablation')

# Create directories
fig_dir.mkdir(parents=True, exist_ok=True)
comparison_dir.mkdir(parents=True, exist_ok=True)

print(f"{'='*60}")
print(f"  Visualization for: {MODEL_NAME.upper()}")
print(f"{'='*60}")
print(f"✓ Images will be saved to: {fig_dir}")
print(f"✓ Comparison plots to: {comparison_dir}\n")

## 1. Training Loss Curve

In [None]:
# Load training history
model_path = models_dir / MODEL_NAME / 'training_history.json'

if not model_path.exists():
    print(f"⚠️  Training history not found: {model_path}")
    print(f"   Available models: {[d.name for d in models_dir.iterdir() if d.is_dir()]}")
else:
    with open(model_path) as f:
        history = json.load(f)

    # Plot train and validation loss
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Training Loss
    axes[0].plot(history['epochs'], history['train_loss'], 
                 marker='o', linewidth=2.5, label='Training Loss', color='#2E86AB')
    if 'val_loss' in history:
        axes[0].plot(history['epochs'], history['val_loss'], 
                     marker='s', linewidth=2.5, label='Validation Loss', color='#A23B72')
    axes[0].set_xlabel('Epoch', fontsize=13)
    axes[0].set_ylabel('Loss', fontsize=13)
    axes[0].set_title(f'{MODEL_NAME.upper()} - Training Loss', fontsize=15, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Training Metrics (if available)
    if 'train_f1' in history:
        axes[1].plot(history['epochs'], history['train_f1'], 
                     marker='o', linewidth=2.5, label='Train F1', color='#2E86AB')
        if 'val_f1' in history:
            axes[1].plot(history['epochs'], history['val_f1'], 
                         marker='s', linewidth=2.5, label='Val F1', color='#A23B72')
        axes[1].set_xlabel('Epoch', fontsize=13)
        axes[1].set_ylabel('F1 Score', fontsize=13)
        axes[1].set_title(f'{MODEL_NAME.upper()} - F1 Score', fontsize=15, fontweight='bold')
        axes[1].legend(fontsize=11)
        axes[1].grid(True, alpha=0.3)
    else:
        axes[1].text(0.5, 0.5, 'F1 metrics not available', 
                     ha='center', va='center', fontsize=12, color='gray')
        axes[1].set_title('F1 Score (Not Available)', fontsize=15, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

    print(f"✓ Saved: {fig_dir / 'training_curves.png'}")

## 2. Multiple Models Comparison

**Note:** Change `MODEL_NAME` in the first cell to switch between models (baseline, gcn, gat, etc.)

In [None]:
# Compare multiple models
model_configs = [
    ('baseline', 'BERT Baseline'),
    ('gcn', 'GCN'),
    ('gat', 'GAT'),
    ('focal_loss', 'Focal Loss'),
]

# Use shared ablation folder for comparison plots
comparison_dir = Path('../results/images/ablation')
comparison_dir.mkdir(parents=True, exist_ok=True)

plt.figure(figsize=(12, 6))

for model_name, label in model_configs:
    history_path = models_dir / model_name / 'training_history.json'
    if history_path.exists():
        with open(history_path) as f:
            history = json.load(f)
        plt.plot(history['epochs'], history['train_loss'], marker='o', linewidth=2, label=label)
    else:
        print(f"⚠️  {model_name}: training_history.json not found")

plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Training Loss', fontsize=12)
plt.title('Model Comparison: Training Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(comparison_dir / 'model_comparison_loss.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Saved: {comparison_dir / 'model_comparison_loss.png'}")

## 3. Evaluation Metrics Table

In [None]:
# Load actual metrics from model checkpoints
results = []
model_configs = [
    ('baseline', 'BERT Baseline'),
    ('gcn', 'GCN'),
    ('gat', 'GAT'),
    ('focal_loss', 'Focal Loss'),
    ('asymmetric_loss', 'Asymmetric Loss'),
]

for model_name, label in model_configs:
    checkpoint_path = models_dir / model_name / 'best_model.pt'
    if checkpoint_path.exists():
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            
            # Extract metrics (adjust keys based on your checkpoint structure)
            metrics = checkpoint.get('metrics', {})
            
            results.append({
                'Model': label,
                'Micro F1': metrics.get('micro_f1', 0.0),
                'Macro F1': metrics.get('macro_f1', 0.0),
                'Precision': metrics.get('precision', 0.0),
                'Recall': metrics.get('recall', 0.0),
                'Final Loss': checkpoint.get('loss', 0.0),
                'Epoch': checkpoint.get('epoch', 0)
            })
            print(f"✓ Loaded metrics for {model_name}")
        except Exception as e:
            print(f"⚠️  Error loading {model_name}: {e}")
    else:
        print(f"⚠️  {model_name}: checkpoint not found")

if results:
    # Create DataFrame
    df = pd.DataFrame(results)
    
    print("\n" + "="*70)
    print("                MODEL PERFORMANCE COMPARISON")
    print("="*70)
    print(df.to_string(index=False))
    print("="*70 + "\n")
    
    # Save as CSV
    csv_path = comparison_dir / 'ablation_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"✓ Saved: {csv_path}")
else:
    print("⚠️  No model results found")

## 4. Metrics Bar Chart

In [None]:
# Visualize metrics comparison
if results:
    metrics = ['Micro F1', 'Macro F1', 'Precision', 'Recall']
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.flatten()
    
    colors = plt.cm.Set3(np.linspace(0, 1, len(df)))
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        bars = ax.bar(df['Model'], df[metric], alpha=0.8, edgecolor='black', linewidth=1.5, color=colors)
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.3f}',
                       ha='center', va='bottom', fontsize=10, fontweight='bold')
        
        ax.set_ylabel(metric, fontsize=12, fontweight='bold')
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.tick_params(axis='x', rotation=25, labelsize=10)
        ax.set_ylim(0, min(1.0, df[metric].max() * 1.2))
        ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(comparison_dir / 'metrics_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {comparison_dir / 'metrics_comparison.png'}")
else:
    print("⚠️  No results to plot")

## 5. Confusion Matrix / Label Distribution

In [None]:
# Load predictions for current model
pred_files = list(predictions_dir.glob(f'{MODEL_NAME}_*.pkl'))

if pred_files:
    # Load most recent prediction file
    pred_file = sorted(pred_files)[-1]
    print(f"Loading: {pred_file.name}")
    
    with open(pred_file, 'rb') as f:
        results = pickle.load(f)
    
    predictions = results['predictions']
    
    # Count labels per sample
    label_counts = [len(p) for p in predictions]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Histogram
    axes[0].hist(label_counts, bins=range(1, max(label_counts)+2), 
                 alpha=0.8, edgecolor='black', color='#F18F01')
    axes[0].axvline(np.mean(label_counts), color='red', linestyle='--', 
                    linewidth=2, label=f'Mean: {np.mean(label_counts):.2f}')
    axes[0].set_xlabel('Number of Labels per Sample', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('Frequency', fontsize=13, fontweight='bold')
    axes[0].set_title(f'{MODEL_NAME.upper()} - Label Distribution', fontsize=15, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(axis='y', alpha=0.3)
    
    # Box plot
    axes[1].boxplot([label_counts], vert=True, patch_artist=True,
                    boxprops=dict(facecolor='#F18F01', alpha=0.7),
                    medianprops=dict(color='red', linewidth=2),
                    whiskerprops=dict(linewidth=1.5),
                    capprops=dict(linewidth=1.5))
    axes[1].set_ylabel('Number of Labels', fontsize=13, fontweight='bold')
    axes[1].set_title('Label Count Distribution', fontsize=15, fontweight='bold')
    axes[1].set_xticklabels([MODEL_NAME.upper()])
    axes[1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'label_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'label_distribution.png'}")
    print(f"\nStatistics:")
    print(f"  Mean: {np.mean(label_counts):.2f}")
    print(f"  Median: {np.median(label_counts):.1f}")
    print(f"  Std: {np.std(label_counts):.2f}")
    print(f"  Min/Max: {min(label_counts)} / {max(label_counts)}")
else:
    print(f"⚠️  No prediction files found for {MODEL_NAME} in {predictions_dir}")

## 6. Top Predicted Classes

In [None]:
if pred_files:
    # Count class frequencies
    all_labels = [label for pred in predictions for label in pred]
    label_freq = Counter(all_labels)
    top_classes = label_freq.most_common(30)
    
    classes, counts = zip(*top_classes)
    
    plt.figure(figsize=(14, 10))
    bars = plt.barh(range(len(classes)), counts, alpha=0.8, edgecolor='black', 
                    color=plt.cm.viridis(np.linspace(0, 1, len(classes))))
    plt.yticks(range(len(classes)), [f'Class {c}' for c in classes], fontsize=10)
    plt.xlabel('Frequency', fontsize=13, fontweight='bold')
    plt.ylabel('Class ID', fontsize=13, fontweight='bold')
    plt.title(f'{MODEL_NAME.upper()} - Top 30 Predicted Classes', 
              fontsize=15, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.grid(axis='x', alpha=0.3, linestyle='--')
    
    # Add count labels
    for i, (bar, count) in enumerate(zip(bars, counts)):
        plt.text(count, i, f' {count}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'top_classes.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'top_classes.png'}")
    print(f"\nTotal unique classes predicted: {len(label_freq)}")
    print(f"Most frequent class: {classes[0]} ({counts[0]} times)")

## 7. t-SNE Visualization of Label Embeddings

In [None]:
# Load model to extract label embeddings
checkpoint_path = models_dir / MODEL_NAME / 'best_model.pt'

if checkpoint_path.exists():
    try:
        print(f"Loading model checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Extract classifier weights (label embeddings)
        # Adjust key based on your model structure
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        # Find classifier layer
        classifier_key = None
        for key in state_dict.keys():
            if 'classifier' in key.lower() and 'weight' in key:
                classifier_key = key
                break
        
        if classifier_key:
            label_embeddings = state_dict[classifier_key].cpu().numpy()
            print(f"✓ Found label embeddings: {label_embeddings.shape}")
            
            # Apply t-SNE
            print("Running t-SNE... (this may take a minute)")
            tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
            embeddings_2d = tsne.fit_transform(label_embeddings)
            
            # Visualize
            plt.figure(figsize=(14, 10))
            scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                                 c=range(len(embeddings_2d)), cmap='tab20',
                                 alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
            
            # Highlight some important classes
            top_class_ids = [c for c, _ in top_classes[:10]] if pred_files else []
            for cls_id in top_class_ids:
                if cls_id < len(embeddings_2d):
                    plt.scatter(embeddings_2d[cls_id, 0], embeddings_2d[cls_id, 1],
                              c='red', s=200, marker='*', edgecolors='black', 
                              linewidth=1.5, label=f'Top Class {cls_id}' if cls_id == top_class_ids[0] else '')
            
            plt.colorbar(scatter, label='Class ID')
            plt.xlabel('t-SNE Dimension 1', fontsize=13, fontweight='bold')
            plt.ylabel('t-SNE Dimension 2', fontsize=13, fontweight='bold')
            plt.title(f'{MODEL_NAME.upper()} - Label Embeddings (t-SNE)', 
                     fontsize=15, fontweight='bold')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(fig_dir / 'label_tsne.png', dpi=300, bbox_inches='tight')
            plt.show()
            
            print(f"✓ Saved: {fig_dir / 'label_tsne.png'}")
        else:
            print("⚠️  Classifier layer not found in checkpoint")
            print(f"   Available keys: {list(state_dict.keys())[:10]}...")
            
    except Exception as e:
        print(f"⚠️  Error loading model: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"⚠️  Checkpoint not found: {checkpoint_path}")

## 8. Confusion Matrix for Hierarchical Levels

In [None]:
# Hierarchical level confusion (if you have true labels)
# This requires validation/test labels

import sys
sys.path.append('..')

try:
    from src.data_preprocessing import DataLoader
    
    # Load data to get hierarchy information
    data_loader = DataLoader(data_dir='../data/raw/Amazon_products')
    data_loader.load_all()
    
    # Load class hierarchy
    hierarchy_file = Path('../data/raw/Amazon_products/class_hierarchy.txt')
    
    if hierarchy_file.exists():
        # Parse hierarchy to get level information
        class_levels = {}
        
        with open(hierarchy_file, 'r') as f:
            for line in f:
                if '->' in line:
                    parts = line.strip().split('->')
                    if len(parts) == 2:
                        parent, child = map(str.strip, parts)
                        # Assign levels (simplified: count depth)
                        class_levels[child] = class_levels.get(parent, 0) + 1
        
        if class_levels and pred_files:
            # Count predictions by hierarchy level
            level_counts = {}
            
            for pred in predictions:
                for cls_id in pred:
                    cls_name = str(cls_id)
                    level = class_levels.get(cls_name, 0)
                    level_counts[level] = level_counts.get(level, 0) + 1
            
            # Plot hierarchy level distribution
            levels = sorted(level_counts.keys())
            counts = [level_counts[l] for l in levels]
            
            plt.figure(figsize=(12, 6))
            bars = plt.bar(levels, counts, alpha=0.8, edgecolor='black', 
                          color=plt.cm.Spectral(np.linspace(0, 1, len(levels))))
            plt.xlabel('Hierarchy Level', fontsize=13, fontweight='bold')
            plt.ylabel('Number of Predictions', fontsize=13, fontweight='bold')
            plt.title(f'{MODEL_NAME.upper()} - Predictions by Hierarchy Level', 
                     fontsize=15, fontweight='bold')
            plt.xticks(levels)
            plt.grid(axis='y', alpha=0.3)
            
            # Add value labels
            for bar in bars:
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height,
                        f'{int(height)}',
                        ha='center', va='bottom', fontsize=11, fontweight='bold')
            
            plt.tight_layout()
            plt.savefig(fig_dir / 'hierarchy_level_distribution.png', dpi=300, bbox_inches='tight')
            plt.show()
            
            print(f"✓ Saved: {fig_dir / 'hierarchy_level_distribution.png'}")
        else:
            print("⚠️  Hierarchy information not available")
    else:
        print(f"⚠️  Hierarchy file not found: {hierarchy_file}")
        
except Exception as e:
    print(f"⚠️  Error analyzing hierarchy: {e}")

## 9. Summary

**All visualizations have been saved!**

### Model-Specific Images (`results/images/{MODEL_NAME}/`):
- `training_curves.png`: Train/Val loss and F1 curves
- `label_distribution.png`: Predicted label count distribution
- `top_classes.png`: Top 30 most frequent predicted classes
- `label_tsne.png`: t-SNE visualization of label embeddings
- `hierarchy_level_distribution.png`: Predictions by hierarchy level

### Cross-Model Comparisons (`results/images/ablation/`):
- `model_comparison_loss.png`: Training loss comparison across models
- `metrics_comparison.png`: F1, Precision, Recall comparison
- `ablation_results.csv`: Detailed metrics table

---

### To visualize a different model:
Change `MODEL_NAME` in the first cell to: `'baseline'`, `'gcn'`, `'gat'`, `'focal_loss'`, etc.

---

### Notes:
- Requires trained models in `models/{MODEL_NAME}/`
- Requires predictions in `results/predictions/{MODEL_NAME}_*.pkl`
- All images are saved at 300 DPI for publication quality