In [None]:
import os
import glob
from collections import OrderedDict
import numpy as np
import scipy.linalg as la
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import CubicSpline
import seaborn as sns
%matplotlib inline

In [None]:
palette = sns.color_palette('husl', 5)
sns.set_context('paper')
sns.set(font='sans')
sns.set_style('white', {
        'font.family': 'sans',
        'axes.labelcolor': '0.',
        'text.color': '0.',
        'xtick.color': '0.',
        'ytick.color': '0.'
    }
)

In [None]:
base_dir = os.path.dirname(os.getcwd())
model_dir = os.path.join(base_dir, 'data', 'mnist-iwae')
exp_dir = os.path.join(base_dir, 'experiments', 'mnist-iwae')

In [None]:
with np.load(os.path.join(model_dir, 'joint-sample-and-log-norm-bounds.npz')) as samples_and_log_norm_bounds:
    log_zeta = samples_and_log_norm_bounds['log_zeta']
    log_norm_lower = samples_and_log_norm_bounds['log_norm_lower']
    log_norm_upper = samples_and_log_norm_bounds['log_norm_upper']

In [None]:
with np.load(os.path.join(exp_dir, 'ais-results.npz')) as loaded:
    ais_log_norm_ests = loaded['log_norm_ests']
    ais_times = loaded['sampling_times']

In [None]:
with np.load(os.path.join(exp_dir, 'st-results.npz')) as loaded:
    st_log_norm_ests = loaded['log_norm_ests']
    st_times = loaded['sampling_times']

In [None]:
with np.load(os.path.join(exp_dir, 'gct-results.npz')) as loaded:
    gct_log_norm_ests = loaded['log_norm_ests']
    gct_times = loaded['sampling_times']

In [None]:
with np.load(os.path.join(exp_dir, 'jct-results-alt.npz')) as loaded:
    jct_log_norm_ests = loaded['log_norm_ests']
    jct_times = loaded['sampling_times']

In [None]:
num_data = 1000
max_time = 400
skip = 10
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(111)
ci = [95]
_ = sns.tsplot(
    data=st_log_norm_ests[:, ::skip] * num_data, 
    time=np.linspace(0, 1, st_log_norm_ests.shape[1] / skip) * st_times.mean(0),
    color=palette[1], ls='--',
    err_style="ci_band", ci=ci, ax=ax, condition='ST'
)
_ = sns.tsplot(
    data=gct_log_norm_ests[:, ::skip] * num_data, 
    time=np.linspace(0, 1, gct_log_norm_ests.shape[1] / skip) * gct_times.mean(0),
    color=palette[2], ls=':',
    err_style="ci_band", ci=ci, ax=ax, condition='Gibbs CT'
)
_ = sns.tsplot(
    data=jct_log_norm_ests[:, ::skip] * num_data, 
    time=np.linspace(0, 1, jct_log_norm_ests.shape[1] / skip) * jct_times.mean(0),
    color=palette[3], ls='-',
    err_style="ci_band", ci=ci, ax=ax, condition='Joint CT'
)
_ = sns.tsplot(
    data=ais_log_norm_ests * num_data,
    time=ais_times.mean(0), interpolate=False,
    color=palette[0], ms=7,
    err_style="ci_bars", ci=ci, ax=ax, condition='AIS'
)
_ = sns.tsplot(time=[0, max_time], data=[log_norm_upper * num_data, log_norm_upper * num_data], 
               color='k', ls='-.', lw=1., condition='BDMC upper')
_ = sns.tsplot(time=[0, max_time], data=[log_norm_lower * num_data, log_norm_lower * num_data], 
               color='r', ls='-.', lw=1., condition='BDMC lower')
ax.legend(ncol=2)
ax.set_xlim(0, max_time)
ax.set_xlabel('Time / s')
ax.set_ylim(-95.7 * num_data, -95.3 * num_data)
ax.set_ylabel('Log marginal likelihood est.')
fig.tight_layout(pad=0)
fig.savefig('mnist-marginal-likelihood-est.pdf')