# StyleGAN3 Training Visualization

This notebook helps visualize training progress:
- Load and inspect checkpoints
- Generate sample images from trained models
- Visualize training curves (if using WandB or saved logs)
- Interpolate in latent space
- Test regional conditioning

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision

from src.generation.generate_gan import GreekMotifGenerator
from src.models.stylegan3_model import StyleGAN3Generator

# Configure plotting
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Load Trained Model

In [None]:
# Set checkpoint path
checkpoint_path = project_root / "models" / "checkpoints" / "best_model.pt"

# Alternative: use specific epoch checkpoint
# checkpoint_path = project_root / "models" / "checkpoints" / "checkpoint_epoch_0100.pt"

if checkpoint_path.exists():
    print(f"Loading checkpoint: {checkpoint_path}")
    generator = GreekMotifGenerator(str(checkpoint_path), device=device)
    print("✓ Model loaded successfully")
else:
    print(f"❌ Checkpoint not found: {checkpoint_path}")
    print("\nPlease train a model first or specify correct checkpoint path.")

## 2. Generate Sample Motifs

In [None]:
def show_generated_grid(images, title="Generated Greek Motifs", nrow=4):
    """
    Display a grid of generated images.
    """
    # Convert to tensors
    images_tensor = torch.stack([torch.from_numpy(img).permute(2, 0, 1) / 255.0
                                   for img in images])

    # Create grid
    grid = torchvision.utils.make_grid(images_tensor, nrow=nrow, padding=2)

    # Plot
    plt.figure(figsize=(15, 15))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(title, fontsize=16, fontweight='bold')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Generate random samples
num_samples = 16
print(f"Generating {num_samples} random samples...")

samples = generator.generate(
    num_samples=num_samples,
    seed=42,  # Fixed seed for reproducibility
    truncation_psi=0.7
)

show_generated_grid(samples, "Generated Greek Motifs (Truncation=0.7)")

## 3. Test Regional Conditioning

In [None]:
# List of available regions
regions = [
    'Aegean_Islands', 'Cyclades', 'Dodecanese', 'Epirus',
    'Greece', 'Lesvos', 'North_Aegean', 'Rhodes',
    'Thessaly', 'Thrace', 'Turkey'
]

# Generate samples for each region
for region in regions[:4]:  # Show first 4 regions
    print(f"\nGenerating samples for region: {region}")

    region_samples = generator.generate(
        num_samples=8,
        region=region,
        seed=42,
        truncation_psi=0.7
    )

    show_generated_grid(
        region_samples,
        title=f"Generated Motifs - {region} Style",
        nrow=4
    )

## 4. Truncation Trick Comparison

In [None]:
# Generate with different truncation values
truncation_values = [0.3, 0.5, 0.7, 1.0]

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()

for idx, psi in enumerate(truncation_values):
    print(f"Generating with truncation_psi={psi}...")

    samples = generator.generate(
        num_samples=4,
        seed=123,  # Same seed for comparison
        truncation_psi=psi
    )

    # Convert to grid
    images_tensor = torch.stack([torch.from_numpy(img).permute(2, 0, 1) / 255.0
                                   for img in samples])
    grid = torchvision.utils.make_grid(images_tensor, nrow=2, padding=2)

    axes[idx].imshow(grid.permute(1, 2, 0).cpu().numpy())
    axes[idx].set_title(f'Truncation ψ = {psi}', fontsize=12, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle('Effect of Truncation Trick', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nNote: Lower truncation values (e.g., 0.3) produce more typical/conservative samples.")
print("Higher values (e.g., 1.0) allow more diversity but may reduce quality.")

## 5. Latent Space Interpolation

In [None]:
# Generate two random latent vectors
z_dim = 512
torch.manual_seed(42)
z1 = torch.randn(1, z_dim).to(device)
z2 = torch.randn(1, z_dim).to(device)

# Interpolate
n_steps = 10
print(f"Interpolating between two latent vectors ({n_steps} steps)...")

interpolated = generator.interpolate(
    start_latent=z1,
    end_latent=z2,
    steps=n_steps
)

# Display
show_generated_grid(
    interpolated,
    title="Latent Space Interpolation",
    nrow=5
)

print("\nSmooth transitions indicate well-structured latent space.")

## 6. Generate Variations of a Base Motif

In [None]:
# Generate base latent vector
torch.manual_seed(99)
base_z = torch.randn(1, z_dim).to(device)

# Generate original
original = generator.generate(num_samples=1, latent_vectors=base_z)

# Generate variations
variations = generator.generate_variations(
    base_latent=base_z,
    num_variations=8,
    variation_strength=0.1
)

# Combine original + variations
all_samples = original + variations

show_generated_grid(
    all_samples,
    title="Original Motif + Variations (strength=0.1)",
    nrow=3
)

print("\nFirst image: Original")
print("Remaining: Variations with small perturbations")

## 7. Save Generated Samples

In [None]:
# Generate and save a batch of high-quality samples
output_dir = project_root / "outputs" / "generated_samples"
output_dir.mkdir(parents=True, exist_ok=True)

num_to_save = 20
print(f"Generating and saving {num_to_save} samples to {output_dir}...")

samples_to_save = generator.generate(
    num_samples=num_to_save,
    truncation_psi=0.7
)

generator.save_images(
    images=samples_to_save,
    output_dir=str(output_dir),
    prefix="greek_motif"
)

print(f"\n✓ Saved {num_to_save} images to {output_dir}")

## 8. Model Inspection

In [None]:
# Load checkpoint to inspect training details
checkpoint = torch.load(checkpoint_path, map_location=device)

print("Checkpoint Information:")
print("=" * 60)
print(f"Epoch: {checkpoint.get('epoch', 'Unknown')}")
print(f"Global step: {checkpoint.get('global_step', 'Unknown')}")

if 'config' in checkpoint:
    config = checkpoint['config']
    print("\nModel Configuration:")
    print(f"  Image resolution: {config.get('img_resolution', 'N/A')}")
    print(f"  Latent dimension: {config.get('z_dim', 'N/A')}")
    print(f"  Conditioning dim: {config.get('condition_dim', 'N/A')}")
    print(f"  Batch size: {config.get('batch_size', 'N/A')}")
    print(f"  Learning rates: G={config.get('g_lr', 'N/A')}, D={config.get('d_lr', 'N/A')}")

# Count parameters
if 'generator_state_dict' in checkpoint:
    g_params = sum(p.numel() for p in checkpoint['generator_state_dict'].values() if isinstance(p, torch.Tensor))
    print(f"\nGenerator parameters: {g_params:,}")

if 'discriminator_state_dict' in checkpoint:
    d_params = sum(p.numel() for p in checkpoint['discriminator_state_dict'].values() if isinstance(p, torch.Tensor))
    print(f"Discriminator parameters: {d_params:,}")

print("\n" + "=" * 60)