# Distributed GAN Training Demo - Model Visualization

This notebook demonstrates the results from our distributed GAN training system. You can use it to visualize:
- Training loss curves (Generator and Discriminator)
- Generated face samples
- Training progression over epochs
- Model architecture details

# DCGAN Trained Model Demo

This notebook demonstrates the trained DCGAN model for generating celebrity faces.

**Contents:**
1. Load trained model checkpoint
2. Visualize training history (learning curves)
3. Generate new face images
4. Show training progression

**Requirements:**
- Trained model checkpoint in `../outputs_local/checkpoints/` or `../outputs/checkpoints/`
- PyTorch and required dependencies


In [None]:
import sys
import os
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Add src to path to import our modules
sys.path.insert(0, '../src')

from models.dcgan import Generator, Discriminator

# Set plotting style
plt.style.use('default')
%matplotlib inline

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
LATENT_DIM = 100
IMAGE_SIZE = 64
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Checkpoint paths (try both local and distributed outputs)
CHECKPOINT_PATHS = [
    '../outputs_local/checkpoints/checkpoint_latest.pth',
    '../outputs/checkpoints/checkpoint_latest.pth',
]

# Find available checkpoint
CHECKPOINT_PATH = None
for path in CHECKPOINT_PATHS:
    if Path(path).exists():
        CHECKPOINT_PATH = path
        print(f"Found checkpoint: {path}")
        break

if CHECKPOINT_PATH is None:
    print("WARNING: No checkpoint found! Please train a model first.")
    print("Expected locations:")
    for path in CHECKPOINT_PATHS:
        print(f"  - {path}")
else:
    print(f"\nUsing checkpoint: {CHECKPOINT_PATH}")
    print(f"Device: {DEVICE}")

In [None]:
if CHECKPOINT_PATH:
    # Initialize models
    generator = Generator(latent_dim=LATENT_DIM).to(DEVICE)
    discriminator = Discriminator().to(DEVICE)
    
    # Load checkpoint
    print("Loading checkpoint...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    
    # Load model weights
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    
    # Set to evaluation mode
    generator.eval()
    discriminator.eval()
    
    # Extract training information
    epoch = checkpoint.get('epoch', 'Unknown')
    g_losses = checkpoint.get('g_losses', [])
    d_losses = checkpoint.get('d_losses', [])
    
    print(f"âœ“ Loaded checkpoint from epoch {epoch}")
    print(f"âœ“ Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"âœ“ Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    if g_losses:
        print(f"âœ“ Training history: {len(g_losses)} epochs")
        print(f"  Final G loss: {g_losses[-1]:.4f}")
        print(f"  Final D loss: {d_losses[-1]:.4f}")
else:
    print("Warning: Skipping model loading - no checkpoint available")

In [None]:
if CHECKPOINT_PATH and g_losses:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    epochs = range(1, len(g_losses) + 1)
    
    # Generator and Discriminator losses
    axes[0].plot(epochs, g_losses, label='Generator Loss', color='blue', linewidth=2)
    axes[0].plot(epochs, d_losses, label='Discriminator Loss', color='red', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training Losses Over Time', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Smoothed losses (moving average)
    window = min(5, len(g_losses) // 10) if len(g_losses) > 10 else 1
    if window > 1:
        g_losses_smooth = np.convolve(g_losses, np.ones(window)/window, mode='valid')
        d_losses_smooth = np.convolve(d_losses, np.ones(window)/window, mode='valid')
        epochs_smooth = range(window, len(g_losses) + 1)
        
        axes[1].plot(epochs_smooth, g_losses_smooth, label='Generator (smoothed)', 
                    color='blue', linewidth=2)
        axes[1].plot(epochs_smooth, d_losses_smooth, label='Discriminator (smoothed)', 
                    color='red', linewidth=2)
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('Loss', fontsize=12)
        axes[1].set_title(f'Smoothed Losses (window={window})', fontsize=14, fontweight='bold')
        axes[1].legend(fontsize=11)
        axes[1].grid(True, alpha=0.3)
    else:
        axes[1].axis('off')
        axes[1].text(0.5, 0.5, 'Not enough epochs\nfor smoothing', 
                    ha='center', va='center', fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Training completed after {len(g_losses)} epochs")
    print(f"Final Generator Loss: {g_losses[-1]:.4f}")
    print(f"Final Discriminator Loss: {d_losses[-1]:.4f}")
else:
    print("Warning: No training history available in checkpoint")

In [None]:
if CHECKPOINT_PATH:
    # Generate images
    num_samples = 64  # 8x8 grid
    
    print(f"Generating {num_samples} new face images...")
    
    with torch.no_grad():
        # Sample random noise
        noise = torch.randn(num_samples, LATENT_DIM, 1, 1, device=DEVICE)
        
        # Generate images
        generated_images = generator(noise)
        
        # Move to CPU and convert to numpy
        generated_images = generated_images.cpu()
    
    print(f"âœ“ Generated {num_samples} images")
    print(f"  Image shape: {generated_images.shape}")
    print(f"  Value range: [{generated_images.min():.2f}, {generated_images.max():.2f}]")
else:
    print("Warning: Skipping generation - no model loaded")

In [None]:
if CHECKPOINT_PATH:
    # Create a grid of images
    nrow = 8
    ncol = 8
    
    fig, axes = plt.subplots(nrow, ncol, figsize=(16, 16))
    
    for i, ax in enumerate(axes.flat):
        if i < len(generated_images):
            # Convert from [-1, 1] to [0, 1]
            img = (generated_images[i].permute(1, 2, 0).numpy() + 1) / 2
            img = np.clip(img, 0, 1)
            
            ax.imshow(img)
            ax.axis('off')
        else:
            ax.axis('off')
    
    plt.suptitle('Generated Celebrity Faces', fontsize=18, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()
    
    print(f"Displayed {min(num_samples, nrow*ncol)} generated faces in {nrow}x{ncol} grid")
else:
    print("Warning: No images to display")

In [None]:
# Look for saved sample images
sample_dirs = [
    '../outputs_local/samples/',
    '../outputs/samples/',
]

sample_images = []
for sample_dir in sample_dirs:
    sample_path = Path(sample_dir)
    if sample_path.exists():
        # Get all image files sorted by name
        images = sorted(sample_path.glob('*.png'))
        if images:
            print(f"Found {len(images)} sample images in {sample_dir}")
            
            # Select evenly spaced samples to show progression
            num_to_show = min(8, len(images))
            indices = np.linspace(0, len(images)-1, num_to_show, dtype=int)
            
            for idx in indices:
                img_path = images[idx]
                img = Image.open(img_path)
                sample_images.append((img_path.name, img))
            
            print(f"Selected {len(sample_images)} images to show progression")
            break

if not sample_images:
    print("No saved sample images found. Sample images are saved during training in:")
    for sample_dir in sample_dirs:
        print(f"  - {sample_dir}")

In [None]:
if sample_images:
    num_images = len(sample_images)
    ncols = min(4, num_images)
    nrows = (num_images + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
    
    if num_images == 1:
        axes = [axes]
    else:
        axes = axes.flat if num_images > ncols else axes
    
    for idx, (name, img) in enumerate(sample_images):
        if idx < len(axes):
            axes[idx].imshow(img)
            axes[idx].set_title(name.replace('.png', '').replace('_', ' '), fontsize=10)
            axes[idx].axis('off')
    
    # Hide unused subplots
    for idx in range(len(sample_images), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle('Training Progression - Generated Samples Over Time', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()
    
    print(f"Displayed {num_images} samples showing training progression")
    print(f"\nObservations:")
    print("  - Early epochs: Random noise and blurry patterns")
    print("  - Middle epochs: Face-like structures emerge")
    print("  - Later epochs: Detailed, realistic faces")
else:
    print("No progression images to display")

In [None]:
def generate_faces(num_faces=16, seed=None):
    """
    Generate new faces using the trained generator.
    
    Args:
        num_faces: Number of faces to generate
        seed: Random seed for reproducibility (optional)
    """
    if not CHECKPOINT_PATH:
        print("Warning: No model loaded")
        return
    
    if seed is not None:
        torch.manual_seed(seed)
    
    with torch.no_grad():
        noise = torch.randn(num_faces, LATENT_DIM, 1, 1, device=DEVICE)
        images = generator(noise).cpu()
    
    # Display images
    nrow = int(np.ceil(np.sqrt(num_faces)))
    ncol = nrow
    
    fig, axes = plt.subplots(nrow, ncol, figsize=(ncol * 2, nrow * 2))
    
    if num_faces == 1:
        axes = [axes]
    else:
        axes = axes.flat
    
    for i, ax in enumerate(axes):
        if i < num_faces:
            img = (images[i].permute(1, 2, 0).numpy() + 1) / 2
            img = np.clip(img, 0, 1)
            ax.imshow(img)
        ax.axis('off')
    
    plt.suptitle(f'{num_faces} Newly Generated Faces', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Example: Generate 16 new faces
if CHECKPOINT_PATH:
    print("Generating 16 new faces...")
    generate_faces(16, seed=42)
    
    print("\nðŸ’¡ Tip: Call generate_faces(N) to generate N faces")
    print("   Example: generate_faces(25, seed=123)")
else:
    print("Warning: Skipping - no model loaded")

In [None]:
if CHECKPOINT_PATH:
    print("="*70)
    print("DCGAN MODEL SUMMARY")
    print("="*70)
    print(f"\n Training Information:")
    print(f"   â€¢ Training completed: {epoch} epochs")
    print(f"   â€¢ Final Generator Loss: {g_losses[-1]:.4f}" if g_losses else "   â€¢ No loss history")
    print(f"   â€¢ Final Discriminator Loss: {d_losses[-1]:.4f}" if d_losses else "")
    
    print(f"\n Model Architecture:")
    print(f"   â€¢ Generator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"   â€¢ Discriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    print(f"   â€¢ Latent Dimension: {LATENT_DIM}")
    print(f"   â€¢ Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")
    
    print(f"\n Hardware:")
    print(f"   â€¢ Device: {DEVICE}")
    if torch.cuda.is_available():
        print(f"   â€¢ GPU: {torch.cuda.get_device_name(0)}")
    
    print(f"\n Key Insights:")
    if g_losses and d_losses:
        print(f"   â€¢ Training stability: {'Good' if abs(g_losses[-1] - d_losses[-1]) < 1.0 else 'Needs tuning'}")
        print(f"   â€¢ Loss convergence: {'Converged' if len(g_losses) > 10 and abs(g_losses[-1] - g_losses[-5]) < 0.1 else 'Still improving'}")
    
    print(f"\n What This Model Does:")
    print(f"   â€¢ Generates realistic celebrity faces from random noise")
    print(f"   â€¢ Learned from CelebA dataset (celebrity images)")
    print(f"   â€¢ Can create infinite unique faces")
    print(f"   â€¢ No two generated faces are exactly the same")
    
    print(f"\n Next Steps:")
    print(f"   â€¢ Try generating more images with different seeds")
    print(f"   â€¢ Experiment with latent space interpolation")
    print(f"   â€¢ Train for more epochs to improve quality")
    print(f"   â€¢ Try other GAN architectures (StyleGAN, ProGAN)")
    
    print("\n" + "="*70)
else:
    print("Warning: No model loaded - train a model first!")

---

##  Conclusion

This notebook demonstrated a trained DCGAN model that can generate realistic celebrity faces. The model learned to map random noise vectors to face images through adversarial training.

**Key Takeaways:**
- GANs learn to generate realistic images without explicit pixel-level supervision
- The generator creates faces from random latent vectors
- Training involves a minimax game between generator and discriminator
- Quality improves with more training epochs and larger datasets

**Further Experimentation:**
- Try different random seeds to generate diverse faces
- Explore latent space interpolation (morphing between faces)
- Fine-tune on specific face attributes
- Experiment with conditional GANs (control age, gender, etc.)

**Resources:**
- [DCGAN Paper](https://arxiv.org/abs/1511.06434)
- [CelebA Dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
- [Project Repository](https://github.com/YOUR_USERNAME/GANNs-with-freinds)

---

**Created as part of the Distributed GAN Training Project** 

## 9. Summary and Model Insights

Key observations and insights from the trained DCGAN model.

## 8. Interactive Generation (Optional)

Generate specific numbers of new images on demand.

### Display Training Progression

Show how the generated images improved over the course of training.

## 7. Training Progression (Optional)

If you have saved sample images during training, we can visualize how the generator improved over time.

## 6. Display Generated Images

Visualize the generated face images in a grid layout.

## 5. Generate New Face Images

Generate new celebrity face images using the trained generator with random noise inputs.

## 4. Visualize Training History - Learning Curves

Plot the generator and discriminator losses over training epochs to understand model convergence.

## 3. Load Trained Model

Load the generator and discriminator from the saved checkpoint.

## 2. Configuration and Setup

Set up paths and parameters for loading the model.

## 1. Import Required Libraries

Import necessary libraries for loading models, plotting, and image visualization.

## Import Libraries

In [None]:
import sys
import os
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / 'src') if '__file__' in globals() else '../src')

import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import json

from models.dcgan import Generator, Discriminator

# Set up matplotlib
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

## Configuration

Set the paths to your trained model checkpoint and output directory.

In [None]:
# Configuration
CHECKPOINT_PATH = '../outputs_local/checkpoint_latest.pth'  # or '../outputs/checkpoint_latest.pth' for distributed
OUTPUT_DIR = Path('../outputs_local')  # or Path('../outputs') for distributed
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {DEVICE}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

## Load Trained Model

In [None]:
# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

# Extract configuration from checkpoint
latent_dim = checkpoint.get('latent_dim', 100)
image_size = checkpoint.get('image_size', 64)

# Initialize models
generator = Generator(latent_dim=latent_dim).to(DEVICE)
discriminator = Discriminator().to(DEVICE)

# Load weights
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

# Set to evaluation mode
generator.eval()
discriminator.eval()

print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Iteration: {checkpoint['iteration']}")
print(f"Generator loss: {checkpoint.get('g_loss', 'N/A'):.4f}")
print(f"Discriminator loss: {checkpoint.get('d_loss', 'N/A'):.4f}")

## Training Learning Curves

Visualize how the Generator and Discriminator losses evolved during training.

In [None]:
# Load training history if available
history_file = OUTPUT_DIR / 'training_history.json'

if history_file.exists():
    with open(history_file, 'r') as f:
        history = json.load(f)
    
    iterations = history['iterations']
    g_losses = history['g_losses']
    d_losses = history['d_losses']
    
    # Plot losses
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Generator loss
    ax1.plot(iterations, g_losses, label='Generator Loss', color='blue', alpha=0.7)
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('Loss')
    ax1.set_title('Generator Loss Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Discriminator loss
    ax2.plot(iterations, d_losses, label='Discriminator Loss', color='red', alpha=0.7)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('Loss')
    ax2.set_title('Discriminator Loss Over Time')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Combined plot
    plt.figure(figsize=(12, 6))
    plt.plot(iterations, g_losses, label='Generator Loss', color='blue', alpha=0.7)
    plt.plot(iterations, d_losses, label='Discriminator Loss', color='red', alpha=0.7)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Total iterations: {len(iterations)}")
    print(f"Final G loss: {g_losses[-1]:.4f}")
    print(f"Final D loss: {d_losses[-1]:.4f}")
else:
    print(f"Training history not found at {history_file}")
    print("Skipping loss visualization")

## Generate New Face Images

Generate a batch of fresh face images using the trained generator.

In [None]:
# Generate images
num_samples = 64
z = torch.randn(num_samples, latent_dim, 1, 1, device=DEVICE)

with torch.no_grad():
    fake_images = generator(z)

# Convert to numpy for visualization
fake_images = fake_images.cpu()
fake_images = (fake_images + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
fake_images = fake_images.permute(0, 2, 3, 1).numpy()

print(f"Generated {num_samples} images of shape {fake_images.shape[1:]}")

## Display Generated Faces Grid

In [None]:
# Display grid of generated faces
grid_size = 8
fig, axes = plt.subplots(grid_size, grid_size, figsize=(16, 16))

for i in range(grid_size):
    for j in range(grid_size):
        idx = i * grid_size + j
        axes[i, j].imshow(fake_images[idx])
        axes[i, j].axis('off')

plt.suptitle('Generated Faces from Trained GAN', fontsize=16, y=0.995)
plt.tight_layout()
plt.show()

## Training Progression

View how the generated images improved over the course of training.

In [None]:
# Find sample images saved during training
sample_files = sorted(OUTPUT_DIR.glob('samples_iter_*.png'))

if sample_files:
    # Show progression: first, middle, and latest samples
    num_checkpoints = min(6, len(sample_files))
    indices = np.linspace(0, len(sample_files) - 1, num_checkpoints, dtype=int)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, idx in enumerate(indices):
        if i < num_checkpoints:
            img = Image.open(sample_files[idx])
            axes[i].imshow(img)
            axes[i].set_title(f'Iteration {sample_files[idx].stem.split("_")[-1]}')
            axes[i].axis('off')
    
    plt.suptitle('Training Progression', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print(f"Found {len(sample_files)} sample checkpoints")
else:
    print("No sample images found in output directory")

## Interactive Generation

Generate new batches of faces on demand.

In [None]:
def generate_faces(n=16, seed=None):
    """Generate n face images."""
    if seed is not None:
        torch.manual_seed(seed)
    
    z = torch.randn(n, latent_dim, 1, 1, device=DEVICE)
    
    with torch.no_grad():
        images = generator(z)
    
    images = images.cpu()
    images = (images + 1) / 2  # Denormalize
    images = images.permute(0, 2, 3, 1).numpy()
    
    # Display
    grid_size = int(np.ceil(np.sqrt(n)))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
    axes = axes.flatten()
    
    for i in range(grid_size * grid_size):
        if i < n:
            axes[i].imshow(images[i])
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Try it out!
print("Generating 16 random faces...")
generate_faces(16, seed=42)

## Model Summary

In [None]:
# Model parameter counts
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

g_params = count_parameters(generator)
d_params = count_parameters(discriminator)

print("=" * 60)
print("MODEL SUMMARY")
print("=" * 60)
print(f"Generator parameters:     {g_params:,}")
print(f"Discriminator parameters: {d_params:,}")
print(f"Total parameters:         {g_params + d_params:,}")
print("=" * 60)
print(f"\nLatent dimension: {latent_dim}")
print(f"Image size: {image_size}x{image_size}")
print(f"Training device: {DEVICE}")
print(f"Epochs trained: {checkpoint['epoch']}")
print(f"Total iterations: {checkpoint['iteration']}")
print("=" * 60)

## Conclusion

This notebook demonstrated:
- - Loading a trained DCGAN checkpoint
- - Visualizing training loss curves
- - Generating new face images
- - Viewing training progression
- - Interactive face generation

You can modify the `CHECKPOINT_PATH` at the top to compare different training runs (distributed vs local) or different epochs!