# Generative Adversarial Network (GAN)

In [None]:
from typing import Sequence
import functools
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.linen as nn
import haiku as hk
import optax
import chex
from tqdm import tqdm
from utils import BatchManager, load_dataset, save_samples

## Configuration and Data Loading

In [None]:
# Define the configuration for each dataset
dataset_name = 'checkerboard'
dataset_configs = {
    'checkerboard': {
        'epochs': 1000
    },
    'gaussian_mixture': {
        'epochs': 1000
    },
    'pinwheel': {
        'epochs': 1000
    },
    'spiral': {
        'epochs': 1000
    }
}

config = dataset_configs[dataset_name]
config['batch_size'] = 128
config['d_learning_rate'] = 1e-3
config['g_learning_rate'] = 1e-3
config['d_layer_dim'] = [30, 20, 10, 1]
config['g_layer_dim'] = [16, 8, 4, 2]
config['latent_dim'] = 10
config['k'] = 1  # Discriminator training iterations
X_train, X_test = load_dataset(dataset_name)

## Model Definition

In [None]:
class Discriminator(nn.Module):
    layer_dim: Sequence[int]

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        for f in self.layer_dim[:-1]:
            x = nn.Dense(f)(x)
            x = nn.relu(x)
            x = nn.Dropout(rate=0.1, deterministic=not train)(x)
        x = nn.Dense(self.layer_dim[-1])(x)
        x = nn.sigmoid(x)
        return x
    
discriminator = Discriminator(layer_dim=config['d_layer_dim'])

In [None]:
class Generator(nn.Module):
    layer_dim: Sequence[int]

    @nn.compact
    def __call__(self, z: jax.Array) -> jax.Array:
        for f in self.layer_dim[:-1]:
            z = nn.Dense(f)(z)
            z = nn.relu(z)
        x = nn.Dense(self.layer_dim[-1])(z)
        return x
    
generator = Generator(layer_dim=config['g_layer_dim'])

## Training Preparation
Set up the optimizer, loss functions, and other training utilities.

In [None]:
d_optimizer = optax.adam(learning_rate=config['d_learning_rate'])
g_optimizer = optax.adam(learning_rate=config['g_learning_rate'])
prng_seq = hk.PRNGSequence(jax.random.PRNGKey(0))

In [None]:
@functools.partial(jax.jit, static_argnames=['train'])
def gan_loss(d_params:chex.ArrayTree, g_params:chex.ArrayTree, batch: jax.Array, key: chex.PRNGKey, train: bool):
    # Generate sample
    key, z_key = jax.random.split(key)
    z_batch = jax.random.normal(z_key, (batch.shape[0], config['latent_dim']))
    fake_batch = generator.apply(g_params, z_batch)

    # Apply discriminator
    key1, key2 = jax.random.split(key)
    real_preds = discriminator.apply(d_params, batch, train=train, rngs={'dropout': key1})
    fake_preds = discriminator.apply(d_params, fake_batch, train=train, rngs={'dropout': key2})

    # Compute loss
    d_loss = -jnp.mean(jnp.log(real_preds) + jnp.log(1 - fake_preds))
    g_loss = -jnp.mean(jnp.log(fake_preds))
    # g_loss = jnp.mean(jnp.log(1 - fake_preds))

    return d_loss, g_loss

@jax.jit
def do_batch_update(batch, d_params, g_params, opt_d_state, opt_g_state, key):
    # Train discriminator
    for _ in range(config['k']):
        key, z_key = jax.random.split(key)
        compute_d_loss = lambda d_params: gan_loss(d_params, g_params, batch, z_key, train=True)[0]
        d_grad = jax.grad(compute_d_loss)(d_params)
        d_updates, opt_d_state = d_optimizer.update(d_grad, opt_d_state)
        d_params = optax.apply_updates(d_params, d_updates)

    # Train generator
    compute_g_loss = lambda g_params: gan_loss(d_params, g_params, batch, key, train=True)[1]
    g_grad = jax.grad(compute_g_loss)(g_params)
    g_updates, opt_g_state = g_optimizer.update(g_grad, opt_g_state)
    g_params = optax.apply_updates(g_params, g_updates)

    return d_params, g_params, opt_d_state, opt_g_state

## Training Loop

In [None]:
g_params = generator.init(next(prng_seq), jax.random.normal(next(prng_seq), (1, config['latent_dim'])))
d_params = discriminator.init(next(prng_seq), X_train[:1, ...], train=False)
opt_g_state = g_optimizer.init(g_params)
opt_d_state = d_optimizer.init(d_params)
bm = BatchManager(X_train, config['batch_size'], key=next(prng_seq))
d_train_losses = []
g_train_losses = []
d_test_losses = []
g_test_losses = []

In [None]:
for epoch in tqdm(range(config['epochs']), "Epoch"):
    for _ in range(bm.num_batches):
        batch = next(bm)
        d_params, g_params, opt_d_state, opt_g_state = do_batch_update(
            batch, d_params, g_params, opt_d_state, opt_g_state, next(prng_seq))
    d_train_loss, g_train_loss = gan_loss(d_params, g_params, X_train, next(prng_seq), train=False)
    d_test_loss, g_test_loss = gan_loss(d_params, g_params, X_test, next(prng_seq), train=False)
    d_train_losses.append(-d_train_loss)
    g_train_losses.append(g_train_loss)
    d_test_losses.append(-d_test_loss)
    g_test_losses.append(g_test_loss)

## Training Results Visualization

In [None]:
# Apply moving average filter to losses
window_size = 10
window = np.ones(window_size) / window_size
d_train_losses_f = np.convolve(d_train_losses, window, mode='valid')
g_train_losses_f = np.convolve(g_train_losses, window, mode='valid')
d_test_losses_f = np.convolve(d_test_losses, window, mode='valid')
g_test_losses_f = np.convolve(g_test_losses, window, mode='valid')
x_pos = np.arange(window_size // 2, window_size // 2 + d_train_losses_f.shape[0])

plt.plot(x_pos, d_train_losses_f, label='D Train')
plt.plot(x_pos, g_train_losses_f, label='G Train')
plt.plot(x_pos, d_test_losses_f, label='D Test')
plt.plot(x_pos, g_test_losses_f, label='G Test')
plt.xlabel('Epoch')
plt.ylabel('Smoothed Loss')
plt.legend()
plt.show()

## Sample Generation and Visualization

In [None]:
# Function to generate samples from the generator
def generate_samples(g_params: chex.ArrayTree, key: chex.PRNGKey, num_samples:int):
    z = jax.random.normal(key, (num_samples, config['latent_dim']))
    samples = generator.apply(g_params, z)
    return samples

In [None]:
# Generate samples
samples = generate_samples(g_params, next(prng_seq), 2000)

# Plot samples
plt.scatter(samples[:, 0], samples[:, 1], marker='.', label='Sampled')
plt.scatter(X_train[:, 0], X_train[:, 1], alpha=0.2, marker='o', label='Train')
plt.scatter(X_test[:, 0], X_test[:, 1], alpha=0.2, marker='o', label='Test')
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.axis('equal')
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.legend()
plt.show() 

In [None]:
# Save model outputs
save_samples('gan', dataset_name, samples)