In [None]:
import os
import glob
import time
import json
import numpy as np
import theano as th
import theano.tensor as tt
import theano.tensor.slinalg as sla
import bmtools.exact.moments as mom
import bmtools.relaxations.gm_relaxations as gmr
import bmtools.relaxations.var_mixture as var
import bmtools.utils as utils
import matplotlib.pyplot as plt
import thermomc.continuous_temp as cont_temp
import thermomc.discrete_temp as disc_temp
import thermomc.control_funcs as ctrl
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

## Load model parameters and set up

In [None]:
base_dir = os.path.dirname(os.getcwd())
model_dir = os.path.join(base_dir, 'data', 'gaussian-bmr')
exp_dir = os.path.join(base_dir, 'experiments', 'gaussian-bmr')
if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)

In [None]:
seed = 201702
rng = np.random.RandomState(seed)

In [None]:
class PhiFunc(object):
    
    def __init__(self, Q, b, log_zeta):
        self.Q = Q
        self.b = b
        self.log_zeta = log_zeta
        
    def __call__(self, x):
        return (
            0.5 * (x**2).sum(-1) - 
            tt.log(tt.cosh(x.dot(self.Q.T) + self.b)).sum(-1) + self.log_zeta
        )

class PsiFunc(object):

    def __init__(self, L, m):
        self.L = L
        self.m = m
    
    def __call__(self, x):
        z = x - self.m
        return 0.5 * (
            (z.T * sla.solve_upper_triangular(
                self.L.T, sla.solve_lower_triangular(
                        self.L, z.T))).sum(0) +
            self.m.shape[0] * tt.log(2 * np.pi) + 
            2 * tt.log(self.L.diagonal()).sum()
        )

In [None]:
def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def sigmoidal_schedule(num_temp, scale):
    inv_temp_sched = sigmoid(
        scale * (2. * np.arange(num_temp + 1) / num_temp - 1.))
    return (
        (inv_temp_sched - inv_temp_sched[0]) / 
        (inv_temp_sched[-1] - inv_temp_sched[0])
    )

def rmse(x, y):
    return ((x - y)**2).mean()**0.5

In [None]:
dtype = 'float64'
relaxation_list = []
true_log_norm_list = []
true_mean_list = []
true_covar_list = []
var_log_norm_list = []
var_mean_list = []
var_covar_chol_list = []
phi_funcs = []
psi_funcs = []
for i, file_path in enumerate(
        sorted(glob.glob(os.path.join(model_dir, 'params_and_moms_*.npz')))):
    loaded = np.load(file_path)
    relaxation_list.append(gmr.IsotropicCovarianceGMRelaxation(
            loaded['weights'], loaded['biases'], True)
    )
    true_log_norm_list.append(loaded['log_norm_const_x'])
    true_mean_list.append(loaded['expc_x'])
    true_covar_list.append(loaded['covar_x'])
    var_mean, var_covar_chol, var_log_norm = (
        var.mixture_of_variational_distributions_moments(
            relaxation_list[-1], rng
    ))
    var_log_norm_list.append(var_log_norm)
    var_mean_list.append(var_mean)
    var_covar_chol_list.append(var_covar_chol)
    np.savez(os.path.join(model_dir, 'var_moms_{0}.npz'.format(i)), 
             var_mean=var_mean, var_covar_chol=var_covar_chol, var_log_norm=var_log_norm)
    Q = tt.constant(relaxation_list[-1].Q, 'Q' + str(i), 2, dtype)
    b = tt.constant(relaxation_list[-1].b, 'b' + str(i), 1, dtype)
    L = tt.constant(var_covar_chol, 'L' + str(i), 2, dtype)
    m = tt.constant(var_mean, 'm' + str(i), 1, dtype)
    log_zeta = tt.constant(var_log_norm, 'log_zeta' + str(i), 0, dtype)
    phi_funcs.append(PhiFunc(Q, b, log_zeta))
    psi_funcs.append(PsiFunc(L, m))
    print('Var. log norm RMSE: {0}'.format(rmse(var_log_norm, true_log_norm_list[-1])))

##  Annealed Importance Sampling

In [None]:
num_temps = [1000, 5000, 10000, 20000]
dt = 0.5
temp_scale = 4.
num_reps = 10
num_step = 10
num_runs_per_rep = 100
mom_resample_coeff = 1.
num_runs = num_reps * num_runs_per_rep

In [None]:
pos = tt.matrix('pos')
inv_temps = tt.vector('inv_temps')
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}
ais_sampler = disc_temp.AnnealedImportanceSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)
ais_run_funcs = []
for phi_func, psi_func in zip(phi_funcs, psi_funcs):
    pos_samples, log_weights, accepts, updates = ais_sampler.run(
        pos, None, inv_temps, phi_func, psi_func, hmc_params
    )
    ais_run = th.function(
        [pos, inv_temps],
        [pos_samples, log_weights, accepts],
        updates=updates
    )
    ais_run_funcs.append(ais_run)

In [None]:
for i in range(10):
    ais_exp_dir = os.path.join(exp_dir, 'ais', 'params-' + str(i))
    if not os.path.exists(ais_exp_dir):
        os.makedirs(ais_exp_dir)
    for num_temp in num_temps:
        settings = {
            'dt': dt,
            'num_temp': num_temp,
            'temp_scale': temp_scale,
            'n_step': num_step,
            'mom_resample_coeff': mom_resample_coeff
        }
        print('Parameters {0} num temps {1}'.format(i, num_temp))
        print('-' * 100)
        print(settings)
        settings_path = os.path.join(ais_exp_dir, 'settings-{0}.json'.format(num_temp))
        results_path = os.path.join(ais_exp_dir, 'results-{0}.npz'.format(num_temp))
        with open(settings_path, 'w') as f:
            json.dump(settings, f, indent=True)
        inv_temp_sched = sigmoidal_schedule(num_temp, temp_scale)
        num_dim = relaxation_list[i].n_dim_r
        pos_init = rng.normal(size=(num_runs, num_dim)).dot(
            var_covar_chol_list[i].T) + var_mean_list[i]
        start_time = time.time()
        pos_samples, log_weights, accepts = ais_run_funcs[i](
            pos_init, inv_temp_sched
        )
        sampling_time = time.time() - start_time
        print('Sampling time: {0:.2f}s'.format(sampling_time))
        log_norm_rmses = []
        mean_rmses = []
        covar_rmses = []
        for lw, ps in zip(
                log_weights.reshape((num_reps, -1)),
                pos_samples.reshape((num_reps, num_runs_per_rep, -1))):
            log_norm_rmses.append(
                rmse(np.log(np.exp(lw).mean(0)) + var_log_norm_list[i], 
                     true_log_norm_list[i])
            )
            probs = np.exp(lw)
            probs /= probs.sum()
            mean_est = (probs[:, None] * ps).sum(0)
            mean_rmses.append(
                rmse(true_mean_list[i], mean_est)
            )
            ps_zm = ps - mean_est
            covar_est = (ps_zm * probs[:, None]).T.dot(ps_zm)
            covar_rmses.append(
                rmse(true_covar_list[i], covar_est)
            )
        var_log_norm_rmse = rmse(true_log_norm_list[i], var_log_norm_list[i])
        var_mean_rmse = rmse(true_mean_list[i], var_mean_list[i])
        var_covar_rmse = rmse(true_covar_list[i], 
                              var_covar_chol_list[i].dot(var_covar_chol_list[i].T))
        print('RMSE log_norm={0:.2f} mean={1:.2f} covar={2:.2f}'
             .format(
                np.mean(log_norm_rmses) / var_log_norm_rmse, 
                np.mean(mean_rmses) / var_mean_rmse, 
                np.mean(covar_rmses) / var_covar_rmse
            )
        )
        np.savez(
            results_path, 
            sampling_time=sampling_time, 
            pos_samples=pos_samples, 
            log_weights=log_weights, 
            accepts=accepts,
            log_norm_rmses=np.array(log_norm_rmses),
            mean_rmses=np.array(mean_rmses),
            covar_rmses=np.array(covar_rmses),
            var_log_norm_rmse=var_log_norm_rmse,
            var_mean_rmse=var_mean_rmse,
            var_covar_rmse=var_covar_rmse
        )
        print('Saved to ' + results_path)

## Hamiltonian Annealed Importance Sampling

In [None]:
num_temps = [1000, 5000, 10000, 20000]
dt = 0.5
temp_scale = 4.
num_reps = 10
num_step = 1
num_runs_per_rep = 500
num_runs = num_reps * num_runs_per_rep

In [None]:
pos = tt.matrix('pos')
inv_temps = tt.vector('inv_temps')
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': (1. - 0.5**dt)**0.5
}
ais_sampler = disc_temp.AnnealedImportanceSampler(
    tt.shared_randomstreams.RandomStreams(seed), True
)
ais_run_funcs = []
for phi_func, psi_func in zip(phi_funcs, psi_funcs):
    pos_samples, log_weights, accepts, updates = ais_sampler.run(
        pos, None, inv_temps, phi_func, psi_func, hmc_params
    )
    ais_run = th.function(
        [pos, inv_temps],
        [pos_samples, log_weights, accepts],
        updates=updates
    )
    ais_run_funcs.append(ais_run)

In [None]:
for i in range(10):
    ais_exp_dir = os.path.join(exp_dir, 'h-ais', 'params-' + str(i))
    if not os.path.exists(ais_exp_dir):
        os.makedirs(ais_exp_dir)
    for num_temp in num_temps:
        settings = {
            'dt': dt,
            'num_temp': num_temp,
            'temp_scale': temp_scale,
            'n_step': num_step,
            'mom_resample_coeff': (1. - 0.5**dt)**0.5
        }
        print('Parameters {0} num temp {1}'.format(i, num_temp))
        print('-' * 100)
        print(settings)
        settings_path = os.path.join(ais_exp_dir, 'settings-{0}.json'.format(num_temp))
        results_path = os.path.join(ais_exp_dir, 'results-{0}.npz'.format(num_temp))
        with open(settings_path, 'w') as f:
            json.dump(settings, f, indent=True)
        inv_temp_sched = sigmoidal_schedule(num_temp, temp_scale)
        num_dim = relaxation_list[i].n_dim_r
        pos_init = rng.normal(size=(num_runs, num_dim)).dot(
            var_covar_chol_list[i].T) + var_mean_list[i]
        start_time = time.time()
        pos_samples, log_weights, accepts = ais_run_funcs[i](
            pos_init, inv_temp_sched
        )
        sampling_time = time.time() - start_time
        print('Sampling time: {0:.2f}s'.format(sampling_time))
        log_norm_rmses = []
        mean_rmses = []
        covar_rmses = []
        for lw, ps in zip(
                log_weights.reshape((num_reps, -1)),
                pos_samples.reshape((num_reps, num_runs_per_rep, -1))):
            log_norm_rmses.append(
                rmse(np.log(np.exp(lw).mean(0)) + var_log_norm_list[i], 
                     true_log_norm_list[i])
            )
            probs = np.exp(lw)
            probs /= probs.sum()
            mean_est = (probs[:, None] * ps).sum(0)
            mean_rmses.append(
                rmse(true_mean_list[i], mean_est)
            )
            ps_zm = ps - mean_est
            covar_est = (ps_zm * probs[:, None]).T.dot(ps_zm)
            covar_rmses.append(
                rmse(true_covar_list[i], covar_est)
            )
        var_log_norm_rmse = rmse(true_log_norm_list[i], var_log_norm_list[i])
        var_mean_rmse = rmse(true_mean_list[i], var_mean_list[i])
        var_covar_rmse = rmse(true_covar_list[i], 
                              var_covar_chol_list[i].dot(var_covar_chol_list[i].T))
        print('RMSE log_norm={0:.2f} mean={1:.2f} covar={2:.2f}'
             .format(
                np.mean(log_norm_rmses) / var_log_norm_rmse, 
                np.mean(mean_rmses) / var_mean_rmse, 
                np.mean(covar_rmses) / var_covar_rmse
            )
        )
        np.savez(
            results_path, 
            sampling_time=sampling_time, 
            pos_samples=pos_samples, 
            log_weights=log_weights, 
            accepts=accepts,
            log_norm_rmses=np.array(log_norm_rmses),
            mean_rmses=np.array(mean_rmses),
            covar_rmses=np.array(covar_rmses),
            var_log_norm_rmse=var_log_norm_rmse,
            var_mean_rmse=var_mean_rmse,
            var_covar_rmse=var_covar_rmse
        )
        print('Saved to ' + results_path)
        print('-' * 100)

## Incremental RMSE helper

In [None]:
def rmse(x, y):
    return ((x - y)**2).mean()**0.5

def calculate_incremental_rmses(x_samples, probs_1, probs_0,
                                true_log_norm, true_mean, true_covar):
    n_sample, n_chain, n_dim = x_samples.shape
    sum_probs_1_x = 0
    sum_probs_1_xx = 0
    sum_probs_1 = 0
    sum_probs_0 = 0
    log_norm_rmses = np.empty(n_sample) * np.nan
    mean_rmses = np.empty(n_sample) * np.nan
    covar_rmses = np.empty(n_sample) * np.nan
    for s in range(n_sample):
        p1 = probs_1[s]
        p0 = probs_0[s]
        x = x_samples[s]
        sum_probs_1_x += p1[:, None] * x
        sum_probs_1_xx += p1[:, None, None] * (x[:, :, None] * x[:, None, :])
        sum_probs_1 += p1
        sum_probs_0 += p0
        log_norm_est = np.log(sum_probs_1.sum(0)) - np.log(sum_probs_0.sum(0))
        mean_est = sum_probs_1_x.sum(0) / sum_probs_1.sum(0)
        covar_est = sum_probs_1_xx.sum(0) / sum_probs_1.sum(0) - np.outer(mean_est, mean_est)
        log_norm_rmses[s] = rmse(log_norm_est, true_log_norm)
        mean_rmses[s] = rmse(mean_est, true_mean)
        covar_rmses[s] = rmse(covar_est, true_covar)
    return log_norm_rmses, mean_rmses, covar_rmses

## Simulated Tempering

In [None]:
num_temp = 1000
dt = 0.5
num_step = 20
temp_scale = 4.
num_reps = 10
num_runs_per_rep = 10
num_runs = num_reps * num_runs_per_rep
mom_resample_coeff = 1.

In [None]:
pos = tt.matrix('pos')
idx = tt.lvector('idx')
inv_temps = tt.vector('inv_temps')
num_sample = tt.lscalar('num_sample')
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}
st_sampler = disc_temp.SimulatedTemperingSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)
st_chain_funcs = []
for phi_func, psi_func in zip(phi_funcs, psi_funcs):
    pos_samples, idx_samples, probs_0, probs_1, accepts, updates = st_sampler.chain(
        pos, None, idx, inv_temps, 0, phi_func, psi_func, num_sample, hmc_params
    )
    st_chain = th.function(
        [pos, idx, inv_temps, num_sample],
        [pos_samples, idx_samples, probs_0, probs_1, accepts],
        updates=updates
    )
    st_chain_funcs.append(st_chain)

In [None]:
num_sample = 40000
for i in range(10):
    st_exp_dir = os.path.join(exp_dir, 'st', 'params-' + str(i))
    if not os.path.exists(st_exp_dir):
        os.makedirs(st_exp_dir)
    settings = {
        'dt': dt,
        'num_temp': num_temp,
        'num_sample': num_sample,
        'num_step': num_step,
        'temp_scale': temp_scale,
        'mom_resample_coeff': mom_resample_coeff
    }
    print('Parameters {0}'.format(i))
    print('-' * 100)
    print(settings)
    settings_path = os.path.join(st_exp_dir, 'settings.json')
    results_path = os.path.join(st_exp_dir, 'results.npz')
    with open(settings_path, 'w') as f:
        json.dump(settings, f, indent=True)
    inv_temp_sched = sigmoidal_schedule(num_temp, temp_scale)
    num_dim = relaxation_list[i].n_dim_r
    pos_init = rng.normal(size=(num_runs, num_dim)).dot(
        var_covar_chol_list[i].T) + var_mean_list[i]
    idx_init = np.zeros(num_runs, 'int64')
    start_time = time.time()
    pos_samples, idx_samples, probs_0, probs_1, accepts = st_chain_funcs[i](
        pos_init, idx_init, inv_temp_sched, num_sample
    )
    sampling_time = time.time() - start_time
    print('Sampling time: {0:.2f}s'.format(sampling_time))
    log_norm_rmses = np.empty((num_reps, num_sample))
    mean_rmses = np.empty((num_reps, num_sample))
    covar_rmses = np.empty((num_reps, num_sample))
    for r in range(num_reps):
        log_norm_rmses[r], mean_rmses[r], covar_rmses[r] = calculate_incremental_rmses(
            pos_samples[:, r:(r+1)*num_runs_per_rep], 
            probs_1[:, r:(r+1)*num_runs_per_rep], 
            probs_0[:, r:(r+1)*num_runs_per_rep], 
            true_log_norm_list[i] - var_log_norm_list[i],
            true_mean_list[i], true_covar_list[i]
        )
    var_log_norm_rmse = rmse(true_log_norm_list[i], var_log_norm_list[i])
    var_mean_rmse = rmse(true_mean_list[i], var_mean_list[i])
    var_covar_rmse = rmse(true_covar_list[i], 
                          var_covar_chol_list[i].dot(var_covar_chol_list[i].T))
    print('RMSE log_norm={0:.2f} mean={1:.2f} covar={2:.2f}'
         .format(
            np.mean(log_norm_rmses[:, -1]) / var_log_norm_rmse, 
            np.mean(mean_rmses[:, -1]) / var_mean_rmse, 
            np.mean(covar_rmses[:, -1]) / var_covar_rmse
        )
    )
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    axes[0].semilogy(log_norm_rmses.mean(0) / var_log_norm_rmse)
    axes[0].set_title('Log norm RMSE')
    axes[1].semilogy(mean_rmses.mean(0) / var_mean_rmse)
    axes[1].set_title('Mean RMSE')
    axes[2].semilogy(covar_rmses.mean(0) / var_covar_rmse)
    axes[2].set_title('Covariance RMSE')
    plt.show()
    np.savez(
        results_path, 
        sampling_time=sampling_time, 
        pos_samples=pos_samples,
        idx_samples=idx_samples,
        probs_1=probs_1,
        probs_0=probs_0,
        accepts=accepts,
        log_norm_rmses=log_norm_rmses,
        mean_rmses=mean_rmses,
        covar_rmses=covar_rmses,
        var_log_norm_rmse=var_log_norm_rmse,
        var_mean_rmse=var_mean_rmse,
        var_covar_rmse=var_covar_rmse
    )
    print('Saved to ' + results_path)
    print('-' * 100)

## Continuous tempering

### Gibbs

In [None]:
dt = 0.5
num_step = 20
num_reps = 10
num_runs_per_rep = 10
num_runs = num_reps * num_runs_per_rep
mom_resample_coeff = 1.

In [None]:
pos = tt.matrix('pos')
idx = tt.lvector('idx')
inv_temp = tt.vector('inv_temp')
num_sample = tt.lscalar('n_sample')
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}
gct_sampler = cont_temp.GibbsContinuousTemperingSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)
gct_chain_funcs = []
for phi_func, psi_func in zip(phi_funcs, psi_funcs):
    pos_samples, inv_temp_samples, probs_0, probs_1, accepts, updates = gct_sampler.chain(
        pos, None, inv_temp, phi_func, psi_func, num_sample, hmc_params
    )
    gct_chain = th.function(
        [pos, inv_temp, num_sample],
        [pos_samples, inv_temp_samples, probs_0, probs_1, accepts],
        updates=updates
    )
    gct_chain_funcs.append(gct_chain)

In [None]:
num_sample = 60000
for i in range(10):
    gct_exp_dir = os.path.join(exp_dir, 'gibbs-ct', 'params-' + str(i))
    if not os.path.exists(gct_exp_dir):
        os.makedirs(gct_exp_dir)
    settings = {
        'dt': dt,
        'num_sample': num_sample,
        'num_step': num_step,
        'mom_resample_coeff': mom_resample_coeff
    }
    print('Parameters {0}'.format(i))
    print('-' * 100)
    print(settings)
    settings_path = os.path.join(gct_exp_dir, 'settings.json')
    results_path = os.path.join(gct_exp_dir, 'results.npz')
    with open(settings_path, 'w') as f:
        json.dump(settings, f, indent=True)
    num_dim = relaxation_list[i].n_dim_r
    pos_init = rng.normal(size=(num_runs, num_dim)).dot(
        var_covar_chol_list[i].T) + var_mean_list[i]
    inv_temp_init = np.zeros(num_runs)
    start_time = time.time()
    pos_samples, inv_temp_samples, probs_0, probs_1, accepts = gct_chain_funcs[i](
        pos_init, inv_temp_init, num_sample
    )
    sampling_time = time.time() - start_time
    print('Sampling time: {0:.2f}s'.format(sampling_time))
    log_norm_rmses = np.empty((num_reps, num_sample))
    mean_rmses = np.empty((num_reps, num_sample))
    covar_rmses = np.empty((num_reps, num_sample))
    for r in range(num_reps):
        log_norm_rmses[r], mean_rmses[r], covar_rmses[r] = calculate_incremental_rmses(
            pos_samples[:, r:(r+1)*num_runs_per_rep], 
            probs_1[:, r:(r+1)*num_runs_per_rep], 
            probs_0[:, r:(r+1)*num_runs_per_rep], 
            true_log_norm_list[i] - var_log_norm_list[i],
            true_mean_list[i], true_covar_list[i]
        )
    var_log_norm_rmse = rmse(true_log_norm_list[i], var_log_norm_list[i])
    var_mean_rmse = rmse(true_mean_list[i], var_mean_list[i])
    var_covar_rmse = rmse(true_covar_list[i], 
                          var_covar_chol_list[i].dot(var_covar_chol_list[i].T))
    print('RMSE log_norm={0:.2f} mean={1:.2f} covar={2:.2f}'
         .format(
            np.mean(log_norm_rmses[:, -1]) / var_log_norm_rmse, 
            np.mean(mean_rmses[:, -1]) / var_mean_rmse, 
            np.mean(covar_rmses[:, -1]) / var_covar_rmse
        )
    )
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    axes[0].semilogy(log_norm_rmses.mean(0) / var_log_norm_rmse)
    axes[0].set_title('Log norm RMSE')
    axes[1].semilogy(mean_rmses.mean(0) / var_mean_rmse)
    axes[1].set_title('Mean RMSE')
    axes[2].semilogy(covar_rmses.mean(0) / var_covar_rmse)
    axes[2].set_title('Covariance RMSE')
    plt.show()
    np.savez(
        results_path, 
        sampling_time=sampling_time, 
        pos_samples=pos_samples,
        inv_temp_samples=inv_temp_samples,
        probs_1=probs_1,
        probs_0=probs_0,
        accepts=accepts,
        log_norm_rmses=log_norm_rmses,
        mean_rmses=mean_rmses,
        covar_rmses=covar_rmses,
        var_log_norm_rmse=var_log_norm_rmse,
        var_mean_rmse=var_mean_rmse,
        var_covar_rmse=var_covar_rmse
    )
    print('Saved to ' + results_path)
    print('-' * 100)

### Joint

In [None]:
dt = 0.5
num_step = 20
temp_scale = 1.
num_reps = 10
num_runs_per_rep = 10
num_runs = num_reps * num_runs_per_rep
mom_resample_coeff = 1.

In [None]:
pos = tt.matrix('pos')
tmp_ctrl = tt.vector('tmp_ctrl')
num_sample = tt.lscalar('n_sample')
ctrl_func = ctrl.SigmoidalControlFunction(temp_scale)
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}
jct_sampler = cont_temp.JointContinuousTemperingSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)
jct_chain_funcs = []
for phi_func, psi_func in zip(phi_funcs, psi_funcs):
    (pos_samples, tmp_ctrl_sample, inv_temp_samples, 
     probs_0, probs_1, accepts, updates) = jct_sampler.chain(
        pos, tmp_ctrl, None, phi_func, psi_func, ctrl_func, num_sample, hmc_params
    )
    jct_chain = th.function(
        [pos, tmp_ctrl, num_sample],
        [pos_samples, inv_temp_samples, probs_0, probs_1, accepts],
        updates=updates
    )
    jct_chain_funcs.append(jct_chain)

In [None]:
num_sample = 50000
for i in range(10):
    jct_exp_dir = os.path.join(exp_dir, 'joint-ct', 'params-' + str(i))
    if not os.path.exists(jct_exp_dir):
        os.makedirs(jct_exp_dir)
    settings = {
        'dt': dt,
        'num_sample': num_sample,
        'num_step': num_step,
        'temp_scale': temp_scale,
        'mom_resample_coeff': mom_resample_coeff
    }
    print('Parameters {0}'.format(i))
    print('-' * 100)
    print(settings)
    settings_path = os.path.join(jct_exp_dir, 'settings.json')
    results_path = os.path.join(jct_exp_dir, 'results.npz')
    with open(settings_path, 'w') as f:
        json.dump(settings, f, indent=True)
    num_dim = relaxation_list[i].n_dim_r
    pos_init = rng.normal(size=(num_runs, num_dim)).dot(
        var_covar_chol_list[i].T) + var_mean_list[i]
    tmp_ctrl_init = np.zeros(num_runs) - 10.
    start_time = time.time()
    pos_samples, inv_temp_samples, probs_0, probs_1, accepts = jct_chain_funcs[i](
        pos_init, tmp_ctrl_init, num_sample
    )
    sampling_time = time.time() - start_time
    print('Sampling time: {0:.2f}s'.format(sampling_time))
    log_norm_rmses = np.empty((num_reps, num_sample))
    mean_rmses = np.empty((num_reps, num_sample))
    covar_rmses = np.empty((num_reps, num_sample))
    for r in range(num_reps):
        log_norm_rmses[r], mean_rmses[r], covar_rmses[r] = calculate_incremental_rmses(
            pos_samples[:, r:(r+1)*num_runs_per_rep], 
            probs_1[:, r:(r+1)*num_runs_per_rep], 
            probs_0[:, r:(r+1)*num_runs_per_rep], 
            true_log_norm_list[i] - var_log_norm_list[i],
            true_mean_list[i], true_covar_list[i]
        )
    var_log_norm_rmse = rmse(true_log_norm_list[i], var_log_norm_list[i])
    var_mean_rmse = rmse(true_mean_list[i], var_mean_list[i])
    var_covar_rmse = rmse(true_covar_list[i], 
                          var_covar_chol_list[i].dot(var_covar_chol_list[i].T))
    print('RMSE log_norm={0:.2f} mean={1:.2f} covar={2:.2f}'
         .format(
            np.mean(log_norm_rmses[:, -1]) / var_log_norm_rmse, 
            np.mean(mean_rmses[:, -1]) / var_mean_rmse, 
            np.mean(covar_rmses[:, -1]) / var_covar_rmse
        )
    )
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    axes[0].semilogy(log_norm_rmses.mean(0) / var_log_norm_rmse)
    axes[0].set_title('Log norm RMSE')
    axes[1].semilogy(mean_rmses.mean(0) / var_mean_rmse)
    axes[1].set_title('Mean RMSE')
    axes[2].semilogy(covar_rmses.mean(0) / var_covar_rmse)
    axes[2].set_title('Covariance RMSE')
    plt.show()
    np.savez(
        results_path, 
        sampling_time=sampling_time, 
        pos_samples=pos_samples,
        inv_temp_samples=inv_temp_samples,
        probs_1=probs_1,
        probs_0=probs_0,
        accepts=accepts,
        log_norm_rmses=log_norm_rmses,
        mean_rmses=mean_rmses,
        covar_rmses=covar_rmses,
        var_log_norm_rmse=var_log_norm_rmse,
        var_mean_rmse=var_mean_rmse,
        var_covar_rmse=var_covar_rmse
    )
    print('Saved to ' + results_path)
    print('-' * 100)