In [None]:
import os
# Force theano.scan to allow garbage collection to fix issue with memory leak
os.environ['THEANO_FLAGS'] = 'scan.allow_gc=True'
import time
import theano as th
import theano.tensor as tt
import pymc3 as pm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from thermomc import discrete_temp
import seaborn as sns
%matplotlib inline

## Set plot styling options

In [None]:
sns.set_context('paper')
sns.set(font='sans')
sns.set_style('white', {
    'font.family': 'sans',
    'axes.labelcolor': '0.',
    'text.color': '0.',
    'xtick.color': '0.',
    'ytick.color': '0.'
})
palette = sns.color_palette('husl', 5)

## Set up directories

In [None]:
base_dir = os.path.dirname(os.getcwd())
data_dir = os.path.join(base_dir, 'data', 'radon')
exp_dir = os.path.join(base_dir, 'experiments', 'pymc3-radon')
if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)
seed = 201702
rng = np.random.RandomState(seed)

## Load data

In [None]:
data = np.load(os.path.join(data_dir, 'radon-group.npz'))

## Additional PyMC3 classes

In [None]:
class StandardLogistic(pm.distributions.Continuous):

    def __init__(self, *args, **kwargs):
        super(StandardLogistic, self).__init__(*args, **kwargs)
        self.mean = self.mode = 0.

    def random(self, point=None, size=None, repeat=None):
        def _random(size=None):
            return rng.logistic(size==size)

        samples = generate_samples(_random,
                                   dist_shape=self.shape,
                                   broadcast_shape=mu.shape,
                                   size=size)
        return samples

    def logp(self, value):
        p = tt.nnet.sigmoid(value)
        return tt.log(p) + tt.log(1. - p)

class ExtendedModel(pm.Model):
    
    def __init__(self, target_model, base_means, base_stds, log_norm_est):
        super(ExtendedModel, self).__init__()
        self.named_vars.update(target_model.named_vars)
        self.free_RVs += target_model.free_RVs
        self.observed_RVs += target_model.observed_RVs
        self.deterministics += target_model.deterministics
        self.potentials  += target_model.potentials
        self.missing_values += target_model.missing_values
        self.target_model = target_model
        self.base_means = base_means
        self.base_stds = base_stds
        self.log_norm_est = log_norm_est
        with self:
            temp_ctrl = StandardLogistic('temp_ctrl')
        inv_temp = tt.nnet.sigmoid(temp_ctrl)
        inv_temp.name = 'inv_temp'
        delta = self.log_norm_est - self.target_model.logpt + self.base_logpt
        prob_0 = tt.switch(tt.eq(delta, 0.), tt.ones_like(delta),
                           -delta / tt.expm1(-delta))
        prob_0.name = 'prob_0'
        prob_1 = tt.switch(tt.eq(delta, 0.), tt.ones_like(delta),
                           delta / tt.expm1(delta))
        prob_1.name = 'prob_1'
        self.deterministics += [inv_temp, prob_0, prob_1]
        
    @property
    def base_logpt(self):
        base_logp = 0
        for cont_var in self.target_model.cont_vars:
            mean = self.base_means[cont_var.name]
            std = self.base_stds[cont_var.name]
            base_logp -= tt.sum(
                0.5 * ((cont_var - mean) / std)**2 + 
                0.5 * tt.log(2 * np.pi) + tt.log(std)
            )
        return base_logp
    
    @property
    def logpt(self):
        target_logp = self.target_model.logpt
        base_logp = self.base_logpt
        temp_ctrl = self.named_vars['temp_ctrl']
        inv_temp = tt.nnet.sigmoid(temp_ctrl)
        return (inv_temp * target_logp - inv_temp * self.log_norm_est + 
                (1 - inv_temp) * base_logp + temp_ctrl.logpt)

## Define PyMC3 model

In [None]:
with pm.Model() as model:
    sigma_alpha = pm.HalfCauchy('sigma_alpha', beta=2.5)
    mu_alpha = pm.Normal('mu_alpha', mu=0, sd=1)
    n_alpha = pm.MvNormal('n_alpha', mu=tt.zeros(data['n_counties']), 
                          chol=tt.eye(data['n_counties']), 
                          shape=data['n_counties'])
    alpha = mu_alpha + sigma_alpha * n_alpha
    sigma_beta = pm.HalfCauchy('sigma_beta', beta=2.5)
    mu_beta = pm.Normal('mu_beta', mu=0, sd=1)
    n_beta = pm.MvNormal('n_beta', mu=tt.zeros(2), chol=tt.eye(2), shape=2)
    beta = mu_beta + sigma_beta * n_beta
    epsilon = pm.HalfCauchy('epsilon', beta=2.5)
    y_hat = alpha[data['county']] + data['x'] * beta[0] + data['u'] * beta[1]
    y = pm.Normal('y', y_hat, epsilon, observed=data['y'])

## Fit initial base density with ADVI

In [None]:
advi_start_time = time.time()
with model:
    var_params = pm.variational.advi(
        n=30000, accurate_elbo=False, 
        learning_rate=1e-3, tol_obj=1e-3, 
        random_seed=seed)
advi_run_time = time.time() - advi_start_time

In [None]:
var_log_norm_est = var_params.elbo_vals[-100:].mean()
print('Var log norm est={0:.2f}'.format(var_log_norm_est))

In [None]:
np.savez(os.path.join(exp_dir, 'advi-estimate.npz'), 
         var_log_norm_est=var_log_norm_est, advi_run_time=advi_run_time)

In [None]:
extended_model = ExtendedModel(
    model, var_params.means, var_params.stds, var_log_norm_est
)

## Annealed importance sampling for $\log Z$ 'ground truth'

Manually define energy functions (negative log unnormalised densities) for target and base distribution to allow use of separate Theano based annealed importance sampling implementation.

In [None]:
log_zeta = tt.constant(var_log_norm_est, name='log_zeta')

def phi_func(x):
    sigma_alpha_log = x[:, 0]
    sigma_alpha = tt.exp(sigma_alpha_log)
    mu_alpha = x[:, 1]
    sigma_beta_log = x[:, 2]
    sigma_beta = tt.exp(sigma_beta_log)
    mu_beta = x[:, 3]
    epsilon_log = x[:, 4]
    epsilon = tt.exp(epsilon_log)
    n_beta = x[:, 5:7]
    n_alpha = x[:, 7:7 + data['n_counties']]
    alpha = mu_alpha[:, None] + sigma_alpha[:, None] * n_alpha
    beta = mu_beta[:, None] + sigma_beta[:, None] * n_beta
    alpha.name = 'alpha'
    beta.name = 'beta'
    y_hat = (
        alpha[:, data['county']] + 
        data['x'][None, :] * beta[:, 0][:, None] + 
        data['u'][None, :] * beta[:, 1][:, None]
    )
    return log_zeta + (
        0.5 * (n_alpha**2).sum(-1) + 0.5 * data['n_counties'] * tt.log(2. * np.pi) +
        0.5 * (n_beta**2).sum(-1) + 0.5 * 2 * tt.log(2. * np.pi) +
        0.5 * (mu_alpha**2) + 0.5 * tt.log(2. * np.pi) +
        0.5 * (mu_beta**2) + 0.5 * tt.log(2. * np.pi) +
        tt.log1p((sigma_alpha / 2.5)**2) - tt.log(2) + 
        tt.log(np.pi) + tt.log(2.5) - tt.log(sigma_alpha) +
        tt.log1p((sigma_beta / 2.5)**2) - tt.log(2) + 
        tt.log(np.pi) + tt.log(2.5) - tt.log(sigma_beta) +
        tt.log1p((epsilon / 2.5)**2) - tt.log(2) + 
        tt.log(np.pi) + tt.log(2.5) - tt.log(epsilon) +
        0.5 * (((data['y'][None, :] - y_hat) / epsilon[:, None])**2).sum(-1) + 
        0.5 * data['n_data'] * tt.log(2. * np.pi) +
        data['n_data'] * tt.log(epsilon)
    )

def psi_func(x):
    sigma_alpha_log_ = x[:, 0]
    mu_alpha = x[:, 1]
    sigma_beta_log_ = x[:, 2]
    mu_beta = x[:, 3]
    epsilon_log_ = x[:, 4]
    n_beta = x[:, 5:7]
    n_alpha = x[:, 7:7+data['n_counties']]
    params = dict(
        sigma_alpha_log_ = sigma_alpha_log_,
        mu_alpha = mu_alpha,
        sigma_beta_log_ = sigma_beta_log_,
        mu_beta = mu_beta,
        epsilon_log_ = epsilon_log_,
        n_beta = n_beta,
        n_alpha = n_alpha
    )
    psi = 0
    for name, param in params.items():
        mean = extended_model.base_means[name]
        std = extended_model.base_stds[name]
        if param.ndim == 2:
            psi += (
                0.5 * ((param - mean) / std)**2 + 
                0.5 * tt.log(2 * np.pi) + tt.log(std)
            ).sum(-1)
        else:
            psi += (
                0.5 * ((param - mean) / std)**2 + 
                0.5 * tt.log(2 * np.pi) + tt.log(std)
            )
    return psi

In [None]:
ais_sampler = discrete_temp.AnnealedImportanceSampler(
    tt.shared_randomstreams.RandomStreams(seed), False
)

In [None]:
pos = tt.matrix('pos')
inv_temps = tt.vector('inv_temps')
dt = tt.scalar('dt')
n_step = tt.lscalar('n_step')
hmc_params = {
    'dt': dt,
    'n_step': n_step,
    'mom_resample_coeff': 1.,
}
pos_samples, log_weights, accepts, updates = ais_sampler.run(
    pos, None, inv_temps, phi_func, psi_func, hmc_params)
ais_chain_func = th.function(
    [pos, inv_temps, dt, n_step],
    [log_weights, accepts.mean()],
    updates=updates
)

In [None]:
def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def sigmoidal_schedule(num_temp, scale):
    inv_temp_sched = sigmoid(
        scale * (2. * np.arange(num_temp + 1) / num_temp - 1.))
    return (
        (inv_temp_sched - inv_temp_sched[0]) / 
        (inv_temp_sched[-1] - inv_temp_sched[0])
    )

In [None]:
var_mean = np.concatenate([
    extended_model.base_means['sigma_alpha_log_'][None],
    extended_model.base_means['mu_alpha'][None],
    extended_model.base_means['sigma_beta_log_'][None],
    extended_model.base_means['mu_beta'][None],
    extended_model.base_means['epsilon_log_'][None],
    extended_model.base_means['n_beta'],
    extended_model.base_means['n_alpha']
])
var_std = np.concatenate([
    extended_model.base_stds['sigma_alpha_log_'][None],
    extended_model.base_stds['mu_alpha'][None],
    extended_model.base_stds['sigma_beta_log_'][None],
    extended_model.base_stds['mu_beta'][None],
    extended_model.base_stds['epsilon_log_'][None],
    extended_model.base_stds['n_beta'],
    extended_model.base_stds['n_alpha']
])

In [None]:
n_run = 100
n_temp = 10000
dt = 0.025
n_step = 4

In [None]:
inv_temp_sched = sigmoidal_schedule(n_temp, 4.)
n_init = rng.normal(size=(n_run, var_mean.shape[0]))
pos_init = var_mean + n_init * var_std
start_time = time.time()
log_weights, accept = ais_chain_func(pos_init, inv_temp_sched, dt, n_step)
ais_time = time.time() - start_time
print('Accept={0:.2f} Time={1:.1f}s'.format(float(accept), ais_time))
ais_log_norm_est = log_zeta.value + np.log(np.exp(log_weights).mean())
print('AIS log norm est={0:.2f}'.format(ais_log_norm_est))

In [None]:
np.savez(os.path.join(exp_dir, 'ais-estimate.npz'), 
         ais_log_norm_est=ais_log_norm_est, log_weights=log_weights, ais_time=ais_time)

## NUTS run in extended space

In [None]:
rng.seed(seed)
num_reps = 10
nuts_1_times = np.empty(num_reps) * np.nan
nuts_1_traces = [None] * num_reps
nuts_1_seeds = rng.randint(1000000, size=num_reps)
scales = {'temp_ctrl': np.array(5.)}
scales.update(var_params.stds)
scales = extended_model.dict_to_array(scales)**2
with model:
    start = pm.variational.sample_vp(var_params, num_reps, random_seed=seed)
for r in range(num_reps):
    start[r]['temp_ctrl'] = rng.logistic()
    with extended_model:
        step = pm.NUTS(scaling=scales, is_cov=True)
        start_time = time.time()
        nuts_1_traces[r] = pm.sample(2500, step, start=start[r], init=None, 
                                     progressbar=True, tune=100, random_seed=nuts_1_seeds[r])
        nuts_1_times[r] = time.time() - start_time

In [None]:
for r, trace in enumerate(nuts_1_traces):
    print('Chain {0}'.format(r))
    print('  diverging={0}, mean accept={1:.2f}'.format(
        trace.get_sampler_stats('diverging', combine=True).sum(),
        trace.get_sampler_stats('mean_tree_accept', combine=True).mean(),
    ))

In [None]:
nuts_1_log_norm_ests = np.stack([
    extended_model.log_norm_est + 
    np.log(np.array(nuts_1_traces[r].get_values('prob_1')).cumsum(0)) - 
    np.log(np.array(nuts_1_traces[r].get_values('prob_0')).cumsum(0))
    for r in range(num_reps)
])

In [None]:
np.savez(os.path.join(exp_dir, 'intial-nuts-estimate.npz'), 
         log_norm_ests=nuts_1_log_norm_ests, run_time=nuts_1_times)

In [None]:
for trace in nuts_1_traces:
    _ = pm.plots.traceplot(trace, ['temp_ctrl'])

In [None]:
for trace in nuts_1_traces:
    _ = pm.plots.traceplot(trace, ['inv_temp'])

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
_ = sns.tsplot(
    data=nuts_1_log_norm_ests, 
    time=advi_run_time + np.linspace(0, nuts_1_times.mean(), nuts_1_log_norm_ests.shape[-1]),
    color=palette[2],
    err_style="ci_band", ci=[95], ax=ax, condition='CT NUTS', lw=1.5
)
ax.plot(np.linspace(0, advi_run_time, var_params.elbo_vals.shape[0] // 100), 
        var_params.elbo_vals.reshape(-1, 100).mean(-1), lw=1.)
ax.plot([0, 60], [var_log_norm_est, var_log_norm_est], 'r-.', lw=1.)
ax.plot([0, 60], [ais_log_norm_est, ais_log_norm_est], 'k--', lw=1.)
ax.set_xlim(0, 60)
ax.set_ylim(-1100, -1070)
ax.plot([advi_run_time, advi_run_time], ax.get_ylim(), 'k:', lw=1.5)
ax.set_xlabel('Time / s')
ax.set_ylabel('Log marginal likelihood est.')
ax.legend(['CT NUTS', 'ADVI', 'ADVI (final)', 'AIS'], loc='lower right', ncol=2)
fig.tight_layout()
fig.savefig(os.path.join(exp_dir, 'hier-lin-regression-marg-lik.pdf'))

## Update base density and $\log \zeta$

In [None]:
rng.seed(seed)
extended_model_list = [None] * num_reps
scales_list =[None] * num_reps
base_params_list = [None] * num_reps
start = [None] * num_reps
for r, trace in enumerate(nuts_1_traces):
    sum_prob_1 = trace.get_values('prob_1').sum()
    sum_prob_0 = trace.get_values('prob_0').sum()
    log_norm_est = extended_model.log_norm_est + np.log(sum_prob_1) - np.log(sum_prob_0)
    sample_means = {
        param: (trace.get_values('prob_1') * trace.get_values(param).T).sum(-1) 
        / sum_prob_1 
        for param in var_params.means.keys()
    }
    sample_stds = {
        param: ((trace.get_values('prob_1') * 
                (trace.get_values(param) - sample_means[param]).T**2).sum(-1) 
                / sum_prob_1)**0.5
        for param in var_params.means.keys()
    }
    start[r] = pm.variational.sample_vp(
        {'means': sample_means, 'stds': sample_stds}, 1, model=model, random_seed=nuts_1_seeds[r])
    start[r][0]['temp_ctrl'] = rng.logistic()
    extended_model_list[r] = ExtendedModel(
        model, sample_means, sample_stds, log_norm_est
    )
    scales_list[r] = {'temp_ctrl': np.array(1.)}
    scales_list[r].update(sample_stds)
    scales_list[r] = extended_model_list[r].dict_to_array(scales_list[r])**2

## Final NUTS run

In [None]:
rng.seed(seed)
nuts_2_times = np.empty(num_reps) * np.nan
nuts_2_traces = [None] * num_reps
nuts_2_seeds = rng.randint(1000000, size=num_reps)
scales = {'temp_ctrl': np.array(5.)}
scales.update(var_params.stds)
scales = extended_model.dict_to_array(scales)**2
for r in range(num_reps):
    with extended_model_list[r]:
        step = pm.NUTS(scaling=scales_list[r], is_cov=True)
        start_time = time.time()
        nuts_2_traces[r] = pm.sample(5000, step, start=start[r][0], init=None, 
                                     progressbar=True, tune=500, random_seed=nuts_2_seeds[r])
        nuts_2_times[r] = time.time() - start_time

In [None]:
for r, trace in enumerate(nuts_2_traces):
    print('Chain {0}'.format(r))
    print('  diverging={0}, mean accept={1:.2f}'.format(
        trace.get_sampler_stats('diverging', combine=True).sum(),
        trace.get_sampler_stats('mean_tree_accept', combine=True).mean(),
    ))

In [None]:
nuts_2_log_norm_ests = np.stack([
    extended_model.log_norm_est + 
    np.log(np.array(nuts_1_traces[r].get_values('prob_1')).cumsum(0)) - 
    np.log(np.array(nuts_1_traces[r].get_values('prob_0')).cumsum(0))
    for r in range(num_reps)
])

In [None]:
np.savez(os.path.join(exp_dir, 'final-nuts-estimate.npz'), 
         log_norm_ests=nuts_2_log_norm_ests, run_time=nuts_2_times)

In [None]:
for trace in nuts_2_traces:
    _ = pm.plots.traceplot(trace, ['temp_ctrl'])

In [None]:
for trace in nuts_2_traces:
    _ = pm.plots.traceplot(trace, ['inv_temp'])