In [None]:
import os
import numpy as np
import theano as th
import theano.tensor as tt
import matplotlib
import matplotlib.pyplot as plt
from thermomc import continuous_temp, discrete_temp, control_funcs, hmc
import seaborn as sns
%matplotlib inline

## Create experiment directory

In [None]:
base_dir = os.path.dirname(os.getcwd())
exp_dir = os.path.join(base_dir, 'experiments', '1d-bimodal')
if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)

## Set plot style settings

In [None]:
sns.set_context('paper')
sns.set(font='sans')
sns.set_style('white', {
    'font.family': 'sans',
    'axes.labelcolor': '0.',
    'text.color': '0.',
    'xtick.color': '0.',
    'ytick.color': '0.'
})

In [None]:
params = {
    'text.latex.preamble' : [
        r'\usepackage[notextcomp]{stix}', 
        r'\usepackage{amsmath}', 
        r'\usepackage{helvet}',
        r'\renewcommand{\rmdefault}{\sfdefault}',
    ],
    'font.family': "sans-serif",
    'font.size' : 12,
    'text.usetex': True 
}
matplotlib.rcParams.update(params)

## Define model functions

In [None]:
mu_1 = -8
mu_2 = 8
sigma_1 = 1
sigma_2 = 2
p_1 = 0.6
p_2 = 0.4

var_mean = p_1 * mu_1 + p_2 * mu_2
var_std = (
    p_1 * sigma_1**2 + p_2 * sigma_2**2 + 
    p_1 * mu_1**2 + p_2 * mu_2**2 - var_mean**2
)**0.5

def phi(x):
    ret_val = -tt.log(
        (p_1 / (2 * np.pi * sigma_1**2)**0.5) * tt.exp(-0.5 * ((x - mu_1) / sigma_1)**2) +
        (p_2 / (2 * np.pi * sigma_2**2)**0.5) * tt.exp(-0.5 * ((x - mu_2) / sigma_2)**2)
    )
    if x.ndim == 2:
        return ret_val.sum(-1)
    else:
        return ret_val

def psi(x):
    ret_val = (
        0.5 * ((x - var_mean) / var_std)**2 + 0.5 * tt.log(2 * np.pi) + tt.log(var_std)
    )
    if x.ndim == 2:
        return ret_val.sum(-1)
    else:
        return ret_val

def joint_energy_u(x, u):
    beta = tt.nnet.sigmoid(u)
    return beta * phi(x) + (1 - beta) * psi(x) - tt.log(beta * (1 - beta))

def joint_energy_beta(x, beta):
    return beta * phi(x) + (1 - beta) * psi(x)

x, u, beta = tt.vectors('x', 'u', 'beta')
phi_func = th.function([x], phi(x))
psi_func = th.function([x], psi(x))
joint_energy_u_func = th.function([x, u], joint_energy_u(x, u))
joint_energy_beta_func = th.function([x, beta], joint_energy_beta(x, beta))

## Visualise target and base densities

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
xs = np.linspace(-20, 20, 200)
ax.plot(xs, np.exp(-phi_func(xs)))
ax.plot(xs, np.exp(-psi_func(xs)))
ax.legend([r'Target $\,\frac{1}{Z}\,\exp[-\phi(x)]$', r'Base $\,\exp[-\psi(x)]$'])
ax.set_xlabel(r'Target state $x$')
ax.set_ylabel(r'Probability density')
ax.set_xticklabels(['{0:.0f}'.format(tick) for tick in ax.get_xticks()], fontsize=9)
ax.set_yticklabels(['{0:.2f}'.format(tick) for tick in ax.get_yticks()], fontsize=9)
fig.tight_layout(pad=0)
fig.savefig(os.path.join(exp_dir, 'bimodal-gm-target-and-gaussian-base.pdf'))

## Visualise joint energy / density

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-20, 20, 100)
u_lin = np.linspace(-8, 8, 100)
x_grid, u_grid = np.meshgrid(x_lin, u_lin)
energies_u = joint_energy_u_func(
    x_grid.flatten(), u_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, u_grid, (energies_u), cmap='magma', shading='gouraud')
ax.contour(x_grid, u_grid, energies_u, 30, linewidths=0.2, colors='w')

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-30, 30, 100)
beta_lin = np.linspace(0, 1, 100)
x_grid, beta_grid = np.meshgrid(x_lin, beta_lin)
energies_beta = joint_energy_beta_func(
    x_grid.flatten(), beta_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid,  beta_grid, np.exp(-energies_beta), cmap='magma', shading='gouraud')

## Continuously tempered HMC

In [None]:
seed = 1234
temp_scale = 1.

In [None]:
pos = tt.matrix('pos')
tmp_ctrl = tt.vector('tmp_ctrl')
dt, mom_resample_coeff = tt.scalars('dt', 'mom_resample_coeff')
n_step, n_sample = tt.iscalars('n_step', 'n_sample')
hmc_params = {'dt': dt, 'mom_resample_coeff': mom_resample_coeff, 'n_step': n_step}
ctrl_func = control_funcs.SigmoidalControlFunction(temp_scale)
sampler = continuous_temp.JointContinuousTemperingSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)

In [None]:
(pos_samples, tmp_ctrl_samples, inv_temp_samples, 
 probs_0, probs_1, accepts, updates) = sampler.chain(
    pos, tmp_ctrl, None, phi, psi, ctrl_func, n_sample, hmc_params)
jct_chain_func = th.function(
    [pos, tmp_ctrl, dt, n_step, n_sample, mom_resample_coeff],
    [pos_samples, tmp_ctrl_samples, inv_temp_samples, 
     probs_0, probs_1, accepts],
    updates=updates
)

In [None]:
rng = np.random.RandomState(seed)
pos_init = rng.normal(size=(1, 1))
tmp_ctrl_init = rng.normal(size=(1,))

In [None]:
dt = 1.
n_step = 20
n_sample = 1000
mom_resample_coeff = 1.
(pos_samples, tmp_ctrl_samples, inv_temp_samples, 
 probs_0, probs_1, accepts) = jct_chain_func(
    pos_init, tmp_ctrl_init, dt, n_step, n_sample, mom_resample_coeff
)
print(accepts.mean())

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-20, 20, 100)
u_lin = np.linspace(-10, 10, 100)
x_grid, u_grid = np.meshgrid(x_lin, u_lin)
energies_u = joint_energy_u_func(
    x_grid.flatten(), u_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, u_grid, np.exp(-energies_u), cmap='magma', shading='gouraud')
ax.plot(pos_samples[:, 0, 0], tmp_ctrl_samples[:, 0], 'g.', ms=4)

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-20, 20, 100)
u_lin = np.linspace(-10, 10, 100)
x_grid, u_grid = np.meshgrid(x_lin, u_lin)
energies_u = joint_energy_u_func(
    x_grid.flatten(), u_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, u_grid, (energies_u), cmap='magma', shading='gouraud')
ax.contour(x_grid, u_grid, energies_u, 30, linewidths=0.2, colors='w')
ax.plot(pos_samples[:, 0, 0], tmp_ctrl_samples[:, 0], '.')

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-30, 30, 100)
beta_lin = np.linspace(0, 1, 100)
x_grid, beta_grid = np.meshgrid(x_lin, beta_lin)
energies_beta = joint_energy_beta_func(
    x_grid.flatten(), beta_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid,  beta_grid, np.exp(-energies_beta), cmap='magma', shading='gouraud')
ax.plot(pos_samples[:, 0, 0], inv_temp_samples[:, 0], '.')

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-30, 30, 100)
beta_lin = np.linspace(0, 1, 100)
x_grid, beta_grid = np.meshgrid(x_lin, beta_lin)
energies_beta = joint_energy_beta_func(
    x_grid.flatten(), beta_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, beta_grid, (energies_beta), cmap='magma', shading='gouraud')
ax.contour(x_grid, beta_grid, energies_beta, 30, linewidths=0.2, colors='w')
ax.plot(pos_samples[:, 0, 0], inv_temp_samples[:, 0], '.')

In [None]:
sampler.srng.seed(seed)
j = -100
pos_init = pos_samples[j, 0:1]
tmp_ctrl_init = tmp_ctrl_samples[j, 0:1]
dt = 0.1
n_step = 2
n_sample = 100
mom_resample_coeff = 0.
(pos_traj, tmp_ctrl_traj, inv_temp_traj, 
 _, _, accepts) = jct_chain_func(
    pos_init, tmp_ctrl_init, dt, n_step, n_sample, mom_resample_coeff
)
print accepts.min()

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-20, 20, 200)
u_lin = np.linspace(-10, 10, 200)
x_grid, u_grid = np.meshgrid(x_lin, u_lin)
energies_u = joint_energy_u_func(
    x_grid.flatten(), u_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, u_grid, (energies_u), cmap='magma', shading='gouraud')
ax.contour(x_grid, u_grid, energies_u, 30, linewidths=0.3, colors='w')
ax.plot(pos_traj[::2, 0, 0], tmp_ctrl_traj[::2, 0], '.-', lw=1, color='limegreen')
ax.set_xlabel('Target state $x$')
ax.set_ylabel('Temperature control $u$')
ax.set_xticklabels(['{0:.0f}'.format(tick) for tick in ax.get_xticks()], fontsize=9)
ax.set_yticklabels(['{0:.0f}'.format(tick) for tick in ax.get_yticks()], fontsize=9)
fig.tight_layout(pad=0)
fig.savefig(os.path.join(exp_dir, 'jct-energy-and-trajectory.pdf'))

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
x_lin = np.linspace(-20, 20, 200)
beta_lin = np.linspace(0, 1, 200)
x_grid, beta_grid = np.meshgrid(x_lin, beta_lin)
energies_beta = joint_energy_beta_func(
    x_grid.flatten(), beta_grid.flatten()).reshape(x_grid.shape)
ax.pcolormesh(x_grid, beta_grid, -np.exp(-energies_beta), cmap='magma', shading='gouraud')
ax.contour(x_grid, beta_grid, np.exp(-energies_beta), 15, linewidths=0.3, colors='k')
xs, betas, p1s = pos_samples[:, 0, 0], inv_temp_samples[:, 0], probs_1[:, 0]
ax.scatter(xs, betas, c='k', s=(5 * p1s + 1), linewidths=0.)
ax.set_xlabel('Target state $x$')
ax.set_ylabel(r'Inverse temperature $\beta$')
ax.set_xlim(x_lin[0], x_lin[-1])
ax.set_ylim(beta_lin[0], beta_lin[-1])
ax.set_xticklabels(['{0:.0f}'.format(tick) for tick in ax.get_xticks()], fontsize=9)
ax.set_yticklabels(['{0:.1f}'.format(tick) for tick in ax.get_yticks()], fontsize=9)
fig.tight_layout(pad=0)
fig.savefig(os.path.join(exp_dir, 'jct-prob-dens-and-joint-samples.pdf'))

## HMC in target density

In [None]:
pos = tt.matrix('pos')
dt, mom_resample_coeff = tt.scalars('dt', 'mom_resample_coeff')
n_step, n_sample = tt.iscalars('n_step', 'n_sample')
hmc_params = {'dt': dt, 'mom_resample_coeff': mom_resample_coeff, 'n_step': n_step}
hmc_sampler = hmc.HamiltonianSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)

In [None]:
(pos_samples_hmc, mom_samples, accepts, updates) = hmc_sampler.hmc_chain(
    pos, None, phi, n_sample, **hmc_params)
hmc_chain_func = th.function(
    [pos, dt, n_step, n_sample, mom_resample_coeff],
    [pos_samples_hmc, accepts],
    updates=updates
)

In [None]:
dt = 1.
n_step = 20
n_sample = 1000
mom_resample_coeff = 1.
pos_init = rng.normal(size=(10, 1))
pos_samples_hmc, accepts = hmc_chain_func(
    pos_init, dt, n_step, n_sample, mom_resample_coeff
)
print(accepts.mean())

In [None]:
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 3))
xs = np.linspace(-20, 20, 200)
axes[0].plot(xs, np.exp(-phi_func(xs)))
axes[1].plot(xs, np.exp(-phi_func(xs)))
axes[0].hist(pos_samples_hmc[:, 0, 0], 15, normed=True, alpha=0.8)
axes[1].hist(pos_samples[:, 0, 0], 50, weights=probs_1[:, 0], normed=True, alpha=0.8)
axes[1].set_xlabel(r'Target state $x$')
axes[0].set_ylabel(r'Probability density')
axes[1].set_ylabel(r'Probability density')
axes[0].legend(['Target', 'HMC'], ncol=1, loc='upper left')
axes[1].legend(['Target', 'CT HMC'], ncol=1, loc='upper left')
fig.tight_layout(pad=0)
fig.savefig(os.path.join(exp_dir, 'jct-and-hmc-target-histograms.pdf'))