# Wasserstein GAN with Gradient Penalty (WGAN-GP)

## Core Idea

WGAN-GP replaces the JS divergence in vanilla GAN with the Wasserstein-1 distance (Earth Mover's Distance),
which provides meaningful gradients even when the generator and data distributions have disjoint supports.
The Lipschitz constraint required by the Kantorovich-Rubinstein duality is enforced via a gradient penalty
rather than weight clipping.

## Mathematical Foundation

### Wasserstein-1 Distance

$$W_1(p_{\text{data}}, p_g) = \inf_{\gamma \in \Pi(p_{\text{data}}, p_g)} \mathbb{E}_{(x,y) \sim \gamma}[\|x - y\|]$$

**Interpretation:** Minimum cost to transport mass from $p_g$ to $p_{\text{data}}$, where cost = distance moved.

### Kantorovich-Rubinstein Duality

$$W_1(p_{\text{data}}, p_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{x \sim p_g}[f(x)]$$

where $\|f\|_L \leq 1$ denotes 1-Lipschitz functions: $|f(x_1) - f(x_2)| \leq \|x_1 - x_2\|$ for all $x_1, x_2$.

**Proof Sketch:**
1. Primal: Optimal transport problem over joint distributions $\gamma$
2. Dual: Maximize over 1-Lipschitz functions (Kantorovich potentials)
3. Strong duality holds for compact metric spaces

### Gradient Penalty (Gulrajani et al., 2017)

Instead of weight clipping, enforce Lipschitz constraint via:

$$\mathcal{L}_{\text{GP}} = \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}\left[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2\right]$$

where $\hat{x} = \epsilon x_{\text{real}} + (1-\epsilon) x_{\text{fake}}$, $\epsilon \sim U[0,1]$.

**Why interpolated points?** The optimal critic has gradient norm 1 almost everywhere along straight lines
between $p_{\text{data}}$ and $p_g$ (Proposition 1 in original paper).

### Complete Objective

**Critic (minimize):**
$$\mathcal{L}_D = \underbrace{\mathbb{E}_{\tilde{x} \sim p_g}[D(\tilde{x})] - \mathbb{E}_{x \sim p_{\text{data}}}[D(x)]}_{-W_1 \text{ estimate}} + \underbrace{\lambda \mathbb{E}_{\hat{x}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2]}_{\text{Gradient Penalty}}$$

**Generator (minimize):**
$$\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[D(G(z))]$$

## Problem Statement

Vanilla GAN suffers from:
1. **Gradient vanishing:** When supports are disjoint, $\text{JSD}(p_{\text{data}} \| p_g) = \log 2$ (constant)
2. **Mode collapse:** Generator finds few modes that fool discriminator
3. **Training instability:** Sensitive to architecture and hyperparameters

WGAN-GP addresses these by providing gradients proportional to actual distribution distance.

## Algorithm Comparison

| Aspect | GAN | WGAN (clip) | WGAN-GP |
|--------|-----|-------------|----------|
| Distance | JS Divergence | Wasserstein | Wasserstein |
| Lipschitz | None | Weight clipping | Gradient penalty |
| Output | Probability | Unbounded | Unbounded |
| Stability | Poor | Better | Best |
| Capacity | Full | Limited | Full |

## Complexity Analysis

- **Time per iteration:** $O(n_{\text{critic}} \cdot (T_D + T_{\text{GP}}) + T_G)$
  - $T_{\text{GP}} = O(d)$ for gradient computation via autograd
  - Typically $n_{\text{critic}} = 5$
- **Space:** $O(|\theta_D| + |\theta_G|)$ plus intermediate activations for gradient penalty

In [None]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

In [None]:
@dataclass
class WGANGPConfig:
    """Configuration for WGAN-GP.
    
    Core Idea:
        Hyperparameters tuned for stable Wasserstein training.
    
    Mathematical Theory:
        - gp_lambda: Weight $\lambda$ in GP term. Default 10 from original paper.
        - n_critic: Train critic n times per generator update for accurate $W_1$ estimate.
    """
    latent_dim: int = 100
    hidden_dim: int = 256
    image_channels: int = 1
    image_size: int = 28
    
    lr: float = 1e-4
    beta1: float = 0.0
    beta2: float = 0.9
    
    gp_lambda: float = 10.0
    n_critic: int = 5
    
    batch_size: int = 64
    num_epochs: int = 50
    
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    seed: int = 42
    
    @property
    def image_dim(self) -> int:
        return self.image_channels * self.image_size * self.image_size

In [None]:
class Generator(nn.Module):
    """Generator network for WGAN-GP.
    
    Core Idea:
        Maps latent code $z$ to image space. Architecture identical to vanilla GAN;
        WGAN-GP changes only affect the critic and loss computation.
    
    Mathematical Theory:
        $G: \mathbb{R}^{d_z} \to [-1, 1]^{d_x}$ via composition of affine + nonlinear layers.
        BatchNorm stabilizes training by normalizing intermediate activations.
    """
    
    def __init__(self, config: WGANGPConfig) -> None:
        super().__init__()
        self.config = config
        
        self.net = nn.Sequential(
            nn.Linear(config.latent_dim, config.hidden_dim),
            nn.BatchNorm1d(config.hidden_dim),
            nn.ReLU(True),
            nn.Linear(config.hidden_dim, config.hidden_dim * 2),
            nn.BatchNorm1d(config.hidden_dim * 2),
            nn.ReLU(True),
            nn.Linear(config.hidden_dim * 2, config.hidden_dim * 4),
            nn.BatchNorm1d(config.hidden_dim * 4),
            nn.ReLU(True),
            nn.Linear(config.hidden_dim * 4, config.image_dim),
            nn.Tanh(),
        )
        self._init_weights()
    
    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, z: Tensor) -> Tensor:
        out = self.net(z)
        return out.view(-1, self.config.image_channels, self.config.image_size, self.config.image_size)

In [None]:
class Critic(nn.Module):
    """Critic network for WGAN-GP (not 'Discriminator').
    
    Core Idea:
        Outputs unbounded scalar (not probability). Higher values indicate
        'more real'. The critic approximates the optimal Kantorovich potential.
    
    Mathematical Theory:
        $D: \mathbb{R}^{d_x} \to \mathbb{R}$ approximates 1-Lipschitz function.
        No sigmoid: output is Wasserstein potential, not probability.
        No BatchNorm: can interfere with gradient penalty computation.
    
    Problem Statement:
        BatchNorm in critic correlates samples within batch, violating
        the independence assumption in GP. Use LayerNorm or no normalization.
    """
    
    def __init__(self, config: WGANGPConfig) -> None:
        super().__init__()
        self.config = config
        
        self.net = nn.Sequential(
            nn.Linear(config.image_dim, config.hidden_dim * 4),
            nn.LeakyReLU(0.2, True),
            nn.Linear(config.hidden_dim * 4, config.hidden_dim * 2),
            nn.LeakyReLU(0.2, True),
            nn.Linear(config.hidden_dim * 2, config.hidden_dim),
            nn.LeakyReLU(0.2, True),
            nn.Linear(config.hidden_dim, 1),
        )
        self._init_weights()
    
    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x: Tensor) -> Tensor:
        x_flat = x.view(x.size(0), -1)
        return self.net(x_flat)

In [None]:
def compute_gradient_penalty(
    critic: Critic,
    real: Tensor,
    fake: Tensor,
    device: torch.device
) -> Tensor:
    """Compute gradient penalty for WGAN-GP.
    
    Core Idea:
        Penalize critic gradients that deviate from unit norm along
        interpolated points between real and fake samples.
    
    Mathematical Theory:
        For 1-Lipschitz $f$: $\|\nabla f(x)\| \leq 1$ almost everywhere.
        Optimal critic has $\|\nabla D^*(\hat{x})\| = 1$ on interpolation lines.
        
        $$\text{GP} = \mathbb{E}_{\hat{x}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2]$$
        
        where $\hat{x} = \epsilon x_{\text{real}} + (1-\epsilon) x_{\text{fake}}$
    
    Complexity:
        Time: O(d) for gradient computation via reverse-mode autodiff
        Space: O(d) for storing gradient tensor
    
    Args:
        critic: Critic network
        real: Real samples [B, C, H, W]
        fake: Generated samples [B, C, H, W]
        device: Computation device
    
    Returns:
        Scalar gradient penalty loss
    """
    batch_size = real.size(0)
    
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    
    d_interpolated = critic(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

In [None]:
class WGANGPTrainer:
    """Training orchestrator for WGAN-GP.
    
    Core Idea:
        Alternating optimization with n_critic critic updates per generator update.
        More critic updates = better Wasserstein estimate = better generator gradients.
    
    Mathematical Theory:
        Critic update: $\theta_D \leftarrow \theta_D - \alpha \nabla_{\theta_D} \mathcal{L}_D$
        Generator update: $\theta_G \leftarrow \theta_G - \alpha \nabla_{\theta_G} \mathcal{L}_G$
        
        Adam with $\beta_1=0$ (no momentum) recommended for WGAN-GP.
    """
    
    def __init__(self, config: WGANGPConfig) -> None:
        self.config = config
        self.device = torch.device(config.device)
        
        torch.manual_seed(config.seed)
        
        self.generator = Generator(config).to(self.device)
        self.critic = Critic(config).to(self.device)
        
        self.optimizer_g = torch.optim.Adam(
            self.generator.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        self.optimizer_c = torch.optim.Adam(
            self.critic.parameters(),
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        
        self.fixed_noise = torch.randn(64, config.latent_dim, device=self.device)
        self.history: Dict[str, List[float]] = {
            "loss_c": [], "loss_g": [], "gp": [], "wasserstein": []
        }
    
    def _train_critic(self, real: Tensor) -> Tuple[float, float, float]:
        batch_size = real.size(0)
        self.optimizer_c.zero_grad()
        
        z = torch.randn(batch_size, self.config.latent_dim, device=self.device)
        fake = self.generator(z).detach()
        
        c_real = self.critic(real).mean()
        c_fake = self.critic(fake).mean()
        wasserstein = c_real - c_fake
        
        gp = compute_gradient_penalty(self.critic, real, fake, self.device)
        
        loss_c = -wasserstein + self.config.gp_lambda * gp
        loss_c.backward()
        self.optimizer_c.step()
        
        return loss_c.item(), wasserstein.item(), gp.item()
    
    def _train_generator(self, batch_size: int) -> float:
        self.optimizer_g.zero_grad()
        
        z = torch.randn(batch_size, self.config.latent_dim, device=self.device)
        fake = self.generator(z)
        
        loss_g = -self.critic(fake).mean()
        loss_g.backward()
        self.optimizer_g.step()
        
        return loss_g.item()
    
    def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        self.generator.train()
        self.critic.train()
        
        epoch_metrics = {k: 0.0 for k in self.history.keys()}
        num_batches = len(dataloader)
        
        for real, _ in dataloader:
            real = real.to(self.device)
            batch_size = real.size(0)
            
            for _ in range(self.config.n_critic):
                loss_c, w_dist, gp = self._train_critic(real)
            
            loss_g = self._train_generator(batch_size)
            
            epoch_metrics["loss_c"] += loss_c
            epoch_metrics["loss_g"] += loss_g
            epoch_metrics["gp"] += gp
            epoch_metrics["wasserstein"] += w_dist
        
        for k in epoch_metrics:
            epoch_metrics[k] /= num_batches
            self.history[k].append(epoch_metrics[k])
        
        return epoch_metrics
    
    @torch.no_grad()
    def generate_samples(self, num_samples: int = 64) -> Tensor:
        self.generator.eval()
        z = torch.randn(num_samples, self.config.latent_dim, device=self.device)
        return self.generator(z).cpu()

In [None]:
def create_dataloader(config: WGANGPConfig) -> DataLoader:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=2)

In [None]:
def visualize_samples(samples: Tensor, title: str = "Generated Samples") -> None:
    grid = make_grid(samples, nrow=8, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).numpy())
    plt.title(title)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
def plot_training_curves(history: Dict[str, List[float]]) -> None:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].plot(history["wasserstein"], label="Wasserstein Distance", alpha=0.8)
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Distance")
    axes[0].set_title("Wasserstein Distance (should increase then stabilize)")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history["gp"], label="Gradient Penalty", alpha=0.8, color="orange")
    axes[1].axhline(y=0, color="r", linestyle="--", alpha=0.5)
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("GP Value")
    axes[1].set_title("Gradient Penalty (should stay near 0)")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
if __name__ == "__main__":
    config = WGANGPConfig(num_epochs=50, batch_size=64)
    dataloader = create_dataloader(config)
    trainer = WGANGPTrainer(config)
    
    print(f"Generator parameters: {sum(p.numel() for p in trainer.generator.parameters()):,}")
    print(f"Critic parameters: {sum(p.numel() for p in trainer.critic.parameters()):,}")
    print(f"Device: {config.device}")
    
    for epoch in range(config.num_epochs):
        metrics = trainer.train_epoch(dataloader)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{config.num_epochs}] "
                  f"W: {metrics['wasserstein']:.4f} GP: {metrics['gp']:.4f} "
                  f"Loss_C: {metrics['loss_c']:.4f} Loss_G: {metrics['loss_g']:.4f}")
            samples = trainer.generate_samples(64)
            visualize_samples(samples, f"Epoch {epoch+1}")
    
    plot_training_curves(trainer.history)

## Summary

WGAN-GP improves upon vanilla GAN by:

1. **Wasserstein distance:** Provides meaningful gradients even when distributions don't overlap
2. **Gradient penalty:** Enforces Lipschitz constraint without limiting model capacity
3. **Multiple critic updates:** Better distance estimate leads to better generator gradients

Key implementation details:
- Critic outputs unbounded scalar (no sigmoid)
- No BatchNorm in critic (interferes with GP)
- Adam with $\beta_1 = 0$ for stability
- $\lambda = 10$ for gradient penalty weight
- $n_{\text{critic}} = 5$ critic updates per generator update

The Wasserstein distance should increase during training (distributions getting closer in W-distance sense),
while gradient penalty should stay near zero (Lipschitz constraint satisfied).