In [None]:
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Imports

In [None]:
!git clone https://github.com/deepmind/flows_for_atomic_solids.git
!pip install -r flows_for_atomic_solids/requirements.txt

In [None]:
import distrax
import os
import pickle
import requests
import shutil
import subprocess

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import rcParams

from flows_for_atomic_solids.experiments import lennard_jones_config
from flows_for_atomic_solids.experiments import monatomic_water_config
from flows_for_atomic_solids.models import particle_models
from flows_for_atomic_solids.utils import observable_utils as obs_utils

rcParams.update({
    'font.size': 16, 'xtick.labelsize': 16, 'ytick.labelsize': 16,
    'legend.fontsize': 16, 'lines.linewidth': 3, 'axes.titlepad': 16,
    'axes.labelpad': 16, 'axes.labelsize': 20,
    'figure.figsize': [8.0, 6.0]})

## Load system configuration and parameters

In [None]:
def get_model_params(system_name, local_path):
  bucket_path = 'https://storage.googleapis.com/dm_flows_for_atomic_solids'
  source_url = os.path.join(bucket_path, system_name, 'params.pkl')
  dest_folder = os.path.join(local_path, 'flows_for_atomic_solids', system_name)
  subprocess.check_call(['mkdir', '-p', dest_folder])
  print(f'Downloading: {source_url}')
  dest_path = os.path.join(dest_folder, 'params.pkl')
  with requests.get(source_url, stream=True) as r, open(dest_path, 'wb') as w:
    r.raise_for_status()
    shutil.copyfileobj(r.raw, w)
  with open(dest_path, 'rb') as f:
    return pickle.load(f)

LOCAL_PATH = '/tmp'

available_systems = {
    'mw_cubic_64': (monatomic_water_config, dict(num_particles=64, lattice='cubic')),
    'mw_cubic_216': (monatomic_water_config, dict(num_particles=216, lattice='cubic')),
    'mw_cubic_512': (monatomic_water_config, dict(num_particles=512, lattice='cubic')),
    'mw_hex_64': (monatomic_water_config, dict(num_particles=64, lattice='hex')),
    'mw_hex_216': (monatomic_water_config, dict(num_particles=216, lattice='hex')),
    'mw_hex_512': (monatomic_water_config, dict(num_particles=512, lattice='hex')),
    'lj_256': (lennard_jones_config, dict(num_particles=256)),
    'lj_500': (lennard_jones_config, dict(num_particles=500)),
}

chosen_system = 'mw_cubic_512'

config_module, config_params = available_systems[chosen_system]
model_params = get_model_params(chosen_system, LOCAL_PATH)
config = config_module.get_config(**config_params)

## Create the model

In [None]:
state = config.state
box_length = state['upper'] - state['lower']
num_particles = state.num_particles

@hk.transform
def base_sample_and_logprob_fun(num_samples, config=config, state=state):
  base_config = config.model.kwargs.base
  if config.model.kwargs.translation_invariant:
    base_num_particles = state['num_particles'] - 1
  else:
    base_num_particles = state['num_particles']
  base_proposal = base_config['constructor'](
      num_particles=base_num_particles,
      lower=state['lower'],
      upper=state['upper'],
      **base_config['kwargs'])
  if config.model.kwargs.translation_invariant:
    base_proposal = particle_models.TranslationInvariant(base_proposal)
  return base_proposal.sample_and_log_prob(seed=hk.next_rng_key(), sample_shape=num_samples)

base_params = base_sample_and_logprob_fun.init(jax.random.PRNGKey(0), num_samples=1)
base_sample_and_logprob = jax.jit(base_sample_and_logprob_fun.apply, static_argnames='num_samples')

@hk.transform
def flow_sample_and_logprob_fun(num_samples, config=config, state=state):
  model = config.model['constructor'](
      num_particles=state['num_particles'],
      lower=state['lower'],
      upper=state['upper'],
      **config.model['kwargs'])
  return model.sample_and_log_prob(seed=hk.next_rng_key(), sample_shape=num_samples)

flow_sample_and_logprob = jax.jit(flow_sample_and_logprob_fun.apply, static_argnames='num_samples')

potential_energy_fn = config.test_energy['constructor'](**config.test_energy['kwargs'])
potential_energy = jax.jit(potential_energy_fn)

## Sample the model

In [None]:
def closest_power_of_2(x):
  return 2**int(np.log2(x) + 0.5)

def gather_samples(n, batch_fun):
  data = None
  digits = int(np.log10(n) + 1)
  i = 0
  while (data is None) or (len(jax.tree_flatten(data)[0][0]) < n):
    new_data = batch_fun(i)
    i += 1
    if data is None:
      data = new_data
    else:
      data = jax.tree_map(lambda a, b: jnp.concatenate((a, b)), data, new_data)
    prefix = ('\b' * (2 * digits + 3)) if i > 1 else ''
    print(f'{prefix}{len(jax.tree_flatten(data)[0][0]):0{digits}d} / {n}', end='')
  print()
  return jax.tree_map(lambda x: x[:n], data)

if jax.devices()[0].platform == 'cpu':
  print('WARNING: no accelerator found. The model will take a long time to '
        'compute energies and generate model samples, and may crash. \nA Colab '
        'kernel with a GPU or TPU accelerator is strongly recommended.')
  print()

batch_size = closest_power_of_2(32768/num_particles)
base_batch_size = batch_size * 32
energy_batch_size = closest_power_of_2(8192/num_particles)
N = 8192
print('Gathering base samples: ', end='')
base_samples, base_logprob = jax.tree_map(np.array, gather_samples(N, lambda i: base_sample_and_logprob(base_params, jax.random.PRNGKey(-i-27), num_samples=base_batch_size)))
print('Computing base energies: ', end='')
base_energies = np.array(gather_samples(N, lambda i: potential_energy(base_samples[i*energy_batch_size:(i+1)*energy_batch_size])))
print('Gathering model samples (initial jitting can take some time): ', end='')
model_samples, model_logprob = jax.tree_map(np.array, gather_samples(N, lambda i: flow_sample_and_logprob(model_params, jax.random.PRNGKey(i+4200000), num_samples=batch_size)))
print('Computing model energies: ', end='')
model_energies = np.array(gather_samples(N, lambda i: potential_energy(model_samples[i*energy_batch_size:(i+1)*energy_batch_size])))

## Energies + Radial distribution function (Fig. 2)

In [None]:
num_bins = 200
max_gr_samples = 512
flow_color = 'r'
base_color = 'b'
beta = state.beta

def equalize(x, y, n_points=100):
  length = np.cumsum(np.concatenate([[0], np.sqrt(np.diff(x)**2+np.diff(y)**2)]))
  l_values = np.linspace(length[0], length[-1], n_points)
  new_x = np.interp(l_values, length, x)
  new_y = np.interp(l_values, length, y)
  return new_x, new_y

base_gr = obs_utils.radial_distribution_function(coordinates=base_samples[:max_gr_samples], box_length=box_length)
model_gr = obs_utils.radial_distribution_function(coordinates=model_samples[:max_gr_samples], box_length=box_length)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16.0, 5.0), dpi=180, gridspec_kw={'wspace': 0.45})

ax = ax1
n_samples = len(model_energies)
_, _, patches = ax.hist(beta*base_energies/num_particles, bins=num_bins, density=True, histtype='stepfilled', color=base_color, alpha=0.5, linewidth=0, label='Base')
_, _, patches = ax.hist(beta*model_energies/num_particles, bins=num_bins, density=True, histtype='stepfilled', color=flow_color, alpha=0.5, linewidth=0, label='Model')

_ = ax.legend(loc='upper right', prop={'size': 18}, frameon=False)

ax.set_ylabel(r'Density', labelpad=-10)
ax.set_xlabel(r'$\beta U/N$', labelpad=-10)

xticks = ax.get_xticks()
ax.set_xticks([xticks[1], xticks[-2]])
ax.set_yticks([ax.get_yticks()[0], ax.get_yticks()[-1]])

ax = ax2
if chosen_system.startswith('mw'):  # monatomic water, scale r by sigma
  scale = 2.3925
else:  # Lennard-Jones, r is already scaled
  scale = 1.0

ax.plot(base_gr[:, 0] / scale, base_gr[:, 1], label='Base', color=base_color, linewidth=2, linestyle=':')
ax.plot(model_gr[:, 0] / scale, model_gr[:, 1], color=flow_color, label='Model', linewidth=2, linestyle='--', dashes=[3, 3])
ax.set_ylabel(r'$g(r)$')
ax.set_xlabel(r'$r / \sigma$', labelpad=0)
_ = ax.legend(loc='upper right', prop={'size': 18}, frameon=False)

## Histogram of work values (Fig. 4)

In [None]:
def plot_model_vs_true_logprob(mlp, tlp_unnormalized, target_log_z,
                               ax, percentile=0.01, color='r', label=None,
                               fontsize=None, margin=0, n_plot_samples=10000):
  if not ax.lines and not ax.collections:
    prev_xlim = None
    prev_ylim = None
  else:
    prev_xlim = ax.get_xlim()
    prev_ylim = ax.get_ylim()
  ax.set_xlabel(r'$\ln\ \hat{p}(x)$', fontsize=fontsize, labelpad=2)
  ax.set_ylabel(r'$\ln\ q(x)$', fontsize=fontsize, labelpad=2)

  tlp_normalized = tlp_unnormalized - target_log_z

  # Work out a suitable plot range according to the desired target percentile
  # of data points.
  pymax = np.nanpercentile(mlp, 100 - percentile)
  pymin = np.nanpercentile(mlp, percentile)
  target_pymax = np.nanpercentile(tlp_normalized, 100 - percentile)
  target_pymin = np.nanpercentile(tlp_normalized, percentile)
  pymin = np.minimum(pymin, target_pymin)
  pymax = np.maximum(pymax, target_pymax)
  alpha = min(1.0, 400 / min(n_plot_samples, len(tlp_normalized)))

  ax.scatter(
      tlp_normalized[:n_plot_samples],
      mlp[:n_plot_samples],
      c=color,
      alpha=alpha,
      linewidth=0,
      s=10,
      cmap=plt.get_cmap('magma'),
      label=label)

  if not (np.isinf(pymin) or np.isinf(pymax) or
          np.isnan(pymin) or np.isnan(pymax)):
    if prev_xlim is None:
      xlims = np.array([pymin, pymax])
      ylims = np.array([pymin, pymax])
    else:
      xlims = np.array([min(prev_xlim[0], pymin), max(prev_xlim[1], pymax)])
      ylims = np.array([min(prev_ylim[0], pymin), max(prev_ylim[1], pymax)])
    if margin:
      xlims = xlims + np.diff(xlims)*[-1, 1] * margin
      ylims = ylims + np.diff(ylims)*[-1, 1] * margin
    ax.set_xlim(*xlims)
    ax.set_ylim(*ylims)

  ax.plot(ax.get_xlim(), ax.get_ylim(), ls='--', c='.3', lw=1)
  ax.set_xticks([])
  ax.set_yticks([])
  return ax

log_factorial = lambda n: np.sum(np.log(np.arange(n)+1))
dimensionless_logvolume_mw = -3 * np.log(2.3925)

num_bins = 200
flow_color = 'r'
base_color = 'b'
beta = state.beta
box_vol = np.prod(box_length)
density = num_particles / box_vol

mbar_values = {
    'mw_cubic_64': -25.16306,
    'mw_cubic_216': -25.08234,
    'mw_cubic_512': -25.06156,
    'mw_hex_64': -25.18687,
    'mw_hex_216': -25.08975,
    'mw_hex_512': -25.06480,
    'lj_256': 3.10798,
    'lj_500': 3.12262,
    }

fig, ax = plt.subplots(figsize=(8.0, 6.0), dpi=180)

model_work = (beta*model_energies + model_logprob)/num_particles
base_work = (beta*base_energies + base_logprob)/num_particles
logz_from_mbar_value = mbar_values[chosen_system] - log_factorial(int(num_particles)) / num_particles
if chosen_system.startswith('mw'):  # monatomic water, scale dimensions
  model_work = model_work - dimensionless_logvolume_mw
  base_work = base_work - dimensionless_logvolume_mw
  target_log_z = -(logz_from_mbar_value + dimensionless_logvolume_mw) * num_particles
else:  # Lennard-Jones
  target_log_z = -logz_from_mbar_value * num_particles

_, _, patches = ax.hist(base_work, bins=num_bins, density=True, histtype='stepfilled', color=base_color, alpha=0.5, linewidth=3, label='Base')
ax.hist(base_work, bins=num_bins, density=True, histtype='step', linewidth=2, linestyle='-', color=patches[-1].get_facecolor())
_, _, patches = ax.hist(model_work, bins=num_bins, density=True, histtype='stepfilled', color=flow_color, alpha=0.5, linewidth=3, label='Model')
ax.hist(model_work, bins=num_bins, density=True, histtype='step', linewidth=2, linestyle='-', color=patches[-1].get_facecolor())
ax.axvline(logz_from_mbar_value, color='g', linewidth=1, ls='--')

ax.set_xticks([ax.get_xticks()[0], ax.get_xticks()[-1]])
ax.set_yticks([ax.get_yticks()[0], ax.get_yticks()[-1]])
ax.set_xlabel(r'$\beta \Phi / N$', labelpad=-10, fontsize=20)
plt.ylabel(r'Density', labelpad=-10, fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

xlim = ax.get_xlim()
ylim = ax.get_ylim()
axins2 = inset_axes(ax, width='70%', height='80%', bbox_to_anchor=(logz_from_mbar_value, ylim[0]+.1*np.diff(ylim), .35*np.diff(xlim), .3*np.diff(ylim)), bbox_transform=ax.transData)
_, _, patches = axins2.hist(model_work, bins=num_bins, density=True, histtype='stepfilled', color=flow_color, alpha=0.5, linewidth=3, label='Model')
axins2.axvline(logz_from_mbar_value, color='g', linewidth=1, ls='--')
axins2.set_xticks([axins2.get_xticks()[0], axins2.get_xticks()[-1]])
axins2.tick_params(axis='x', labelsize=12)
axins2.set_yticks([])
ax.annotate('',
            xy=(logz_from_mbar_value+0.01*np.diff(xlim), ylim[0] + 0.25*np.diff(ylim)), xycoords='data',
            xytext=(logz_from_mbar_value + .08*np.diff(xlim), ylim[0] + 0.25*np.diff(ylim)), textcoords='data',
            arrowprops=dict(arrowstyle='->',
                            connectionstyle='arc3'),
            )


axins = inset_axes(ax, width='50%', height='50%')
n_plot_samples = 1000
mlp = model_logprob
tlp = -beta * model_energies
plot_model_vs_true_logprob(mlp=base_logprob, tlp_unnormalized=-beta * base_energies, target_log_z=target_log_z,
                           ax=axins, color=base_color, label='Base', fontsize=16,
                           n_plot_samples=n_plot_samples)

_ = plot_model_vs_true_logprob(mlp=model_logprob, tlp_unnormalized=-beta * model_energies, target_log_z=target_log_z,
                               ax=axins, color=flow_color, label='Flow', fontsize=16, margin=0.05,
                               n_plot_samples=n_plot_samples)
_ = ax.legend(loc='center right', bbox_to_anchor=(0, 0, 1.0, 0.5), prop={'size': 18}, frameon=False)

## Free energy estimation (LFEP)

In [None]:
logZ_fep = -(obs_utils.compute_logz(model_logprob, -beta*model_energies) - log_factorial(num_particles)) / num_particles
if chosen_system.startswith('mw'):  # monatomic water, scale dimensions
  logZ_fep -= dimensionless_logvolume_mw

print(f'LFEP estimate: {logZ_fep:.5f}')