# Energy Based Model

In [None]:
from typing import Sequence
import functools
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.linen as nn
import haiku as hk
import optax
import chex
from tqdm import tqdm
from utils import BatchManager, load_dataset, save_samples

## Configuration and Data Loading

In [None]:
# Define the configuration for each dataset
dataset_name = 'checkerboard'
dataset_configs = {
    'checkerboard': {
        'epochs': 200
    },
    'gaussian_mixture': {
        'epochs': 200
    },
    'pinwheel': {
        'epochs': 200
    },
    'spiral': {
        'epochs': 200
    }
}

config = dataset_configs[dataset_name]
config['batch_size'] = 128
config['learning_rate'] = 1e-3
config['mlp_layer_dim'] = [64, 64, 1]
X_train, X_test = load_dataset(dataset_name)

## Model Definition

In [None]:
class MLP(nn.Module):
    layer_dim: Sequence[int]
    
    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        for f in self.layer_dim[:-1]:
            x = nn.Dense(f)(x)
            x = nn.swish(x)
        x = nn.Dense(self.layer_dim[-1])(x)
        return x.squeeze()

model = MLP(layer_dim=config['mlp_layer_dim'])

## Training Preparation
Set up the optimizer, loss functions, and other training utilities.

In [None]:
optimizer = optax.adam(learning_rate=config['learning_rate'])
prng_seq = hk.PRNGSequence(jax.random.PRNGKey(0))

@functools.partial(jax.jit, static_argnames=("num_steps",))
def langevin_sampling(
    params: chex.ArrayTree,
    key: chex.PRNGKey,
    step_size: float,
    initial_samples: jax.Array,
    num_steps: int,
) -> jax.Array:

    def scan_fn(carry, _):
        states, key = carry
        key, sk = jax.random.split(key)
        noise = jax.random.normal(sk, shape=states.shape)
        score = jax.vmap(jax.grad(lambda x: model.apply(params, x)))(states)
        next_states = states + step_size * score + jnp.sqrt(2 * step_size) * noise
        return (next_states, key), None

    states = initial_samples
    (states, _), _ = jax.lax.scan(scan_fn, (states, key), jnp.arange(num_steps))
    return states

def ce_loss_grad(params: chex.ArrayTree, batch: jax.Array, key: chex.PRNGKey) -> float:
    # Sample from model
    key1, key2 = jax.random.split(key)
    batch_model = langevin_sampling(params, key1, 5e-3, 2 * jax.random.normal(key2, shape=batch.shape), 1000)

    # Apply model and compute gradient for each sample
    f = lambda params, x: model.apply(params, x)
    df_data = jax.vmap(jax.grad(f), in_axes=(None, 0))(params, batch)
    df_model = jax.vmap(jax.grad(f), in_axes=(None, 0))(params, batch_model)
    grad = jax.tree.map(jnp.subtract, df_model, df_data)  # -(df_data - df_model)

    # Sum gradients across samples
    grad = jax.tree.map(lambda x: jnp.sum(x, axis=0), grad)

    return grad

@jax.jit
def do_batch_update(batch: jax.Array, params: chex.ArrayTree, opt_state: optax.OptState, key: chex.PRNGKey) -> tuple[float, chex.ArrayTree, optax.OptState]:
    grad = ce_loss_grad(params, batch, key)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

## Training Loop

In [None]:
params = model.init(next(prng_seq), X_train[:1, ...])
opt_state = optimizer.init(params)
bm = BatchManager(X_train, config['batch_size'], key=next(prng_seq))

In [None]:
train_energies = []
test_energies = []
for epoch in tqdm(range(config['epochs']), "Epoch"):
    for _ in range(bm.num_batches):
        batch = next(bm)
        params, opt_state = do_batch_update(batch, params, opt_state, key=next(prng_seq))
    train_energy = jnp.mean(model.apply(params, X_train))
    test_energy = jnp.mean(model.apply(params, X_test))
    train_energies.append(train_energy)
    test_energies.append(test_energy)

## Training Results Visualization

In [None]:
# Apply moving average filer to losses
window_size = 10
window = np.ones(window_size) / window_size
train_losses_f = np.convolve(train_energies, window, mode='valid')
test_losses_f = np.convolve(test_energies, window, mode='valid')
x_pos = np.arange(window_size // 2, window_size // 2 + train_losses_f.shape[0])

# Plot losses
plt.plot(train_energies, label='Train Energy')
plt.plot(test_energies, label='Test Energy')
plt.plot(x_pos, train_losses_f, label='Smoothed Train Energies')
plt.plot(x_pos, test_losses_f, label='Smoothed Test Energies')
plt.xlabel('Epoch')
plt.ylabel('Energy')
plt.legend()
plt.show()

## Sample Generation and Visualization

In [None]:
samples = langevin_sampling(params, next(prng_seq), 5e-3, 2 * jax.random.normal(next(prng_seq), shape=(2000, 2)), 1000)

plt.scatter(samples[:, 0], samples[:, 1], marker='.', label='Sampled')
plt.scatter(X_train[:, 0], X_train[:, 1], alpha=0.2, marker='o', label='Train')
plt.scatter(X_test[:, 0], X_test[:, 1], alpha=0.2, marker='o', label='Test')
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.axis('equal')
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.legend()
plt.show() 

In [None]:
# Save model outputs
save_samples('ebm', dataset_name, samples)