# Generator Checkpoint Analysis

This notebook helps you explore and analyze generator checkpoints saved during training.

## Features:
- Load and inspect checkpoint metadata
- Visualize generator architecture
- Compare checkpoints across rounds
- Analyze training statistics

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import glob
from collections import defaultdict
import json

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Libraries imported successfully")

## 1. Configuration

In [None]:
# Configure checkpoint directory
CHECKPOINT_DIR = "checkpoints/generators"

# You can also specify a specific checkpoint file
# CHECKPOINT_FILE = "checkpoints/generators/client_generator_node0.pt"

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"Directory exists: {Path(CHECKPOINT_DIR).exists()}")

## 2. Find All Checkpoints

In [None]:
def find_checkpoints(checkpoint_dir):
    """Find all checkpoint files in the directory."""
    checkpoint_files = glob.glob(f"{checkpoint_dir}/**/*.pt", recursive=True)
    return sorted(checkpoint_files)

checkpoints = find_checkpoints(CHECKPOINT_DIR)

print(f"Found {len(checkpoints)} checkpoint(s):\n")
for i, ckpt in enumerate(checkpoints, 1):
    print(f"{i}. {ckpt}")

## 3. Load and Inspect Checkpoint Metadata

In [None]:
def load_checkpoint_metadata(checkpoint_path):
    """Load checkpoint and extract metadata."""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    metadata = {}
    
    # Extract metadata fields
    metadata_fields = [
        'client_id', 'node_id', 'round', 'generator_type', 'generator_granularity',
        'generator_key', 'diffusion_type', 'visual_dim', 'input_dim', 'hidden_dim',
        'latent_dim', 'sequence_length', 'dataset_name', 'selected_classes',
        'generator_classes', 'training_samples', 'training_epochs',
        'final_loss', 'timestamp', 'device'
    ]
    
    for field in metadata_fields:
        metadata[field] = checkpoint.get(field, None)
    
    # Check for state dicts
    metadata['has_generator_state'] = 'generator_state_dict' in checkpoint
    metadata['has_optimizer_state'] = 'optimizer_state_dict' in checkpoint
    
    # Get state dict info if available
    if metadata['has_generator_state']:
        state_dict = checkpoint['generator_state_dict']
        metadata['num_parameters'] = sum(p.numel() for p in state_dict.values())
        metadata['parameter_keys'] = list(state_dict.keys())
    
    return checkpoint, metadata

# Load first checkpoint as example
if checkpoints:
    checkpoint_path = checkpoints[0]
    print(f"Loading: {checkpoint_path}\n")
    
    checkpoint, metadata = load_checkpoint_metadata(checkpoint_path)
    
    print("=" * 80)
    print("CHECKPOINT METADATA")
    print("=" * 80)
    
    for key, value in metadata.items():
        if key != 'parameter_keys':  # Skip long list
            print(f"{key:25s}: {value}")
    
    print("\n" + "=" * 80)

## 4. Analyze All Checkpoints

In [None]:
def analyze_all_checkpoints(checkpoint_files):
    """Analyze all checkpoints and create a summary dataframe."""
    records = []
    
    for ckpt_file in checkpoint_files:
        try:
            _, metadata = load_checkpoint_metadata(ckpt_file)
            metadata['checkpoint_file'] = Path(ckpt_file).name
            records.append(metadata)
        except Exception as e:
            print(f"Error loading {ckpt_file}: {e}")
    
    df = pd.DataFrame(records)
    return df

if checkpoints:
    df_checkpoints = analyze_all_checkpoints(checkpoints)
    
    print("\nCheckpoint Summary:")
    print("=" * 80)
    
    # Select key columns to display
    display_cols = ['checkpoint_file', 'node_id', 'round', 'generator_type', 
                    'generator_granularity', 'generator_key', 'final_loss', 
                    'training_samples', 'num_parameters']
    
    available_cols = [col for col in display_cols if col in df_checkpoints.columns]
    
    display(df_checkpoints[available_cols])
else:
    print("No checkpoints found to analyze.")

## 5. Visualize Generator Architecture

In [None]:
def visualize_generator_architecture(checkpoint_path):
    """Visualize the generator architecture from checkpoint."""
    checkpoint, metadata = load_checkpoint_metadata(checkpoint_path)
    
    if not metadata['has_generator_state']:
        print("No generator state found in checkpoint")
        return
    
    state_dict = checkpoint['generator_state_dict']
    
    # Extract layer information
    layers = []
    layer_sizes = []
    
    for key, param in state_dict.items():
        if 'weight' in key:
            layers.append(key)
            layer_sizes.append(param.shape)
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Plot 1: Parameter count per layer
    param_counts = [param.numel() for param in state_dict.values()]
    param_names = list(state_dict.keys())
    
    ax1.barh(range(len(param_names)), param_counts)
    ax1.set_yticks(range(len(param_names)))
    ax1.set_yticklabels(param_names, fontsize=8)
    ax1.set_xlabel('Number of Parameters')
    ax1.set_title('Parameters per Layer')
    ax1.grid(axis='x', alpha=0.3)
    
    # Plot 2: Architecture summary
    architecture_info = [
        f"Generator Type: {metadata.get('generator_type', 'N/A')}",
        f"Input Dim: {metadata.get('input_dim', 'N/A')}",
        f"Hidden Dim: {metadata.get('hidden_dim', 'N/A')}",
        f"Latent Dim: {metadata.get('latent_dim', 'N/A')}",
        f"Visual Dim: {metadata.get('visual_dim', 'N/A')}",
        f"Sequence Length: {metadata.get('sequence_length', 'N/A')}",
        f"\nTotal Parameters: {metadata.get('num_parameters', 'N/A'):,}",
    ]
    
    ax2.axis('off')
    ax2.text(0.1, 0.9, '\n'.join(architecture_info), 
             fontsize=12, verticalalignment='top',
             fontfamily='monospace')
    ax2.set_title('Architecture Configuration')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTotal parameters: {metadata.get('num_parameters', 'N/A'):,}")

if checkpoints:
    visualize_generator_architecture(checkpoints[0])

## 6. Compare Checkpoints Across Rounds

In [None]:
def plot_training_progress(df_checkpoints):
    """Plot training progress across rounds."""
    if 'round' not in df_checkpoints.columns or 'final_loss' not in df_checkpoints.columns:
        print("Missing 'round' or 'final_loss' information in checkpoints")
        return
    
    # Filter out None values
    df_plot = df_checkpoints[df_checkpoints['round'].notna() & df_checkpoints['final_loss'].notna()].copy()
    
    if df_plot.empty:
        print("No valid data to plot")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Loss over rounds
    if 'node_id' in df_plot.columns:
        for node_id in df_plot['node_id'].unique():
            node_data = df_plot[df_plot['node_id'] == node_id].sort_values('round')
            axes[0].plot(node_data['round'], node_data['final_loss'], 
                        marker='o', label=f'Node {node_id}')
    else:
        df_plot_sorted = df_plot.sort_values('round')
        axes[0].plot(df_plot_sorted['round'], df_plot_sorted['final_loss'], 
                    marker='o', label='Generator')
    
    axes[0].set_xlabel('Round')
    axes[0].set_ylabel('Final Loss')
    axes[0].set_title('Training Loss Over Rounds')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Plot 2: Loss distribution by round
    if len(df_plot) > 1:
        df_plot.boxplot(column='final_loss', by='round', ax=axes[1])
        axes[1].set_xlabel('Round')
        axes[1].set_ylabel('Final Loss')
        axes[1].set_title('Loss Distribution by Round')
        plt.suptitle('')  # Remove automatic title
    else:
        axes[1].text(0.5, 0.5, 'Not enough data\nfor distribution plot', 
                    ha='center', va='center', fontsize=12)
        axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

if checkpoints and 'df_checkpoints' in locals():
    plot_training_progress(df_checkpoints)

## 7. Load Generator Model from Checkpoint

In [None]:
# Add system path to import generator classes
import sys
sys.path.append('/home/lpala/fedgfe/system')

from flcore.trainmodel.generators import ConditionedVAEGenerator, VAELoss

def load_generator_model(checkpoint_path):
    """Load generator model from checkpoint."""
    checkpoint, metadata = load_checkpoint_metadata(checkpoint_path)
    
    if not metadata['has_generator_state']:
        print("No generator state found in checkpoint")
        return None
    
    # Create generator with same configuration
    generator = ConditionedVAEGenerator(
        input_dim=metadata.get('input_dim', 768),
        hidden_dim=metadata.get('hidden_dim', 1024),
        latent_dim=metadata.get('latent_dim', 256),
        visual_dim=metadata.get('visual_dim', 4864),
        sequence_length=metadata.get('sequence_length', 4)
    )
    
    # Load weights
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()
    
    print(f"✓ Generator loaded from {Path(checkpoint_path).name}")
    print(f"  - Type: {metadata.get('generator_type', 'N/A')}")
    print(f"  - Parameters: {metadata.get('num_parameters', 'N/A'):,}")
    print(f"  - Training samples: {metadata.get('training_samples', 'N/A')}")
    print(f"  - Final loss: {metadata.get('final_loss', 'N/A')}")
    
    return generator, metadata

if checkpoints:
    generator, gen_metadata = load_generator_model(checkpoints[0])

## 8. Export Checkpoint Summary to CSV

In [None]:
if 'df_checkpoints' in locals() and not df_checkpoints.empty:
    output_file = f"{CHECKPOINT_DIR}/checkpoint_summary.csv"
    
    # Select columns to export (exclude complex objects)
    export_cols = [col for col in df_checkpoints.columns 
                   if col not in ['parameter_keys', 'selected_classes', 'generator_classes']]
    
    df_checkpoints[export_cols].to_csv(output_file, index=False)
    print(f"✓ Checkpoint summary exported to: {output_file}")
else:
    print("No checkpoint data to export")

## 9. Detailed Checkpoint Inspector

In [None]:
def inspect_checkpoint_detailed(checkpoint_path):
    """Detailed inspection of a specific checkpoint."""
    checkpoint, metadata = load_checkpoint_metadata(checkpoint_path)
    
    print("\n" + "=" * 80)
    print(f"DETAILED CHECKPOINT INSPECTION: {Path(checkpoint_path).name}")
    print("=" * 80)
    
    # Section 1: Node Information
    print("\n[1] NODE INFORMATION")
    print("-" * 80)
    print(f"Node ID:          {metadata.get('node_id', 'N/A')}")
    print(f"Round:            {metadata.get('round', 'N/A')}")
    print(f"Dataset:          {metadata.get('dataset_name', 'N/A')}")
    print(f"Timestamp:        {metadata.get('timestamp', 'N/A')}")
    print(f"Device:           {metadata.get('device', 'N/A')}")
    
    # Section 2: Generator Configuration
    print("\n[2] GENERATOR CONFIGURATION")
    print("-" * 80)
    print(f"Type:             {metadata.get('generator_type', 'N/A')}")
    print(f"Granularity:      {metadata.get('generator_granularity', 'N/A')}")
    print(f"Generator Key:    {metadata.get('generator_key', 'N/A')}")
    print(f"Diffusion Type:   {metadata.get('diffusion_type', 'N/A')}")
    
    # Section 3: Architecture
    print("\n[3] ARCHITECTURE")
    print("-" * 80)
    print(f"Input Dim:        {metadata.get('input_dim', 'N/A')}")
    print(f"Hidden Dim:       {metadata.get('hidden_dim', 'N/A')}")
    print(f"Latent Dim:       {metadata.get('latent_dim', 'N/A')}")
    print(f"Visual Dim:       {metadata.get('visual_dim', 'N/A')}")
    print(f"Sequence Length:  {metadata.get('sequence_length', 'N/A')}")
    print(f"Total Parameters: {metadata.get('num_parameters', 'N/A'):,}")
    
    # Section 4: Training Information
    print("\n[4] TRAINING INFORMATION")
    print("-" * 80)
    print(f"Training Samples: {metadata.get('training_samples', 'N/A')}")
    print(f"Training Epochs:  {metadata.get('training_epochs', 'N/A')}")
    print(f"Final Loss:       {metadata.get('final_loss', 'N/A')}")
    
    # Section 5: Classes
    print("\n[5] CLASSES")
    print("-" * 80)
    selected_classes = metadata.get('selected_classes', None)
    if selected_classes:
        print(f"Selected Classes: {', '.join(map(str, selected_classes))}")
    else:
        print("Selected Classes: N/A")
    
    generator_classes = metadata.get('generator_classes', None)
    if generator_classes:
        print(f"Generator Classes: {', '.join(map(str, generator_classes))}")
    else:
        print("Generator Classes: N/A")
    
    # Section 6: State Dicts
    print("\n[6] STATE DICTIONARIES")
    print("-" * 80)
    print(f"Has Generator State:  {metadata.get('has_generator_state', False)}")
    print(f"Has Optimizer State:  {metadata.get('has_optimizer_state', False)}")
    
    if metadata.get('has_generator_state'):
        print(f"\nGenerator Layers ({len(metadata.get('parameter_keys', []))})")
        for i, key in enumerate(metadata.get('parameter_keys', [])[:10], 1):
            print(f"  {i}. {key}")
        if len(metadata.get('parameter_keys', [])) > 10:
            print(f"  ... and {len(metadata.get('parameter_keys', [])) - 10} more")
    
    print("\n" + "=" * 80 + "\n")

# Inspect the first checkpoint
if checkpoints:
    inspect_checkpoint_detailed(checkpoints[0])

## 10. Custom Analysis

Use this section for your own custom analysis.

In [None]:
# Your custom analysis here
# Example: Compare two specific checkpoints

if len(checkpoints) >= 2:
    print("Comparing first two checkpoints...\n")
    
    for i, ckpt_path in enumerate(checkpoints[:2], 1):
        _, meta = load_checkpoint_metadata(ckpt_path)
        print(f"Checkpoint {i}: {Path(ckpt_path).name}")
        print(f"  Round: {meta.get('round', 'N/A')}")
        print(f"  Loss: {meta.get('final_loss', 'N/A')}")
        print(f"  Samples: {meta.get('training_samples', 'N/A')}")
        print()