In [None]:
import os
#os.environ['THEANO_FLAGS'] = 'device=gpu0,floatX=float32'
import pprint as pp
import glob
import time
import json
import numpy as np
import scipy.linalg as la
import theano as th
import theano.tensor as tt
import theano.sandbox.rng_mrg as rand
import theano.tensor.slinalg as sla
import matplotlib.pyplot as plt
import thermomc.continuous_temp as cont_temp
import thermomc.discrete_temp as disc_temp
import thermomc.control_funcs as ctrl
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

## Create experiment directory

In [None]:
base_dir = os.path.dirname(os.getcwd())
model_dir = os.path.join(base_dir, 'data', 'omni-iwae')
exp_dir = os.path.join(base_dir, 'experiments', 'omni-iwae')
if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)

## Seed random number generator

In [None]:
seed = 201703
rng = np.random.RandomState(seed)

## Load parameters

In [None]:
decoder = np.load(os.path.join(model_dir, 'decoder_params.npz'))
encoder_h = np.load(os.path.join(model_dir, 'encoder_h_params.npz'))
encoder_mean = np.load(os.path.join(model_dir, 'encoder_mean_params.npz'))
encoder_std = np.load(os.path.join(model_dir, 'encoder_std_params.npz'))
samples_and_log_norm_bounds = np.load(os.path.join(model_dir, 'joint-sample-and-log-norm-bounds.npz'))

In [None]:
x = samples_and_log_norm_bounds['xs']
var_mean_z_gvn_x = samples_and_log_norm_bounds['mean_z_gvn_x']
var_std_z_gvn_x = samples_and_log_norm_bounds['std_z_gvn_x']
log_zeta = samples_and_log_norm_bounds['log_zeta']
log_norm_lower = samples_and_log_norm_bounds['log_norm_lower']
log_norm_upper = samples_and_log_norm_bounds['log_norm_upper']

## Define model functions

In [None]:
def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def sigmoidal_schedule(num_temp, scale):
    inv_temp_sched = sigmoid(
        scale * (2. * np.arange(num_temp + 1) / num_temp - 1.))
    return (
        (inv_temp_sched - inv_temp_sched[0]) / 
        (inv_temp_sched[-1] - inv_temp_sched[0])
    )

def rmse(x, y):
    return ((x - y)**2).mean()**0.5

In [None]:
class PhiFunc(object):
    
    def __init__(self, x, weights, biases, non_linearities, log_zeta, var_mean, var_std):
        self.x = x
        self.weights = weights
        self.biases = biases
        self.non_linearities = non_linearities
        self.log_zeta = log_zeta
        self.var_mean = var_mean
        self.var_std = var_std
        
    def mean_x_gvn_z(self, z):
        h = z
        for W, b, f in zip(self.weights, self.biases, self.non_linearities):
            h = f(h.dot(W) + b)
        return h
        
    def __call__(self, u):
        z = u * self.var_std + self.var_mean
        mean = self.mean_x_gvn_z(z)
        return (
            0.5 * (z**2).sum(-1) + 0.5 * z.shape[-1] * tt.log(2 * np.pi) -
            (self.x * tt.log(mean) + (1 - self.x) * tt.log(1 - mean)).sum(-1) +
            self.log_zeta - tt.log(self.var_std).sum(-1)
        )

class PsiFunc(object):
    
    def __call__(self, u):
        return (
            0.5 * (u**2).sum(-1) +  0.5 * u.shape[-1] * tt.log(2 * np.pi)
        )

In [None]:
non_linearity_map = {
    'nnet.Tanh': tt.tanh,
    'nnet.Sigmoid': tt.nnet.sigmoid,
    'nnet.Exp': tt.exp
}
non_linearities = [
    non_linearity_map[name] 
    for name in decoder['layers'] if name != 'nnet.Linear'
]
weights = [
    tt.constant(decoder['W' + str(i * 2)], 'dec_W' + str(i), 
                2, dtype=th.config.floatX) 
    for i in range(3)
]
biases = [
    tt.constant(decoder['b' + str(i * 2)], 'dec_b' + str(i), 
                2, dtype=th.config.floatX) 
    for i in range(3)
]

In [None]:
latent_dim = 50
num_data = 1000

## Annealed Importance Sampling

In [None]:
num_temps = [50, 100, 200, 500, 1000, 2000]
temp_scale = 4.
dt = 0.4
num_step = 10
mom_resample_coeff = 1.
num_runs = 16
num_reps = 10

In [None]:
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}

### Create repeated model parameters / samples

In [None]:
log_zeta_rep = tt.constant(
    log_zeta.repeat(num_runs), 'log_zeta', 1, th.config.floatX) 
x_rep = tt.constant(
    x.repeat(num_runs, 0), 'x', 2, th.config.floatX)
var_mean_z_gvn_x_rep = tt.constant(
    var_mean_z_gvn_x.repeat(num_runs, 0), 'var_mean_z_gvn_x', 2, th.config.floatX)
var_std_z_gvn_x_rep = tt.constant(
    var_std_z_gvn_x.repeat(num_runs, 0), 'var_std_z_gvn_x', 2, th.config.floatX)

### Create model objects

In [None]:
phi_func = PhiFunc(
    x_rep, weights, biases, non_linearities, log_zeta_rep,
    var_mean_z_gvn_x_rep, var_std_z_gvn_x_rep)
psi_func = PsiFunc()

In [None]:
ais_sampler = disc_temp.AnnealedImportanceSampler(
   rand.MRG_RandomStreams(seed), False)

In [None]:
pos = tt.matrix('pos')
inv_temps= tt.vector('inv_temp_sched')
pos_samples, log_weights, accepts, updates = ais_sampler.run(
    pos, None, inv_temps, phi_func, psi_func, hmc_params
)
ais_run = th.function(
    [pos, inv_temps],
    [log_weights, accepts],
    updates=updates
)

### AIS runs

In [None]:
rng.seed(seed)
ais_sampler.srng.seed(seed)
ais_settings = {
    'dt': dt,
    'num_temps': num_temps,
    'temp_scale': temp_scale,
    'num_step': num_step,
    'mom_resample_coeff': mom_resample_coeff,
    'num_runs': num_runs,
}
print('-' * 100)
pp.pprint(ais_settings)
print('-' * 100)
settings_path = os.path.join(exp_dir, 'ais-settings.json')
results_path = os.path.join(exp_dir, 'ais-results.npz')
with open(settings_path, 'w') as f:
        json.dump(ais_settings, f, indent=True)
ais_sampling_times = np.empty((num_reps, len(num_temps))) * np.nan
ais_log_norm_ests = np.empty((num_reps, len(num_temps))) * np.nan
for i in range(num_reps):
    print('Repeat {0}'.format(i + 1))
    print('-' * 100)
    for t, num_temp in enumerate(num_temps):
        print('Num temps {0}'.format(num_temp))
        print('-' * 100)
        inv_temp_sched = sigmoidal_schedule(num_temp, temp_scale)
        pos_init = rng.normal(size=(num_runs * num_data, latent_dim))
        start_time = time.time()
        log_weights, accepts = ais_run(
            pos_init.astype(th.config.floatX), inv_temp_sched.astype(th.config.floatX))
        ais_sampling_times[i, t] = time.time() - start_time
        print('Sampling time: {0:.2f}s'.format(ais_sampling_times[i, t]))
        print('Accept: mean={0:.2f} min={1:.2f} max={2:.2f}'
              .format(accepts.mean(), accepts.min(), accepts.max()))
        ais_log_norm_ests[i, t] = log_zeta.mean() + np.log(
            np.exp(log_weights.reshape((-1, num_runs))).mean(-1)).mean(0)
        print('Log norm est={0:.2f}').format(ais_log_norm_ests[i, t])
        print('-' * 100)
np.savez(
    results_path, 
    sampling_times=ais_sampling_times, 
    log_norm_ests=ais_log_norm_ests,
)
print('Saved to ' + results_path)

In [None]:
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
for i in range(num_reps):
    ax.plot(ais_sampling_times[i, :], ais_log_norm_ests[i, :], 'ro')
ax.plot([0, ais_sampling_times[:, -1].mean()], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, ais_sampling_times[:, -1].mean()], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
plt.show()

## Simulated tempering

In [None]:
dt = 0.4
num_temp = 1000
num_step = 10
mom_resample_coeff = 1.
temp_scale = 4.
num_runs = 2

In [None]:
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}

In [None]:
log_zeta_rep = tt.constant(
    log_zeta.repeat(num_runs), 'log_zeta', 1, th.config.floatX) 
x_rep = tt.constant(
    x.repeat(num_runs, 0), 'x', 2, th.config.floatX)
var_mean_z_gvn_x_rep = tt.constant(
    var_mean_z_gvn_x.repeat(num_runs, 0), 'var_mean_z_gvn_x', 2, th.config.floatX)
var_std_z_gvn_x_rep = tt.constant(
    var_std_z_gvn_x.repeat(num_runs, 0), 'var_std_z_gvn_x', 2, th.config.floatX)

In [None]:
phi_func = PhiFunc(
    x_rep, weights, biases, non_linearities, log_zeta_rep,
    var_mean_z_gvn_x_rep, var_std_z_gvn_x_rep
)
psi_func = PsiFunc()

In [None]:
pos = tt.matrix('pos')
idx = tt.lvector('idx')
inv_temps = tt.vector('inv_temps')
num_sample = tt.lscalar('num_sample')
st_sampler = disc_temp.SimulatedTemperingSampler(
    rand.MRG_RandomStreams(seed), False
)
pos_samples, idx_samples, probs_0, probs_1, accepts, updates = st_sampler.chain(
    pos, None, idx, inv_temps, 0, phi_func, psi_func, num_sample, hmc_params
)
st_chain = th.function(
    [pos, idx, inv_temps, num_sample],
    [probs_0, probs_1, accepts],
    updates=updates
)

### ST runs

In [None]:
num_sample = 3000
num_warm_up = 0
num_data = 1000
num_reps = 10

In [None]:
rng.seed(seed)
st_sampler.srng.seed(seed)
st_settings = {
    'dt': dt,
    'num_step': num_step,
    'mom_resample_coeff': mom_resample_coeff,
    'num_warmup': num_warm_up,
    'temp_scale': temp_scale,
    'num_temp': num_temp,
    'num_runs': num_runs,
}
inv_temp_sched = sigmoidal_schedule(num_temp, temp_scale)
print('-' * 100)
pp.pprint(st_settings)
print('-' * 100)
settings_path = os.path.join(exp_dir, 'st-settings.json')
with open(settings_path, 'w') as f:
    json.dump(st_settings, f, indent=True)
results_path = os.path.join(exp_dir, 'st-results.npz')
st_log_norm_ests = np.empty((num_reps, num_sample - num_warm_up)) * np.nan
st_sampling_times = np.empty(num_reps) * np.nan
for i in range(num_reps):
    print('Repeat {0}'.format(i + 1))
    pos_init = rng.normal(size=(num_runs * num_data, latent_dim))
    idx_init = rng.randint(low=0, high=num_temp, size=num_runs * x.shape[0])
    start_time = time.time()
    probs_0, probs_1, accepts = st_chain(
        pos_init.astype(th.config.floatX), 
        idx_init,
        inv_temp_sched.astype(th.config.floatX),
        num_sample
    )
    st_sampling_times[i] = time.time() - start_time
    st_log_norm_ests[i] = (
        log_zeta.mean() + 
        np.log(probs_1[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0)) - 
        np.log(probs_0[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0))
    ).mean(1)
    print('Sampling time: {0:.2f}s'.format(st_sampling_times[i]))
    print('Accept: mean={0:.2f} min={1:.2f} max={2:.2f}'
          .format(accepts.mean(), accepts.min(), accepts.max()))
    print('Log norm est={0:.2f}'.format(st_log_norm_ests[i][-1]))
    print('-' * 100)
np.savez(
    results_path, 
    sampling_times=st_sampling_times,  
    log_norm_ests=st_log_norm_ests,
)
print('Saved to ' + results_path)
print('-' * 100)

In [None]:
print('Log norm est final mean={0:.4f}'.format(st_log_norm_ests[:, -1].mean()))
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
for i in range(num_reps):
    ax.plot(np.arange(num_warm_up, num_sample) * 
            st_sampling_times[i] / num_sample, st_log_norm_ests[i])
ax.plot([0, st_sampling_times.mean()], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, st_sampling_times.mean()], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
plt.show()

## Gibbs continuous tempering

In [None]:
dt = 0.4
num_step = 10
mom_resample_coeff = 1.
num_runs = 2

In [None]:
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}

In [None]:
log_zeta_rep = tt.constant(
    log_zeta.repeat(num_runs), 'log_zeta', 1, th.config.floatX) 
x_rep = tt.constant(
    x.repeat(num_runs, 0), 'x', 2, th.config.floatX)
var_mean_z_gvn_x_rep = tt.constant(
    var_mean_z_gvn_x.repeat(num_runs, 0), 'var_mean_z_gvn_x', 2, th.config.floatX)
var_std_z_gvn_x_rep = tt.constant(
    var_std_z_gvn_x.repeat(num_runs, 0), 'var_std_z_gvn_x', 2, th.config.floatX)

In [None]:
phi_func = PhiFunc(
    x_rep, weights, biases, non_linearities, log_zeta_rep,
    var_mean_z_gvn_x_rep, var_std_z_gvn_x_rep
)
psi_func = PsiFunc()

In [None]:
pos = tt.matrix('pos')
inv_temp = tt.vector('inv_temp')
num_sample = tt.lscalar('n_sample')
gct_sampler = cont_temp.GibbsContinuousTemperingSampler(
    rand.MRG_RandomStreams(seed), False
)
pos_samples, inv_temp_samples, probs_0, probs_1, accepts, updates = gct_sampler.chain(
    pos, None, inv_temp, phi_func, psi_func, num_sample, hmc_params
)
gct_chain = th.function(
    [pos, inv_temp, num_sample],
    [probs_0, probs_1, accepts],
    updates=updates
)

### Gibbs CT runs

In [None]:
num_sample = 10000
num_warm_up = 0
num_data = 1000
num_reps = 10

In [None]:
rng.seed(seed)
gct_sampler.srng.seed(seed)
gct_settings = {
    'dt': dt,
    'num_step': num_step,
    'mom_resample_coeff': mom_resample_coeff,
    'num_warmup': num_warm_up,
    'num_runs': num_runs,
}
print('-' * 100)
pp.pprint(gct_settings)
print('-' * 100)
settings_path = os.path.join(exp_dir, 'gct-settings.json')
with open(settings_path, 'w') as f:
    json.dump(gct_settings, f, indent=True)
results_path = os.path.join(exp_dir, 'gct-results.npz')
gct_log_norm_ests = np.empty((num_reps, num_sample - num_warm_up)) * np.nan
gct_sampling_times = np.empty(num_reps) * np.nan
for i in range(num_reps):
    print('Repeat {0}'.format(i + 1))
    pos_init = rng.normal(size=(num_runs * num_data, latent_dim))
    inv_temp_init = sigmoid(rng.normal(size=num_runs * x.shape[0]))
    start_time = time.time()
    probs_0, probs_1, accepts = gct_chain(
        pos_init.astype(th.config.floatX), 
        inv_temp_init.astype(th.config.floatX), 
        num_sample
    )
    gct_sampling_times[i] = time.time() - start_time
    gct_log_norm_ests[i] = (
        log_zeta.mean() + 
        np.log(probs_1[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0)) - 
        np.log(probs_0[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0))
    ).mean(1)
    print('Sampling time: {0:.2f}s'.format(gct_sampling_times[i]))
    print('Accept: mean={0:.2f} min={1:.2f} max={2:.2f}'
          .format(accepts.mean(), accepts.min(), accepts.max()))
    print('Log norm est={0:.2f}'.format(gct_log_norm_ests[i][-1]))
    print('-' * 100)
np.savez(
    results_path, 
    sampling_times=gct_sampling_times,  
    log_norm_ests=gct_log_norm_ests,
)
print('Saved to ' + results_path)

In [None]:
print('Log norm est final mean={0:.4f}'.format(gct_log_norm_ests[:, -1].mean()))
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
for i in range(num_reps):
    ax.plot(np.arange(num_warm_up, num_sample) * 
            gct_sampling_times[i] / num_sample, gct_log_norm_ests[i])
ax.plot([0, gct_sampling_times.mean()], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, gct_sampling_times.mean()], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
plt.show()

## Joint continuous tempering

In [None]:
dt = 0.4
num_step = 10
temp_scale = 5.
mom_resample_coeff = 1.
num_runs = 2

In [None]:
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}

In [None]:
log_zeta_rep = tt.constant(
    log_zeta.repeat(num_runs), 'log_zeta', 1, th.config.floatX) 
x_rep = tt.constant(
    x.repeat(num_runs, 0), 'x', 2, th.config.floatX)
var_mean_z_gvn_x_rep = tt.constant(
    var_mean_z_gvn_x.repeat(num_runs, 0), 'var_mean_z_gvn_x', 2, th.config.floatX)
var_std_z_gvn_x_rep = tt.constant(
    var_std_z_gvn_x.repeat(num_runs, 0), 'var_std_z_gvn_x', 2, th.config.floatX)

In [None]:
phi_func = PhiFunc(
    x_rep, weights, biases, non_linearities, log_zeta_rep,
    var_mean_z_gvn_x_rep, var_std_z_gvn_x_rep
)
psi_func = PsiFunc()

In [None]:
pos = tt.matrix('pos')
tmp_ctrl = tt.vector('tmp_ctrl')
num_sample = tt.lscalar('n_sample')
ctrl_func = ctrl.SigmoidalControlFunction(temp_scale)
hmc_params = {
    'dt': dt,
    'n_step': num_step,
    'mom_resample_coeff': mom_resample_coeff
}
jct_sampler = cont_temp.JointContinuousTemperingSampler(
    rand.MRG_RandomStreams(seed), False
)
(pos_samples, tmp_ctrl_samples, inv_temp_samples, 
 probs_0, probs_1, accepts, updates) = jct_sampler.chain(
    pos, tmp_ctrl, None, phi_func, psi_func, ctrl_func, num_sample, hmc_params
)
jct_chain = th.function(
    [pos, tmp_ctrl, num_sample],
    [probs_0, probs_1, accepts],
    updates=updates
)

#### Joint CT runs

In [None]:
num_sample = 7500
num_warm_up = 0
num_data = 1000
num_reps = 10

In [None]:
rng.seed(seed)
jct_sampler.srng.seed(seed)
jct_settings = {
    'dt': dt,
    'num_step': num_step,
    'mom_resample_coeff': mom_resample_coeff,
    'num_warmup': num_warm_up,
    'temp_scale': temp_scale,
    'num_runs': num_runs,
}
print('-' * 100)
pp.pprint(jct_settings)
print('-' * 100)
settings_path = os.path.join(exp_dir, 'jct-settings.json')
with open(settings_path, 'w') as f:
    json.dump(jct_settings, f, indent=True)
results_path = os.path.join(exp_dir, 'jct-results.npz')
jct_log_norm_ests = np.empty((num_reps, num_sample - num_warm_up)) * np.nan
jct_sampling_times = np.empty(num_reps) * np.nan
for i in range(num_reps):
    print('Repeat {0}'.format(i + 1))
    pos_init = rng.normal(size=(num_runs * num_data, latent_dim))
    tmp_ctrl_init = rng.normal(size=num_runs * x.shape[0]) * temp_scale
    start_time = time.time()
    probs_0, probs_1, accepts = jct_chain(
        pos_init.astype(th.config.floatX), 
        tmp_ctrl_init.astype(th.config.floatX), 
        num_sample
    )
    jct_sampling_times[i] = time.time() - start_time
    jct_log_norm_ests[i] = (
        log_zeta.mean() + 
        np.log(probs_1[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0)) - 
        np.log(probs_0[num_warm_up:].reshape((-1, 1000, num_runs)).mean(-1).cumsum(0))
    ).mean(1)
    print('Sampling time: {0:.2f}s'.format(jct_sampling_times[i]))
    print('Accept: mean={0:.2f} min={1:.2f} max={2:.2f}'
          .format(accepts.mean(), accepts.min(), accepts.max()))
    print('Log norm est={0:.2f}'.format(jct_log_norm_ests[i][-1]))
    print('-' * 100)
np.savez(
    results_path, 
    sampling_times=jct_sampling_times,  
    log_norm_ests=jct_log_norm_ests,
)
print('Saved to ' + results_path)

In [None]:
print('Log norm est final mean={0:.4f}'.format(jct_log_norm_ests[:, -1].mean()))
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
for i in range(num_reps):
    ax.plot(np.arange(num_warm_up, num_sample) * 
            jct_sampling_times[i] / num_sample, jct_log_norm_ests[i])
ax.plot([0, jct_sampling_times.mean()], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, jct_sampling_times.mean()], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
plt.show()

## Plot all

In [None]:
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
for i in range(ais_sampling_times.shape[0]):
    ax.plot(ais_sampling_times[i], ais_log_norm_ests[i], 'ro')
for i in range(st_sampling_times.shape[0]):
    ax.plot(np.arange(st_settings['num_warmup'], st_log_norm_ests.shape[1]) * 
            st_sampling_times[i] / st_log_norm_ests.shape[1], st_log_norm_ests[i], 'g-')
for i in range(gct_sampling_times.shape[0]):
    ax.plot(np.arange(gct_settings['num_warmup'], gct_log_norm_ests.shape[1]) * 
            gct_sampling_times[i] / gct_log_norm_ests.shape[1], gct_log_norm_ests[i], 'b-')
for i in range(jct_sampling_times.shape[0]):
    ax.plot(np.arange(gct_settings['num_warmup'], jct_log_norm_ests.shape[1]) * 
            jct_sampling_times[i] / jct_log_norm_ests.shape[1], jct_log_norm_ests[i], 'c-')
ax.plot([0, 350], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, 350], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
#ax.set_ylim(-110.2, -109.8)
plt.show()

In [None]:
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
ax.plot(ais_sampling_times.mean(0), ais_log_norm_ests.mean(0), 'ro')
ax.plot(np.arange(st_settings['num_warmup'], st_log_norm_ests.shape[1]) * 
        st_sampling_times.mean(0) / st_log_norm_ests.shape[1], 
        st_log_norm_ests.mean(0), 'g-')
ax.plot(np.arange(gct_settings['num_warmup'], gct_log_norm_ests.shape[1]) * 
        gct_sampling_times.mean(0) / gct_log_norm_ests.shape[1], 
        gct_log_norm_ests.mean(0), 'b-')
ax.plot(np.arange(jct_settings['num_warmup'], jct_log_norm_ests.shape[1]) * 
        jct_sampling_times.mean(0) / jct_log_norm_ests.shape[1], 
        jct_log_norm_ests.mean(0), 'c-')
ax.plot([0, 350], [log_norm_upper, log_norm_upper], 'k--')
ax.plot([0, 350], [log_norm_lower, log_norm_lower], 'k--')
ax.set_xlabel('Time / s')
ax.set_ylabel('$\\log Z$ estimate')
#ax.set_ylim(-110.8, -109.8)
ax.set_xlim(0, 350)
plt.show()

In [None]:
fig = plt.figure(figsize=(16, 10))
ax = fig.add_subplot(111)
im_grid = np.zeros((700, 1120))
for i, im in enumerate(x[:]):
    row = i % 25
    col = i // 25
    im_grid[row * 28 : (row + 1) * 28, col * 28 : (col + 1) * 28] = (
        im.reshape(28, 28)
    )
ax.imshow(im_grid, cmap='Greys', interpolation='None')
ax.set_xticks([])
ax.set_yticks([])
ax.set_frame_on(False)
fig.tight_layout(pad=0)
fig.savefig(os.path.join(exp_dir, 'omni-samples.pdf'))