# Energy-Based Models - Home Assignment

Welcome to the assignment on Energy-Based Models (EBMs). There are three parts in this assignment:
1. You will implement a very simple EBM and train it on a 2D toy dataset using the Maximum Likelihood Estimation (MLE) method with Markov Chain Monte Carlo (MCMC) sampling.
2. You will change the training method to Noise Contrastive Estimation (NCE) and its variant conditional NCE (cNCE) and train a model on 2D toy datasets. 
3. You will train a slightly more complex EBM on the MNIST dataset using Sliced Score Matching (SSM) in another notebook (SSM.ipynb).

Throughout this assignment, there are several places where you will need to fill in the code. These are marked with `YOUR CODE HERE` comments. Further, there are several places where you will need to answer questions. These are marked with `YOUR ANSWER HERE` comments. You should replace the `YOUR CODE HERE` and `YOUR ANSWER HERE` comments with your code and answers. 

### Conda environment
You can use the same environment as in the Normalizing Flows assignment. Otherwise, you can create a new environment with the following command:

GPU version:
```
conda env create -f environments/ebm_gpu.yml
```
CPU version:
```
conda env create -f environments/ebm_cpu.yml
``` 


In [None]:
import os
import time

import numpy as np

import torch
import torch.nn as nn
import torch.autograd as autograd
from torch.optim import Adam
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

### Device settings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Session 1: MLE with MCMC sampling

In this session, you will create a simple MLP that tries to learn a 2D toy dataset containing 8 Gaussian distributions. You will use the MLE method with MCMC sampling to train the model and sample from it. Let's take a look at the dataset first.

In [None]:
from toy_data import toy_dataset
from toy_data import vis_nce

### Feel free to try out other datasets
datasets = ['pinwheel', '8gaussians', 'checkerboard', '2spirals', 'rings']

### Plot samples from the training set
fig, ax = plt.subplots(1, 5, figsize=(20, 4))
for i, dataset_name in enumerate(datasets):
    X_train = toy_dataset.return_dataset(dataset_name, 1000)[0]
    toy_dataset.plot_2d_samples(ax[i], X_train[:][0])
    ax[i].set_title(dataset_name)
plt.show()

Now we need to define the MLP model. The MLP model should take a 2D input and output a single value, which is the energy of the corresponding input. The architecture is partially defined in the default values of arguments. Specifically, the model should have:
- 2 hidden layers with 100 units each,
- ReLU activation function for hidden layers,
- Linear activation function for the output layer.


In [None]:
### Define the model
class MLP(torch.nn.Module):
    def __init__(self, input_dim:int=2, hidden_dims:tuple=(100, 100), output_dim:int=1):
        super().__init__()
        self.net = torch.nn.Sequential(
            ### NOTE: YOUR CODE HERE
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], output_dim)
        )
        
    def forward(self, x):
        return self.net(x)

Now we need to right the Langevin sampling method. The accept-reject step is ignored, which means the actual algorithm is the Unjusted Langevin Algorithm (ULA). The algorithm should take the following arguments:
- `x0`: the input data as the initial states,
- `model`: the MLP model,
- `stepsize`: the stepsize of the Langevin dynamics,
- `n_steps`: the number of steps of the Langevin dynamics,
- `noise_scale`: the scale of the noise added to the Langevin dynamics (you can just use the default).

In [None]:
def sample_langevin(x0: torch.Tensor, model: nn.Module, stepsize: float, n_steps: int, noise_scale:float=None, intermediate_samples=False, inverse=False):
    """Draw samples using Langevin dynamics
    x0: torch.Tensor, initial points
    model: An energy-based model
    noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2)
    inverse: bool. If True, use the inverse dynamics (the inverse of the gradient)
    """
    if noise_scale is None:
        noise_scale = np.sqrt(stepsize * 2)

    x = x0
    x.requires_grad = True
    l_samples = [x.detach().to('cpu')]
    l_dynamics = []
    for _ in range(n_steps):
        noise = torch.randn_like(x) * noise_scale
        ### NOTE: YOUR CODE HERE
        energy = model(x)
        grad = autograd.grad(energy.sum(), x)[0]
        dynamics = - stepsize * grad + noise # should consider the inverse dynamics
        x = x + dynamics
        
        l_samples.append(x.detach().to('cpu'))
        l_dynamics.append(dynamics.detach().to('cpu'))

    if intermediate_samples:
        return l_samples, l_dynamics
    else:
        return l_samples[-1]

We can test our Langevin sampler by drawing samples from a Gaussian distribution and checking the intermediate samples and dynamics. We can use the negative log probability as the energy. Do you think the dynamics look reasonable? Compare these two plots and discuss the choice of the stepsize.

`YOUR ANSWER HERE`

In [None]:
init_states = torch.randn(50, 2)
def langevin_test(init_states, step_size:float=0.01):
    def target_two_dim_gaussian():
        return torch.distributions.MultivariateNormal(2*torch.ones(2), 0.2*torch.eye(2))

    def energy_temp(x):
        return - target_two_dim_gaussian().log_prob(x)

    real_samples = target_two_dim_gaussian().sample((500,))
    l_samples, l_dynamics = sample_langevin(init_states, energy_temp, step_size, 100, intermediate_samples=True)

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    ax_list = axes.flatten()
    ckp_list = [1, 5, 10, 20, 40, 60]
    for i, ckp in enumerate(ckp_list):
        ax_list[i].scatter(real_samples[:, 0], real_samples[:, 1], s=1, c='g', label='Target')
        ax_list[i].scatter(l_samples[0][:, 0], l_samples[0][:, 1], s=5, c='b', label='Initial')
        ax_list[i].scatter(l_samples[ckp][:, 0], l_samples[ckp][:, 1], s=10, c='r', label='Sampled')
        toy_dataset.plot_2d_samples_with_langevin_dynamics(ax_list[i], l_samples[ckp], l_dynamics[ckp-1])
        ax_list[i].set_title(f'Step {ckp}')
    [ax.axis('equal') for ax in ax_list]
    ax_list[0].legend()
    fig.tight_layout()
    plt.show()

langevin_test(init_states, step_size=0.01)
langevin_test(init_states, step_size=0.02)


Now we have implemented the ULA. We can use it to train the model. One last thing we need to do is to define the loss function. The loss function should take the following arguments:
- `energy_positive`: the energy of the positive samples (data),
- `energy_negative`: the energy of the negative samples (samples from the model),
- `alpha`: the L2 regularization parameter, to limit the energy values.

Then, we are ready to train our EBM on the toy dataset "8gaussians"!

In [None]:
def mle_loss_function(energy_pos, energy_neg, alpha:float=0.1):
    ### NOTE: YOUR CODE HERE
    reg = alpha * (torch.mean(energy_pos**2) + torch.mean(energy_neg**2))
    loss = energy_pos.mean() - energy_neg.mean() + reg
    
    return loss

### Load the training set
X_train = toy_dataset.return_dataset('8gaussians', 1000)[0]
dl_train = DataLoader(X_train, batch_size=128, shuffle=True, num_workers=8)

### Define the training parameters
n_epoch = 50     # number of epochs
stepsize = 0.1   # Langevin dynamics step size
n_step = 100     # The number of Langevin dynamics steps
alpha = 0.1      # Regularizer coefficient
batch_size = 128 # Batch size

model = MLP().to(device)
opt = Adam(model.parameters(), lr=1e-3)

X = torch.randn(1000, 2).to(device)
X = sample_langevin(X, model, stepsize, n_step, intermediate_samples=False).to('cpu').detach()
toy_dataset.vis_result(X_train, X, model, device=device)

for i_epoch in range(n_epoch):
        l_loss = []
        for pos_x, in dl_train:
            
            pos_x:torch.Tensor = pos_x.to(device)

            ### NOTE: YOUR CODE HERE
            x_0 = torch.randn(X.size())
            # No grad. from the sampling
            neg_x = sample_langevin(x_0, model, stepsize, n_step, intermediate_samples=False).to(device).detach()
            pos_out = model(pos_x)
            neg_out = model(neg_x)
            loss = mle_loss_function(pos_out, neg_out, alpha)

            opt.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            opt.step()
            
            l_loss.append(loss.item())
        print(f'Epoch {i_epoch+1}/{n_epoch}. Mean loss {round(np.mean(l_loss), 4)}')

        ### Visualize the results every 10 epochs
        if ((i_epoch+1) % 10 == 0):
            X = torch.randn(1000, 2).to(device)
            X = sample_langevin(X, model, stepsize, n_step, intermediate_samples=False).to('cpu').detach()
            toy_dataset.vis_result(X_train, X, model, device=device)

## Session 2: NCE and cNCE

Now we switch to different ways to train our EBMs. First of all, we will use the NCE method to train our EBM, for which we need to use a different network architecture. The difference here is that we need to model the partition function as a learnable parameter. Feel free to use any architecture you like. Why is it important to model the partition function as a learnable parameter?

Because the NCE criterion includes evaluations of $p_\theta(\cdot)$, which requires an explicitly normalised distribution. For the ML-estimaion above, we got around the issue of the normalisation by estimating $Z_\theta$ from MCMC-samples. Later, with the CNCE criterion, we will see that we only need to evaluate the model distribution up to a constant (i.e. we only need to be able to compute the energy). But for the standard NCE, we need an explicit normalisation.

The glass half full interpretation is that NCE allows us to learn a normalised distribution directly. The glass half empty version is that this is very limiting, consider a simple extension where we want to model a conditional distribution $p_\theta(x \mid y)$, with $y \in \mathbb{R}$ a simple real scalar. Then $Z_\theta(y)$ is a function of $y$ and this learnable parameter trick will not work.

In [None]:
class MLPC(nn.Module):
    def __init__(self, x_dim):
        """
        Initialize EBM model, which is an MLP with additional estimated partition function c.

        Args:
            x_dim (int): The size of the input data.
        """

        super(MLPC, self).__init__()
        ### NOTE: YOUR CODE HERE
        # Re-use the MCMC-ML model for a somewhat fair comparison
        hidden_dims = (100, 100)
        input_dim = x_dim
        self.energy = torch.nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], 1)
        )
        self.log_z = nn.Parameter(torch.randn((1)))
        
    def forward(self, x):
        """Compute log p_theta(x)"""
        # NOTE YOUR CODE HERE
        return - self.energy(x) - self.log_z

    
### Test the model
batch_size = 16
x_dim = 10
x_rand = torch.randn((batch_size, x_dim))
ebm = MLPC(x_dim)
assert ebm(x_rand).shape == (batch_size, 1), "Generator output shape is wrong"

Let's define a noise distribution. Write code for your noise distribution. Could be any distribution of your choice. Looking into `torch.distributions` could be a good idea. The argument `params` depends on which distribution you chose. First recall what should be considered when selecting an appropriate noise distribution:

There are very few formal requirements on the noise distribution $p_n$. We do need a certain level of support:
$$
p_n(x) > 0\; \forall x: p_d(x) > 0.
$$
Beyond that, to use the NCE crit. we need $p_n$ to be relatively simple to sample from and to evaluate the pdf exactly. Ideally, it should also be close to the model distribution $p_\theta$, but this is tricky to achieve without having learnable/adaptive noise.

In [None]:
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.categorical import Categorical

class Gaussian:
    def __init__(self, x_dim, params) -> None:
        # NOTE YOUR CODE HERE
        mean = torch.zeros((x_dim,))
        scale = params
        cov = torch.eye(x_dim)
        self.distribution = MultivariateNormal(loc=mean, covariance_matrix=cov)
        
    def sample(self, n_samples):
        '''
        returns n_samples from the distribution
        '''
        return self.distribution.sample((n_samples,))
        
    
    def log_prob(self, x):
        '''
        returns the log_probability of a batch of samples
        '''
        return self.distribution.log_prob(x).reshape((x.size(0), 1))

### Look at your noise
batch_size = 10000
x_dim = 2
distribution_params = [4] # TODO, depends on which distribution you use.
plot_lim = 4 # for plot, may need to be adjusted depending on your noise dist parameters
n_pts = 700 # for plot 2D hist bins
noise_dist = Gaussian(x_dim, distribution_params)
noise_samples = noise_dist.sample(batch_size)

assert noise_samples.shape == (batch_size, x_dim), "Noise distribution samples shapes are wrong"

log_probs = noise_dist.log_prob(noise_samples)
assert log_probs.shape == (batch_size, 1), "Noise distribution samples shapes are wrong"

# plot samples and pdf
fig, axs = plt.subplots(1, 2, figsize=(12,4.3), subplot_kw={'aspect': 'equal'})
vis_nce.plot_samples(noise_samples, axs[0], plot_lim, n_pts)
axs[0].set_title(f'{batch_size} noise samples')
test_grid = vis_nce.setup_grid(plot_lim, n_pts, device)
vis_nce.plot_noise(noise_dist, axs[1], test_grid, n_pts)

Now that we have coded an EBM model and a noise distribution of your choice, it is time to code the noise contrastive estimator (NCE). 

Describe the idea of the NCE and why it is useful. 
What is the purpose of using the parameter `k`?

`YOUR ANSWER HERE`

In [None]:
class NCE:
    def __init__(self, ebm, noise_dist, k) -> None:
        """
        Initializes the NCE
        
        Args:
        ebm (EBM): The energy based model.
        noise_dist (torch.distributions.distribution): The noise distribution. Could be from torch.distributions. Can also be some custom distribution.
        k (int): The nr of noise data points for each target data point. \nu in the slides from lecture.
        """
        self.ebm = ebm
        self.noise_dist = noise_dist 
        self.k = k
    
    def loss(self, x) -> torch.Tensor:
        '''
        returns the NCE-loss given a batch x from the dataset
        '''
        # NOTE YOUR CODE HERE
        p_theta_x = torch.exp(self.ebm(x))
        q_x = torch.exp(self.noise_dist.log_prob(x))
        pos_den = p_theta_x + self.k * q_x
        pos_term = torch.log(p_theta_x) - torch.log(pos_den)
        
        y = self.noise_dist.sample(x.shape[0] * self.k)
        p_theta_y = torch.exp(self.ebm(y))
        q_y = torch.exp(self.noise_dist.log_prob(y))
        neg_den = p_theta_y + self.k * q_y
        neg_term = torch.log(self.k * q_y) - torch.log(neg_den)
        return torch.mean(pos_term) + self.k * torch.mean(neg_term)    

### Test the NCE loss
x_dim = 2
x_rand = torch.randn((batch_size, x_dim)).to(device)
k = 5
distribution_params = [4] # TODO, depends on which distribution you use.

ebm = MLPC(x_dim=x_dim).to(device)
noise_dist = Gaussian(x_dim, distribution_params)
nce = NCE(ebm, noise_dist, k)
loss = nce.loss(x_rand)
# Assert scalar
assert loss.shape == torch.Size([]), "Loss shape is wrong"

In [None]:
def train(epoch, train_loader, ebm, optimizer, nce):
    ebm.train()
    train_loss = 0
    for batch_idx, x in enumerate(train_loader):
        x = x.to(device)
        optimizer.zero_grad()

        loss = nce.loss(x)
        loss.backward()

        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(x),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item() / len(x),
                )
            )
    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(train_loader.dataset)
        )
    )
    return train_loss / len(train_loader.dataset)

def test(test_loader, ebm, nce):
    ebm.eval()
    test_loss = 0
    with torch.no_grad():
        for x in test_loader:
            x = x.to(device)

            # sum up batch loss
            test_loss += nce.loss(x).item()

    test_loss /= len(test_loader.dataset)

    print("====> Test set loss: {:.4f}".format(test_loss))
    return test_loss

The cell below will train an EBM with NCE on the selected generated dataset.

Feel free to explore the usage of the different datasets available and how the learning differs for different values of k and the noise distribution parameters.

In [None]:
# Specify parameters
n_epochs = 10
batch_size = 128
train_size = 0.8
n_samples = 10_000
distribution_params = [4]
k = 5
x_dim = 2
training_losses = np.zeros([n_epochs, 1])
test_losses = np.zeros([n_epochs, 1])
dataset_name = 'pinwheel' # one of '8gaussians', 'checkerboard', '2spirals', 'pinwheel'
results_dir = f'results/NCE/{dataset_name}'

os.makedirs(results_dir, exist_ok=True)
os.makedirs('models', exist_ok=True)

# Load generated data and create dataloaders
train_dataset = toy_dataset.return_dataset(dataset_name, int(train_size*n_samples))[0][:][0]
test_dataset  = toy_dataset.return_dataset(dataset_name, n_samples-int(train_size*n_samples))[0][:][0]
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

# Build model
ebm = MLPC(x_dim).to(device)
# Use defined noise dist
noise_dist = Gaussian(x_dim, distribution_params)
# Build NCE
nce = NCE(ebm, noise_dist, k)
optimizer = Adam(ebm.parameters())

# Perform training
start_time = time.time()
best_loss = torch.inf
for epoch in range(1, n_epochs):
    train_loss = train(epoch, train_dataloader, ebm, optimizer, nce)
    test_loss = test(test_dataloader, ebm, nce)
    if test_loss < best_loss:
        torch.save(ebm, 'models/nce.pt')
        best_loss = test_loss
    training_losses[epoch - 1] = train_loss
    test_losses[epoch - 1] = test_loss
    
    vis_nce.plot_nce(train_dataset, ebm, noise_dist, device, os.path.join(results_dir, f'epoch_{epoch}.png'))

end_time = time.time()
time_elapsed = end_time - start_time
minutes, seconds = divmod(time_elapsed, 60)
print("Time elapsed during training: %d minutes and %d seconds" % (minutes, seconds))

# Plot training and test losses
plt.plot(range(1, n_epochs), training_losses[:-1])
plt.plot(range(1, n_epochs), test_losses[:-1])
plt.grid()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Training Loss", "Test Loss"])
plt.savefig(os.path.join(results_dir, 'train_test_loss.png'))

In [None]:
ebm(torch.randn(5,2))

Now we want to also implement the conditional NCE (cNCE). What is considered the advantage of cNCE over NCE? Are there any potential disadvantages?

`YOUR ANSWER HERE`

In [None]:
class cNCE:
    def __init__(self, ebm) -> None:
        """
        Initializes the conditional NCE
        
        Args:
        ebm (EBM): The energy based model.
        noise_dist (torch.distributions.distribution): The noise distribution. Could be from torch.distributions. Can also be some custom distribution.
        """
        self.ebm = ebm
    
    def add_noise(self, x):
        '''
        returns the noised batch x
        '''
        # NOTE YOUR CODE HERE
        std = 1/2
        eps = torch.randn_like(x)
        return x + std * eps
    
    def loss(self, x) -> torch.Tensor:
        '''
        returns the cNCE-loss given a batch x from the dataset
        '''
        # NOTE YOUR CODE HERE
        x_noise = self.add_noise(x)
        log_p_theta_x = self.ebm(x)
        log_p_theta_x_noise = self.ebm(x_noise)
        tmp = torch.column_stack((log_p_theta_x, log_p_theta_x_noise))        
        loss = log_p_theta_x - torch.logsumexp(tmp, dim=1)

        return -loss.mean()        

In [None]:
# Specify parameters
n_epochs = 10
batch_size = 128
train_size = 0.8
n_samples = 10_000
x_dim = 2
training_losses = np.zeros([n_epochs, 1])
test_losses = np.zeros([n_epochs, 1])
dataset_name = 'pinwheel' # one of '8gaussians', 'checkerboard', '2spirals', 'pinwheel'
results_dir = f'results/cNCE/{dataset_name}'

os.makedirs(results_dir, exist_ok=True)
os.makedirs('models', exist_ok=True)

# Load generated data and create dataloaders
train_dataset = toy_dataset.return_dataset(dataset_name, int(train_size*n_samples))[0][:][0]
test_dataset  = toy_dataset.return_dataset(dataset_name, n_samples-int(train_size*n_samples))[0][:][0]
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

# Build model
ebm = MLPC(x_dim).to(device)
# Build cNCE
cnce = cNCE(ebm)
optimizer = Adam(ebm.parameters())

# Perform training
start_time = time.time()
best_loss = torch.inf
for epoch in range(1, n_epochs):
    train_loss = train(epoch, train_dataloader, ebm, optimizer, cnce)
    test_loss = test(test_dataloader, ebm, cnce)
    if test_loss < best_loss:
        torch.save(ebm, 'models/cnce.pt')
        best_loss = test_loss
    training_losses[epoch - 1] = train_loss
    test_losses[epoch - 1] = test_loss
    
    # had to use cpu on my machine to plot, otherwise weird cuda error.
    vis_nce.plot_cnce(train_dataset, ebm, cnce, device, os.path.join(results_dir, f'epoch_{epoch}.png'))

end_time = time.time()
time_elapsed = end_time - start_time
minutes, seconds = divmod(time_elapsed, 60)
print("Time elapsed during training: %d minutes and %d seconds" % (minutes, seconds))

# Plot training and test losses
plt.plot(range(1, n_epochs), training_losses[:-1])
plt.plot(range(1, n_epochs), test_losses[:-1])
plt.grid()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Training Loss", "Test Loss"])
plt.savefig(os.path.join(results_dir, 'train_test_loss.png'))

It may be observed that the cNCE is harder to train than the NCE. What are your observations when comparing training with NCE and cNCE?

`YOUR ANSWER HERE`

Now we have trained an EBM model via NCE, we want to sample from the trained EBM. Reuse implemented Langevin dynamics to do this. Consider the limitations of Langevin dynamics in your attempt to sample the target distribution. Chose the dataset you want and either the saved NCE or cNCE model.

In [None]:
step_size = 0.01
n_initializations = 10
steps = 3000

ebm = torch.load('models/nce.pt').to(device)
fig, axs = plt.subplots(1, 2, figsize=(12,4.3), subplot_kw={'aspect': 'equal'})
samples = []
for j in range(n_initializations):
    x = 3*torch.randn(2).to(device)
    xs, _ = sample_langevin(x, ebm, step_size, steps, intermediate_samples=True, inverse=True)
    samples += [x_.detach().cpu().numpy() for x_ in xs]

samples = torch.tensor(np.array(samples))

n_pts = 700
range_lim = 4

xx, yy, zz = vis_nce.setup_grid(range_lim, n_pts, device)
log_prob = ebm.to('cpu')(zz.to('cpu')).detach()
prob = log_prob.exp().cpu()
# plot
vis_nce.plot_samples(samples, axs[0], range_lim, n_pts)
axs[0].set_title('Langevin samples')

axs[1].pcolormesh(xx, yy, prob.view(n_pts,n_pts), cmap=plt.cm.jet)
axs[1].set_facecolor(plt.cm.jet(0.))
axs[1].set_title('Energy density')