# Tutorial

The goal of this tutorial colab is to train the normalizing flow model proposed in reference [1] on a small system comprising 8 particles of monatomic water (mW) in the cubic ice phase. 

All the relevant code is available on github [2] as supplemental material to the publication but some parts of the logic are missing here and need to be implemented for the model to train. In particular, in the functions `get_loss`, `get_eval_metrics`, and `get_normalised_target_log_probs` below you have to supply the missing code as specified in the *Todo* field of their docstrings.

Training the model with the hyperparameters (config) below does not require any hardware accelerators and the model should reach an effective sample size (ESS) of about 10% on a public CPU kernel in under 10 minutes.


<br/>

**References**

[1] Wirnsberger, Papamakarios, Ibarz et al., *Normalizing flows for atomic solids*, Mach. Learn.: Sci. Technol. 3 025009 (2022), [link](https://iopscience.iop.org/article/10.1088/2632-2153/ac6b16).

[2] Supplemental code for *Normalizing flows for atomic solids* on github: [deepmind/flows_for_atomic_solids](https://github.com/deepmind/flows_for_atomic_solids).

[3] Jarzynski, *Targeted free energy perturbation*, Phys. Rev. E 65, 046122 (2002), [link](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.65.046122).

[4] Wirnsberger, Ballard et al., *Targeted free energy estimation via learned mappings*, J. Chem. Phys. 153, 144112 (2020), [link](https://doi.org/10.1063/5.0018903).

[5] Nicoli et al., *Asymptotically unbiased estimation of physical observables with neural samplers*, Phys. Rev. E 101, 023304 (2020), [link](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.101.023304).

[6] Frenkel and Smit, *Understanding molecular simulation*, 2nd edition, San Diego (2002), [link](https://www.sciencedirect.com/book/9780122673511/understanding-molecular-simulation).

## Imports

In [None]:
!git clone https://github.com/deepmind/flows_for_atomic_solids.git
!pip install -r flows_for_atomic_solids/requirements.txt

In [None]:
from typing import Callable, Dict, Tuple, Union
from absl import app
from absl import flags
import chex
import distrax
from flows_for_atomic_solids.experiments import monatomic_water_config
from flows_for_atomic_solids.experiments import utils
from flows_for_atomic_solids.utils import observable_utils as obs_utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import rcParams
import optax

Array = chex.Array
Numeric = Union[Array, float]

rcParams.update({
    'font.size': 16, 'xtick.labelsize': 16, 'ytick.labelsize': 16,
    'legend.fontsize': 16, 'lines.linewidth': 3, 'axes.titlepad': 16,
    'axes.labelpad': 16, 'axes.labelsize': 20,
    'figure.figsize': [8.0, 6.0]})

## Model specification

All relevant model and training hyperparameters, such as the number of flow layers, the batch size, the learning rate, the energy function, the thermodynamic state, etc., are defined in a `ConfigDict`. Training the model with the hyperparameters used in reference [1], see for example [monatomic_water_config.py](https://github.com/deepmind/flows_for_atomic_solids/blob/main/experiments/monatomic_water_config.py), would require multiple accelerators, such as GPUs. We therefore overwrite some of the parameters of the above config in order to reduce the model size and to optimise for quick training on a CPU. 

With the settings below, the model should train to about 10% ESS in less than ten minutes on a CPU. The model will still improve if you train it longer (by increasing `num_training_steps`).

In [None]:
# Get the default config for 8-particle cubic ice. 
config = monatomic_water_config.get_config(num_particles=8, lattice='cubic')

# Update specific hyperparameters for fast training. 
config.train.learning_rate=1e-3
config.model.kwargs.bijector.kwargs.num_layers=4
config.model.kwargs.bijector.kwargs.num_bins=16
config.model.kwargs.bijector.kwargs.conditioner.kwargs.num_frequencies=2
config.model.kwargs.bijector.kwargs.conditioner.kwargs.embedding_size=32
config.test.batch_size=16384
config.test.test_every=500

# `state` is a dictionary that contains information about the thermodynamic 
# state, such as the number of particles, the inverse temperature, and the 
# extents of the simulation box.
state = config.state

# Defines the number of training steps (the number of parameters updates).
num_training_steps=501

## Training the model

In this section, we define the training objective and a set of evaluation metrics that allows us to monitor the training progress. 

### Training objective

We train our model $q(x)$ to approximate the target Boltzmann distribution 

\begin{equation*}
  p(x) = \frac{1}{Z} e^{-\beta U(x)}
\end{equation*}


 by minimizing a Kullback&ndash;Leibler divergence as the loss function:

\begin{equation*}
    D(q || p) = {\langle{\ln{ q(x)} - \ln{p(x)}}\rangle}_q = {\langle{\ln{ q(x)} + \beta U(x)}\rangle}_q + \ln{Z}.  
\end{equation*}

The last term, $\ln Z$, is the logarithm of the normalizing constant and it can be ignored as its gradient with respect to the model parameters vanishes. The function $U(x)$ is a given energy function (here the mW potential) and $\beta = (k_\text{B} T)^{-1}$ is the inverse temperature with $k_\text{B}$ being the Boltzmann constant.

In [None]:
def get_loss(model: distrax.Distribution, energy_fn: Callable[[Array], Array],
             beta: Numeric, num_samples: int) -> Tuple[Array, Dict[str, Array]]:
  """Returns the loss and some additional metrics.

  Args:
    model: our model from which we can sample.
    energy_fn: a function that takes a batch of samples and returns a batch of
      energies.
    beta: the inverse temperature.
    num_samples: the number of samples to be used for computing the loss.

  Returns:
    The scalar loss and a dictionary containing energies, model log probs and 
    target log probs.

  Todo: 
    Draw a batch of samples from the model and implement the quantities that 
    are currently set to zero.
  """
  rng_key = hk.next_rng_key()

  zeros = jnp.zeros(config.train.batch_size)
  loss = 0.                # <ln q(x) + \beta U(x)>
  energy = zeros           # U(x)
  model_log_prob = zeros   # ln q(x)
  target_log_prob = zeros  # ln p(x)
  
  stats = {
      'energy': energy,
      'model_log_prob': model_log_prob,
      'target_log_prob': target_log_prob
  }
  return loss, stats

### Evaluation metrics

After every `config.test.test_every` training steps, we evaluate a set of metrics to monitor the training progress.
<br/><br/>

**Normalizing constant:**
To estimate $\ln Z$, we can use a targeted free energy estimator [3&ndash;5]. We first compute the forward work values

\begin{equation*}
  \beta \Phi(x) = \beta U(x) + \ln{q(x)}
\end{equation*}

and then the asymptotically unbiased estimate

\begin{equation*}
\ln{Z} = \ln { \langle{\exp(-\beta \Phi(x)) \rangle }}_q.
\end{equation*}
<br/>

**Expected energy:**
We can compute an unbiased estimate of the potential energy, ${\langle U \rangle}$, via importance sampling

\begin{equation*}
{\langle U \rangle} = \frac{ \sum_n w_n U(x_n)} {\sum_n w_n},
\end{equation*}

where $w_n = p^*(x_n)/q(x_n)$, $x_n \sim q(x)$ and $p^*(x) = Z p(x)$.
<br/><br/>

**Effective sample size (ESS):**
The effective samples size can be estimated as 

\begin{equation*}
  {\text{ESS}} = \frac{ {\left( \sum_n w_n \right)}^2} {\sum_n w_n^2}.
\end{equation*}
<br/>

**Helmholtz Free energy:**
Knowing the normalizing constant, we can compute the Helmholtz free energy $F$ via the relation

\begin{equation*}
e^{-\beta F}  = \frac{Z}{N! \Lambda^{3N}},
\end{equation*}

where $N$ is the number of particles in the system and
$\Lambda = 2.3925~\overset{\circ}{\text{A}}$ is the thermal de Broglie wavelength, which we set to the same value as the $\sigma$ parameter of the mW potential.



In [None]:
def get_eval_metrics(loss: Array, beta: Numeric, num_particles: int,
                     energy: Array, model_log_prob: Array,
                     target_log_prob: Array) -> Dict[str, Array]:
  """Returns the evaluation metrics.

  Args:
    loss: a scalar containing the loss.
    beta: the inverse temperature.
    num_particles: the number of particles.
    energy: an array with shape [batch_size] containing energy values.
    model_log_prob: an array with shape [batch_size] containing log probs 
      under the model.
    target_log_prob: an array with shape [batch_size] containing log probs 
      under the target.

  Returns:
     A dictionary containing the evaluation metrics.

  Todo: 
    Implement the metrics that are currently set to zero.
  """
  energy_biased = 0.     # <U(x)>_q
  energy_unbiased = 0.   # <U(x)>_p
  ess = 0.               # ESS
  logz = 0.              # log Z
  beta_f = 0.            # \beta F/N
  metrics = {
      'loss': loss,
      'energy_biased': energy_biased,
      'energy_unbiased': energy_unbiased,
      'ess': ess,
      'logz': logz,
      'beta_f': beta_f,
  }
  return metrics

### Training loop


In [None]:
def create_model():
  return config.model['constructor'](
      num_particles=state.num_particles,
      lower=state.lower,
      upper=state.upper,
      **config.model['kwargs'])

def train(num_iterations: int):
  energy_fn_train = config.train_energy.constructor(
      **config.train_energy.kwargs)
  energy_fn_test = config.test_energy.constructor(**config.test_energy.kwargs)
  lr_schedule_fn = utils.get_lr_schedule(
      config.train.learning_rate, config.train.learning_rate_decay_steps,
      config.train.learning_rate_decay_factor)
  optimizer = optax.chain(
      optax.scale_by_adam(),
      optax.scale_by_schedule(lr_schedule_fn),
      optax.scale(-1))
  if config.train.max_gradient_norm is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(config.train.max_gradient_norm), optimizer)

  def loss_fn():
    """Loss function for training."""
    model = create_model()
    loss, stats = get_loss(
        model=model,
        energy_fn=energy_fn_train,
        beta=state.beta,
        num_samples=config.train.batch_size)
    train_metrics = dict(loss=loss, energy_biased=stats['energy'].mean())
    return loss, train_metrics

  def print_formatted(mode, step, metrics):
    """Output the training progress with nice formatting."""
    print(f'{mode}[{step}]') 
    for k in sorted(metrics.keys()):
      print(f'   {k:<20}: {metrics[k]:g}')
    print('-' * 34)
    
  def eval_fn():
    """Evaluation function."""
    model = create_model()
    loss, stats = get_loss(
        model=model,
        energy_fn=energy_fn_test,
        beta=state.beta,
        num_samples=config.test.batch_size)
    return get_eval_metrics(loss=loss, beta=state.beta, 
                            num_particles=state.num_particles, **stats)
 
  print('Initialising system.')
  rng_key = jax.random.PRNGKey(config.train.seed)
  init_fn, apply_fn = hk.transform(loss_fn)
  _, apply_eval_fn = hk.transform(eval_fn)

  rng_key, init_key = jax.random.split(rng_key)
  params = init_fn(init_key)
  opt_state = optimizer.init(params)

  def _loss(params, rng):
    loss, metrics = apply_fn(params, rng)
    return loss, metrics

  jitted_loss = jax.jit(jax.value_and_grad(_loss, has_aux=True))
  jitted_eval = jax.jit(apply_eval_fn)
  
  step = 0
  print('Beginning of training.')
  while step < num_iterations:
    rng_key, loss_key = jax.random.split(rng_key)
    (_, metrics), g = jitted_loss(params, loss_key)

    if (step % 50) == 0:
      print_formatted('Train', step, metrics)

    if (step % config.test.test_every) == 0:
      rng_key, val_key = jax.random.split(rng_key)
      metrics = jitted_eval(params, val_key)
      print_formatted('Valid', step, metrics)
     
    # Update parameters.
    updates, opt_state = optimizer.update(g, opt_state, params)
    params = optax.apply_updates(params, updates)
    step += 1
    
  return params

params = train(num_training_steps)
print('Done')

Some reference values:
- $\langle U \rangle_p \approx -94.64~\text{kcal/mol}$
- $\beta F/N \approx -25.86$ 


## Analyzing the trained model

### Sampling from the model


The flow model is a probability distribution $q$ that is related to the base density $b$ via a diffeomorphism $f$. We can sample from $q$ by first sampling $z$ from $b$ and then taking $x = f(z)$. The probability density of `x` is given by

\begin{equation}
    q(x) = b(z)|\det J_f(z)|^{-1},
\end{equation}

where $J_f$ is the Jacobian of $f$. 

Here, $q$ and $b$ are implemented as `distrax.Distribution` and have a function `sample_and_log_prob` that returns a batch of samples with corresponding log probs.

In [None]:
@hk.transform
def sample_and_log_prob(num_samples: int):
  """Returns samples and log probs from the base and the trained model."""
  key = hk.next_rng_key()
  model = create_model()
  return (
      model.sample_and_log_prob(seed=key, sample_shape=num_samples),
      model._base_model._flow_model.distribution.sample_and_log_prob(
          seed=key, sample_shape=num_samples))

((model_samples, model_log_probs), 
 (base_samples, base_log_probs)) = sample_and_log_prob.apply(params, jax.random.PRNGKey(42), 4096)

### Estimating normalised target log probs

To evaluate

\begin{equation*}
\ln p(x) = -\beta U(x) - \ln Z,
\end{equation*}

we need the exact value of $\ln Z$, which is unknown. However, we have a targeted estimate, $\widehat{\ln Z}$, that we can use to compute approximately normalised target log probs,

\begin{equation*}
\ln \hat p(x) = -\beta U(x) - \widehat{\ln Z}.
\end{equation*}


In [None]:
energy_fn = config.test_energy.constructor(**config.test_energy.kwargs)

def get_normalised_target_log_probs(  
    model_samples: Array,
    model_log_probs: Array) -> Array:
  """Returns the (approximately) normalised target log probs.

  Args:
    model_samples: samples drawn from the model.
    model_log_probs: model log probs for the `model_samples`.
 
  Returns:
    approximately normalised target log probs.

  Todo: 
    Estimate the normalising constant and use it to estimate the normalised
    target log probs.
  """
  normalised_target_log_probs = jnp.zeros_like(model_log_probs)
  return normalised_target_log_probs

normalised_target_log_probs = get_normalised_target_log_probs(model_samples, model_log_probs)

### Plotting model vs. target log probs
We evaluate both the model log probs $\ln q$ and the approximate target log probs $\ln \hat p$ on the same batch of samples drawn from the model, and then plot them against each other.

In [None]:
xmin = min(normalised_target_log_probs.min(), model_log_probs.min(), base_log_probs.min())
xmax = max(normalised_target_log_probs.max(), model_log_probs.max(), base_log_probs.max())
x = np.linspace(xmin, xmax, 100)
plt.xlim((xmin, xmax))
plt.ylim((xmin, xmax))
plt.scatter(normalised_target_log_probs, base_log_probs, c='blue', alpha=0.2, label='base')
plt.scatter(normalised_target_log_probs, model_log_probs, c='red', alpha=0.2, label='model')
plt.plot(x, x, linestyle='--', c='black')
plt.xlabel(r'$\ln \hat{p}(x)$')
plt.ylabel(r'$\ln q(x)$')
plt.legend()
plt.show()

### Plotting the radial distribution function

The radial distribution function $g(r)$ is the ratio of the average number density at a distance $r$ of an arbitrary reference atom and the average number density in an ideal gas at the same overall density (see reference [6] for more details).

#### Reference values
The reference values for $g(r)$ were obtained with Hamiltonian Monte Carlo and are stored in the array `reference_rdf`. 

In [None]:
reference_rdf = np.array([
       [1.5499999e-02, 0.0000000e+00],
       [4.6499986e-02, 0.0000000e+00],
       [7.7499993e-02, 0.0000000e+00],
       [1.0849999e-01, 0.0000000e+00],
       [1.3949999e-01, 0.0000000e+00],
       [1.7050001e-01, 0.0000000e+00],
       [2.0149997e-01, 0.0000000e+00],
       [2.3250003e-01, 0.0000000e+00],
       [2.6350003e-01, 0.0000000e+00],
       [2.9449993e-01, 0.0000000e+00],
       [3.2549998e-01, 0.0000000e+00],
       [3.5650003e-01, 0.0000000e+00],
       [3.8749993e-01, 0.0000000e+00],
       [4.1849995e-01, 0.0000000e+00],
       [4.4950002e-01, 0.0000000e+00],
       [4.8050013e-01, 0.0000000e+00],
       [5.1150000e-01, 0.0000000e+00],
       [5.4249990e-01, 0.0000000e+00],
       [5.7349980e-01, 0.0000000e+00],
       [6.0449988e-01, 0.0000000e+00],
       [6.3549995e-01, 0.0000000e+00],
       [6.6650003e-01, 0.0000000e+00],
       [6.9749999e-01, 0.0000000e+00],
       [7.2850013e-01, 0.0000000e+00],
       [7.5949979e-01, 0.0000000e+00],
       [7.9049981e-01, 0.0000000e+00],
       [8.2149994e-01, 0.0000000e+00],
       [8.5249996e-01, 0.0000000e+00],
       [8.8349998e-01, 0.0000000e+00],
       [9.1449994e-01, 0.0000000e+00],
       [9.4549960e-01, 0.0000000e+00],
       [9.7649974e-01, 0.0000000e+00],
       [1.0074998e+00, 0.0000000e+00],
       [1.0384998e+00, 0.0000000e+00],
       [1.0695000e+00, 0.0000000e+00],
       [1.1005000e+00, 0.0000000e+00],
       [1.1315001e+00, 0.0000000e+00],
       [1.1625001e+00, 0.0000000e+00],
       [1.1934999e+00, 0.0000000e+00],
       [1.2245001e+00, 0.0000000e+00],
       [1.2555002e+00, 0.0000000e+00],
       [1.2864996e+00, 0.0000000e+00],
       [1.3174998e+00, 0.0000000e+00],
       [1.3484998e+00, 0.0000000e+00],
       [1.3794997e+00, 0.0000000e+00],
       [1.4104997e+00, 0.0000000e+00],
       [1.4414998e+00, 0.0000000e+00],
       [1.4724998e+00, 0.0000000e+00],
       [1.5035000e+00, 0.0000000e+00],
       [1.5345000e+00, 0.0000000e+00],
       [1.5655000e+00, 0.0000000e+00],
       [1.5965002e+00, 0.0000000e+00],
       [1.6275002e+00, 0.0000000e+00],
       [1.6585003e+00, 0.0000000e+00],
       [1.6894995e+00, 0.0000000e+00],
       [1.7204996e+00, 0.0000000e+00],
       [1.7514995e+00, 0.0000000e+00],
       [1.7824997e+00, 0.0000000e+00],
       [1.8134997e+00, 0.0000000e+00],
       [1.8444999e+00, 0.0000000e+00],
       [1.8755000e+00, 0.0000000e+00],
       [1.9065002e+00, 0.0000000e+00],
       [1.9374998e+00, 0.0000000e+00],
       [1.9684998e+00, 0.0000000e+00],
       [1.9995000e+00, 0.0000000e+00],
       [2.0304999e+00, 0.0000000e+00],
       [2.0615001e+00, 0.0000000e+00],
       [2.0924993e+00, 0.0000000e+00],
       [2.1235003e+00, 0.0000000e+00],
       [2.1544995e+00, 0.0000000e+00],
       [2.1854997e+00, 0.0000000e+00],
       [2.2164996e+00, 4.4473531e-04],
       [2.2474999e+00, 6.4883090e-04],
       [2.2784998e+00, 3.1564615e-03],
       [2.3095002e+00, 1.0650708e-02],
       [2.3404999e+00, 2.8518641e-02],
       [2.3714993e+00, 8.4110245e-02],
       [2.4024999e+00, 1.9078535e-01],
       [2.4334996e+00, 3.9773461e-01],
       [2.4645002e+00, 7.6246047e-01],
       [2.4954996e+00, 1.3214793e+00],
       [2.5265002e+00, 2.0774045e+00],
       [2.5574994e+00, 2.9314387e+00],
       [2.5885000e+00, 3.8366539e+00],
       [2.6194994e+00, 4.5187459e+00],
       [2.6505001e+00, 4.9712882e+00],
       [2.6814997e+00, 5.1065240e+00],
       [2.7125003e+00, 4.8710790e+00],
       [2.7434998e+00, 4.3332777e+00],
       [2.7744992e+00, 3.6022041e+00],
       [2.8054998e+00, 2.8767633e+00],
       [2.8364992e+00, 2.1645274e+00],
       [2.8675001e+00, 1.5309756e+00],
       [2.8984995e+00, 1.0224787e+00],
       [2.9295001e+00, 6.5660924e-01],
       [2.9604993e+00, 4.1806403e-01],
       [2.9915004e+00, 2.4293379e-01],
       [3.0224993e+00, 1.4242552e-01],
       [3.0535002e+00, 7.6981105e-02],
       [3.0844996e+00, 4.6274789e-02]])

#### Comparison

In [None]:
box_length = config.test_energy.kwargs.box_length
model_rdf = obs_utils.radial_distribution_function(model_samples, box_length, num_bins=100)
base_rdf = obs_utils.radial_distribution_function(base_samples, box_length, num_bins=100)
plt.plot(base_rdf[:, 0], base_rdf[:, 1], linestyle='--', c='blue', label='base')
plt.plot(model_rdf[:, 0], model_rdf[:, 1], linestyle='-', c='red', label='model')
plt.plot(reference_rdf[:, 0], reference_rdf[:, 1], linestyle='dotted', c='black', label='reference')
plt.xlabel(r'$r$')
plt.ylabel(r'$g(r)$')
plt.legend()
plt.show()