In [None]:
import os
#os.environ['THEANO_FLAGS'] = 'device=gpu0,floatX=float32'
import time
import numpy as np
import matplotlib.pyplot as plt
import theano as th
import theano.tensor as tt
import theano.d3viz as d3v
import theano.sandbox.rng_mrg as rand
import thermomc.discrete_temp as disc_temp
%matplotlib inline

## Load parameters

In [None]:
base_dir = os.path.dirname(os.getcwd())
model_dir = os.path.join(base_dir, 'data', 'omni-iwae')
decoder = np.load(os.path.join(model_dir, 'decoder_params.npz'))
encoder_h = np.load(os.path.join(model_dir, 'encoder_h_params.npz'))
encoder_mean = np.load(os.path.join(model_dir, 'encoder_mean_params.npz'))
encoder_std = np.load(os.path.join(model_dir, 'encoder_std_params.npz'))

## Define model functions

In [None]:
def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def sigmoidal_schedule(num_temp, scale):
    inv_temp_sched = sigmoid(
        scale * (2. * np.arange(num_temp + 1) / num_temp - 1.))
    return (
        (inv_temp_sched - inv_temp_sched[0]) / 
        (inv_temp_sched[-1] - inv_temp_sched[0])
    )

def rmse(x, y):
    return ((x - y)**2).mean()**0.5

In [None]:
non_linearity_map_np = {
    'nnet.Tanh': np.tanh,
    'nnet.Sigmoid': sigmoid,
    'nnet.Exponential': np.exp
}

In [None]:
def mean_x_gvn_z_np(z):
    h = z
    for i, layer in enumerate(decoder['layers']):
        if layer == 'nnet.Linear':
            W = decoder['W' + str(i)]
            b = decoder['b' + str(i)]
            h = h.dot(W) + b
        else:
            h = non_linearity_map_np[layer](h)
    return h

def mean_and_std_z_gvn_x_np(x):
    h = x
    for i, layer in enumerate(encoder_h['layers']):
        if layer == 'nnet.Linear':
            W = encoder_h['W' + str(i)]
            b = encoder_h['b' + str(i)]
            h = h.dot(W) + b
        else:
            h = non_linearity_map_np[layer](h)
    std = h * 1.
    for i, layer in enumerate(encoder_std['layers']):
        if layer == 'nnet.Linear':
            W = encoder_std['W' + str(i)]
            b = encoder_std['b' + str(i)]
            std = std.dot(W) + b
        else:
            std = non_linearity_map_np[layer](std)
    mean = h * 1.
    for i, layer in enumerate(encoder_mean['layers']):
        if layer == 'nnet.Linear':
            W = encoder_mean['W' + str(i)]
            b = encoder_mean['b' + str(i)]
            mean = mean.dot(W) + b
        else:
            mean = non_linearity_map_np[layer](mean)
    return mean, std

def log_prob_x_gvn_z_np(x, z):
    means = mean_x_gvn_z_np(z)
    return (x * np.log(means) + (1 - x) * np.log(1 - means)).sum(-1)

def log_prob_z_np(z):
    return -0.5 * (z**2).sum(-1) - 0.5 * z.shape[-1] * np.log(2 * np.pi)

def log_prob_x_and_z_np(x, z):
    return log_prob_x_gvn_z_np(x, z) + log_prob_z_np(z)

def log_prob_z_gvn_x_np(x, z, means=None, stds=None):
    if means is None or stds is None:
        means, stds = mean_and_std_z_gvn_x(x)
    return -(
        0.5 * ((z - means)**2 / stds**2).sum(-1) +
        0.5 * z.shape[-1] * np.log(2 * np.pi) +
        np.log(stds).sum(-1)
    )

In [None]:
class PhiFunc(object):
    
    def __init__(self, x, weights, biases, non_linearities, log_zeta):
        self.x = x
        self.weights = weights
        self.biases = biases
        self.non_linearities = non_linearities
        self.log_zeta = log_zeta
        
    def mean_x_gvn_z(self, z):
        h = z
        for W, b, f in zip(self.weights, self.biases, self.non_linearities):
            h = f(h.dot(W) + b)
        return h
        
    def __call__(self, z):
        mean = self.mean_x_gvn_z(z)
        return (
            0.5 * (z**2).sum(-1) + 0.5 * z.shape[-1] * tt.log(2 * np.pi) -
            (self.x * tt.log(mean) + (1 - self.x) * tt.log(1 - mean)).sum(-1) +
            self.log_zeta
        )

class PsiFunc(object):

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, z):
        return (
            0.5 * (((z - self.mean) / (self.std))**2).sum(-1) + 
            0.5 * self.mean.shape[-1] * tt.log(2 * np.pi) +
            tt.log(self.std).sum(-1)
        )

## Generate $\mathbf{x},\,\mathbf{z}$ pair from joint

In [None]:
seed = 201702
rng = np.random.RandomState(seed)

In [None]:
latent_dim = 50
n_samples = 1000
n_reps = 16

In [None]:
zs = rng.normal(size=(n_samples, latent_dim))
means = mean_x_gvn_z_np(zs)
xs = (rng.uniform(size=means.shape) < means) * 1

In [None]:
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)
im_grid = np.zeros((280, 280))
for i, mean in enumerate(means[:100]):
    row = i % 10
    col = i // 10
    im_grid[row * 28 : (row + 1) * 28, col * 28 : (col + 1) * 28] = (
        mean.reshape(28, 28)
    )
ax.imshow(im_grid, cmap='Greys', interpolation='None')
ax.set_xticks([])
ax.set_yticks([])

## Calculate importance-weighted $\log \zeta$ approximations

In [None]:
start_time = time.time()
n_k = 100
log_weights = []
mean_z_gvn_x, std_z_gvn_x = mean_and_std_z_gvn_x_np(xs)
for k in range(n_k):
    n = rng.normal(size=mean_z_gvn_x.shape)
    z = mean_z_gvn_x + std_z_gvn_x * n
    log_weights.append(
        log_prob_x_and_z_np(xs, z) - 
        (-0.5 * (n**2).sum(-1) - 0.5 * z.shape[-1] * np.log(2 * np.pi) - 
         np.log(std_z_gvn_x).sum(-1))
    )
log_weights = np.array(log_weights)
m = np.max(log_weights, 0)
log_zeta = np.log(np.exp(log_weights - m[None, :]).mean(0)) + m
log_zeta_calc_time = time.time() - start_time
print(log_zeta.mean(), log_zeta_calc_time)

## Create repeated model parameters / samples

In [None]:
log_zeta_rep = tt.constant(log_zeta.repeat(n_reps), 'log_zeta', 1, th.config.floatX) 
zs_rep = zs.repeat(n_reps, 0)
xs_rep = tt.constant(
    xs.repeat(n_reps, 0), 'x', 2, th.config.floatX)
mean_z_gvn_x_rep = tt.constant(
    mean_z_gvn_x.repeat(n_reps, 0), 'mean_z_gvn_x', 2, th.config.floatX)
std_z_gvn_x_rep = tt.constant(
    std_z_gvn_x.repeat(n_reps, 0), 'std_z_gvn_x', 2, th.config.floatX)

## Create model and AIS sampler objects

In [None]:
non_linearity_map = {
    'nnet.Tanh': tt.tanh,
    'nnet.Sigmoid': tt.nnet.sigmoid,
    'nnet.Exp': tt.exp
}
non_linearities = [non_linearity_map[name] for name in decoder['layers'] if name != 'nnet.Linear']
weights = [tt.constant(decoder['W' + str(i * 2)], 'dec_W' + str(i), 2, th.config.floatX) for i in range(3)]
biases = [tt.constant(decoder['b' + str(i * 2)], 'dec_b' + str(i), 2, th.config.floatX) for i in range(3)]
phi_func = PhiFunc(xs_rep, weights, biases, non_linearities, log_zeta_rep)
psi_func = PsiFunc(mean_z_gvn_x_rep, std_z_gvn_x_rep)

In [None]:
ais_sampler = disc_temp.AnnealedImportanceSampler(
   rand.MRG_RandomStreams(seed), False)

In [None]:
dt = tt.scalar('dt')
hmc_params = {
    'dt': dt,
    'n_step': 10,
    'mom_resample_coeff': 1.
}

In [None]:
pos = tt.matrix('pos')
inv_temps= tt.vector('inv_temp_sched')
pos_samples, log_weights, accepts, updates = ais_sampler.run(
    pos, None, inv_temps, phi_func, psi_func, hmc_params
)
ais_run = th.function(
    [pos, inv_temps, dt],
    [pos_samples, log_weights, accepts],
    updates=updates
)

## AIS settings

In [None]:
num_temps = 10000
temp_scale = 4.
dt = 0.08
inv_temp_sched = sigmoidal_schedule(num_temps, temp_scale)

##  Forward AIS run

In [None]:
forward_ais_start_time = time.time()
pos_init = rng.normal(size=zs_rep.shape) * std_z_gvn_x_rep.value + mean_z_gvn_x_rep.value
pos_samples, log_weights, accepts = ais_run(
    pos_init.astype(th.config.floatX), inv_temp_sched.astype(th.config.floatX), dt)
forward_ais_time = time.time() - forward_ais_start_time
print(accepts.mean(), forward_ais_time)

## Reverse AIS run

In [None]:
reverse_ais_start_time = time.time()
rev_pos_samples, rev_log_weights, rev_accepts = ais_run(
    zs_rep.astype(th.config.floatX), inv_temp_sched[::-1].astype(th.config.floatX), dt)
reverse_ais_time = time.time() - reverse_ais_start_time
print(rev_accepts.mean(), reverse_ais_time)

## Calculate stochastic lower and upper bounds on $\mathbb{P}[\mathbf{x} = \boldsymbol{x}]$

In [None]:
log_norm_approx = log_zeta.mean()
log_norm_lower = log_norm_approx + np.log(np.exp(log_weights.reshape((-1, n_reps))).mean(-1)).mean(0)
log_norm_upper = log_norm_approx + np.log(np.exp(-rev_log_weights.reshape((-1, n_reps))).mean(-1)).mean(0)
print(log_norm_approx, log_norm_lower, log_norm_upper)

In [None]:
np.savez(
    os.path.join(model_dir, 'joint-sample-and-log-norm-bounds.npz'),
    xs=xs,
    zs=zs,
    seed=seed,
    fwd_log_weights=log_weights,
    rev_log_weights=rev_log_weights,
    log_zeta=log_zeta,
    log_norm_approx=log_norm_approx,
    log_norm_lower=log_norm_lower,
    log_norm_upper=log_norm_upper,
    reverse_ais_time=reverse_ais_time,
    forward_ais_time=forward_ais_time,
    log_zeta_calc_time=log_zeta_calc_time,
    mean_z_gvn_x=mean_z_gvn_x,
    std_z_gvn_x=std_z_gvn_x,
    mean_x_gvn_z=means
)