# Exercise 3: Energy based models (solutions)
Exercise by [Jes Frellsen](https://frellsen.org) (Technical University of Denmark), August 2024 (version 1.0).

The main task in this programming exercise, is to learn and sample from an EBM based on two simple 2D toy datasets. We have provided you with a file for the toy data:
* `ToyData.py` contains the code for generating data from the two toy models.

You can download `ToyData.py` using the following commands:

In [None]:
! curl -O https://raw.githubusercontent.com/frellsen/02901-2024/main/ToyData.py

## Toy data
First we visualize the probability densities for the toy datasets.

When we create an object of the `Chequerboard` or `TwoGaussian`, we can call the forward method which returns a `Distribution` object from `torch.distributions`. The `Distribution` class implements a method for calculating the log probability (`log_prob(...)`), which we will use to make the plots below, and a method for sampling from the distribution (`sample(...)`), which we will later use for creating our training data.

In [None]:
import torch
import ToyData
import matplotlib.pyplot as plt
import numpy as np

# Make a density plot of the Checkerboard distribution
toy = ToyData.Chequerboard()
coordinates = [[[x,y] for x in np.linspace(*toy.xlim, 1000)] for y in np.linspace(*toy.ylim, 1000)]
prob = torch.exp(toy().log_prob(torch.tensor(coordinates)))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
im = ax1.imshow(prob, extent=[toy.xlim[0], toy.xlim[1], toy.ylim[0], toy.ylim[1]], origin='lower', cmap='YlOrRd')
ax1.set_xlim(toy.xlim)
ax1.set_ylim(toy.ylim)
ax1.set_aspect('equal')
cbar1 = fig.colorbar(im, ax=ax1)
ax1.set_title('Checkerboard distribution')
cbar1.set_label('Probability density')

# Make a density plot of the Gaussian distribution
toy = ToyData.TwoGaussians()
coordinates = [[[x,y] for x in np.linspace(*toy.xlim, 1000)] for y in np.linspace(*toy.ylim, 1000)]
prob = torch.exp(toy().log_prob(torch.tensor(coordinates)))

im = ax2.imshow(prob, extent=[toy.xlim[0], toy.xlim[1], toy.ylim[0], toy.ylim[1]], origin='lower', cmap='YlOrRd')
ax2.set_xlim(toy.xlim)
ax2.set_ylim(toy.ylim)
ax2.set_aspect('equal')
ax2.set_title('Two Gaussians distribution')
cbar2 = fig.colorbar(im, ax=ax2)
cbar2.set_label('Probability density')

### Exercise 3.1: Langevin diffusion

First you will implement Langevin diffusion to sample from the toy distribution (without the Metropolis correction).

Complete the function below, but adding the first-order Euler discretization step $\mathbf{x}_{t+1} = \mathbf{x}_{t} + \frac{\epsilon^2}{2} \nabla_\mathbf{x} \log p_\theta(\mathbf{x}) + \epsilon \mathbf{z}$, where $\epsilon^2$ is the stepsize and $\mathbf{z} \sim \mathcal{N}(0,\mathbf{I})$.

In [None]:
def langevin_dynamics(log_prob, x0, n_steps, stepsize):
    """
    Perform Langevin dynamics to sample from a distribution defined by log_prob.
    
    Args:
    log_prob: A function that takes a sample and returns the unnormalise log-probability of the sample.
    x0: The starting sample.
    n_steps: The number of steps to run Langevin dynamics.
    stepsize: The step size to use for Langevin dynamics.
    """
    x = x0.clone().detach().requires_grad_(True)
    samples = []
    for i in range(n_steps):
        log_prob_val = log_prob(x)
        grad = torch.autograd.grad(log_prob_val, x)[0]
        x = ... # Add your code here
        samples.append(x.clone().detach())
    return torch.stack(samples).squeeze()

Test you code by sampling from the two Gaussians (the implementation of the chequerboard does not support gradients) and plotting the samples with the true density.

In [None]:
# Sample from the Checkerboard distribution using Langevin dynamics
toy = ToyData.TwoGaussians()
log_prob = toy().log_prob
x0 = torch.tensor([[0.0, 0.0]])
n_steps = 10000
stepsize = 0.01
noisy_x = langevin_dynamics(log_prob, x0, n_steps, stepsize)

# Make a density plot of the toy distribution
coordinates = [[[x,y] for x in np.linspace(*toy.xlim, 1000)] for y in np.linspace(*toy.ylim, 1000)]
prob = torch.exp(toy().log_prob(torch.tensor(coordinates)))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
im = ax.imshow(prob, extent=[toy.xlim[0], toy.xlim[1], toy.ylim[0], toy.ylim[1]], origin='lower', cmap='YlOrRd')

# Make a scatter plot of the samples
ax.scatter(noisy_x[:, 0], noisy_x[:, 1], s=1, c='black', alpha=0.5)
ax.set_xlim(toy.xlim)
ax.set_ylim(toy.ylim)
ax.set_aspect('equal')
ax.set_title('Samples from the Checkerboard distribution')


### Exercise 3.2: MLE of EBMs

Next, we will implement MLE based learning for EBMs using the Langevin dynamics sampler.

**Training loop**: First, we have implemented a generic training loop for learning the model.

In [None]:
from tqdm.notebook import tqdm

def train(model, optimizer, data_loader, epochs, device):
    """
    Train a Flow model.

    Parameters:
    model:
       The model to train.
    optimizer: [torch.optim.Optimizer]
         The optimizer to use for training.
    data_loader: [torch.utils.data.DataLoader]
            The data loader to use for training.
    epochs: [int]
        Number of epochs to train for.
    device: [torch.device]
        The device to use for training.
    """
    model.train()

    total_steps = len(data_loader)*epochs
    progress_bar = tqdm(range(total_steps), desc="Training")

    for epoch in range(epochs):
        data_iter = iter(data_loader)
        for x in data_iter:
            if isinstance(x, (list, tuple)):
                x = x[0]
            x = x.to(device)
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()

            # Update progress bar
            progress_bar.set_postfix(loss=f"⠀{loss.item():12.4f}", epoch=f"{epoch+1}/{epochs}")
            progress_bar.update()

**EBM implementation:** Below we define a simple EMB model class.

In [None]:
import torch.nn as nn

class EBM_Base(nn.Module):
    def __init__(self, energy_fn):
        """
        Base class for Energy-Based Models (EBMs).
        """
        super(EBM_Base, self).__init__()
        self.energy_fn = energy_fn

    def forward(self, x):
        """
        Compute the energy of the input samples.
        """
        energy = self.energy_fn(x)
        return energy

    def log_prob(self, x):
        """
        Compute the unnormalised log probability of the input samples.
        """
        return -self(x)

**It is you job, to complete the loss function of the `EBM_MLE` class**, based on the expression for the MLE gradients $\nabla_\theta \log p_\theta(\mathbf{x}) = -\nabla_\theta E_\theta(\mathbf{x}) + \mathbb{E}_{\mathbf{x}'\sim p_\theta(\mathbf{x}')}[\nabla_\theta E_\theta(\mathbf{x}')]$ 

In [None]:
class EBM_MLE(EBM_Base):
    def __init__(self, energy_fn):
        super(EBM_MLE, self).__init__(energy_fn)

    def loss(self, x):
        """
        Compute the loss.
        """
        samples = langevin_dynamics(self.log_prob, x[0], 100, 0.001).detach()
        loss = # Add your code here
        return loss

**Training data**: Next, we generate some training data from the TwoGaussians datasets and create a `data_loader`. We generate a dataset with 10M data points and use a large batch size of 10,000. We can do so, since it is only a two-dimensional dataset.

In [None]:
# Generate the data
import ToyData

batch_size = 100
n_data = 500000

toy = ToyData.TwoGaussians()
train_loader = torch.utils.data.DataLoader(toy().sample((n_data,)), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(toy().sample((n_data,)), batch_size=batch_size, shuffle=True)

**Initialize the model and run the training loop**: Finally we initialize the model using a simple fully connected energy function and run the training loop. *Remember that this will not work before you have completed the loss function above.*

In [None]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get the dimension of the dataset
D = next(iter(train_loader)).shape[1]

# Define the energy function
energy_fn = nn.Sequential(
  nn.Linear(D, 64),
  nn.ReLU(),
  nn.Linear(64, 32),
  nn.ReLU(),
  nn.Linear(32, 16),
  nn.ReLU(),
  nn.Linear(16, 1),
)

# Define flow model
model = EBM_MLE(energy_fn).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters())

# Train model
epochs = 1
train(model, optimizer, train_loader, epochs, device)

**Plotting:** Then we plot the learned density and the true density.

In [None]:
# Make two density plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Get the density of the true distribution
coordinates = torch.tensor([[[x,y] for x in np.linspace(*toy.xlim, 1000)] for y in np.linspace(*toy.ylim, 1000)])
prob = torch.exp(toy().log_prob(coordinates))

# Plot the density of the true distribution
im_true = ax1.imshow(prob, extent=[toy.xlim[0], toy.xlim[1], toy.ylim[0], toy.ylim[1]], origin='lower', cmap='YlOrRd')
ax1.set_xlim(toy.xlim)
ax1.set_ylim(toy.ylim)
ax1.set_aspect('equal')
cbar1 = fig.colorbar(im_true, ax=ax1)
ax1.set_title('True distribution')
cbar1.set_label('Probability density')

# Get the density of the EBM
prob = torch.exp(model.log_prob(coordinates.float().to(device)).detach())

# Plot the density of the EBM
im_learned = ax2.imshow(prob, extent=[toy.xlim[0], toy.xlim[1], toy.ylim[0], toy.ylim[1]], origin='lower', cmap='YlOrRd')
ax2.set_xlim(toy.xlim)
ax2.set_ylim(toy.ylim)
ax2.set_aspect('equal')
cbar2 = fig.colorbar(im_learned, ax=ax2)
ax2.set_title('Learned distribution')
cbar2.set_label('Unnormalised density')

### Exercise 3.3: Denoising score matching

Finally, you will implement learning for EBMs using denoising score matching (DSM) using Gaussian noise with standard deviation `std`.

Complete the `loss` method of `EBM_DSM`. Train the model and compared the learned density to the true density.

In [None]:
class EBM_DSM(EBM_Base):
    def __init__(self, energy_fn):
        super(EBM_DSM, self).__init__(energy_fn)

    def loss(self, x):
        """
        Compute the loss.
        """
        std = 0.1

        losses = []

        # Loop over the batch and compute the loss for each sample
        for i in range(x.shape[0]):
            noisy_x = torch.normal(x[i], std).requires_grad_(True)
            stein_score = torch.autograd.grad(self.log_prob(noisy_x), noisy_x, create_graph=True)[0]
            noise_score = # Add your code here
            loss = # Add your code here
            losses.append(loss)

        return torch.stack(losses).mean()

### Exercise 3.4: Chequerboard (optional)

Evaluate qualitatively how well the two learning methods work on the chequerboard toy data set.
