In [1]:
import argparse
import math
import os
import time
from functools import partial

import numpy as np
import torch
import visdom

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from components.AIR import AIR, latents_to_tensor
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from utils.viz import draw_many, tensor_to_objs

## Support functions

In [2]:
def count_accuracy(X, true_counts, air, batch_size):
    assert X.size(0) == true_counts.size(0), 'Size mismatch.'
    assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.'
    counts = torch.LongTensor(3, 4).zero_()
    error_latents = []
    error_indicators = []

    def count_vec_to_mat(vec, max_index):
        out = torch.LongTensor(vec.size(0), max_index + 1).zero_()
        out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1)
        return out

    for i in range(X.size(0) // batch_size):
        X_batch = X[i * batch_size:(i + 1) * batch_size]
        true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size]
        z_where, z_pres = air.guide(X_batch, batch_size)
        inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data
        true_counts_m = count_vec_to_mat(true_counts_batch, 2)
        inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
        counts += torch.mm(true_counts_m.t(), inferred_counts_m)
        error_ind = 1 - (true_counts_batch == inferred_counts)
        error_ix = error_ind.nonzero().squeeze()
        error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix))
        error_indicators.append(error_ind)

    acc = counts.diag().sum().float() / X.size(0)
    error_indices = torch.cat(error_indicators).nonzero().squeeze()
    if X.is_cuda:
        error_indices = error_indices.cuda()
    return acc, counts, torch.cat(error_latents), error_indices


# Defines something like a truncated geometric. Like the geometric,
# this has the property that there's a constant difference in log prob
# between p(steps=n) and p(steps=n+1).
def make_prior(k):
    assert 0 < k <= 1
    u = 1 / (1 + k + k**2 + k**3)
    p0 = 1 - u
    p1 = 1 - (k * u) / p0
    p2 = 1 - (k**2 * u) / (p0 * p1)
    trial_probs = [p0, p1, p2]
    # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
    # print(dist)
    return lambda t: trial_probs[t]


# Implements "prior annealing" as described in this blog post:
# http://akosiorek.github.io/ml/2017/09/03/implementing-air.html

# That implementation does something very close to the following:
# --z-pres-prior (1 - 1e-15)
# --z-pres-prior-raw
# --anneal-prior exp
# --anneal-prior-to 1e-7
# --anneal-prior-begin 1000
# --anneal-prior-duration 1e6

# e.g. After 200K steps z_pres_p will have decayed to ~0.04

# These compute the value of a decaying value at time t.
# initial: initial value
# final: final value, reached after begin + duration steps
# begin: number of steps before decay begins
# duration: number of steps over which decay occurs
# t: current time step


def lin_decay(initial, final, begin, duration, t):
    assert duration > 0
    x = (final - initial) * (t - begin) / duration + initial
    return max(min(x, initial), final)


def exp_decay(initial, final, begin, duration, t):
    assert final > 0
    assert duration > 0
    # half_life = math.log(2) / math.log(initial / final) * duration
    decay_rate = math.log(initial / final) / duration
    x = initial * math.exp(-decay_rate * (t - begin))
    return max(min(x, initial), final)

## Args

In [3]:
parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
                    help='number of optimization steps to take')
parser.add_argument('-b', '--batch-size', type=int, default=64,
                    help='batch size')
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3,
                    help='baseline learning rate')
parser.add_argument('--progress-every', type=int, default=1,
                    help='number of steps between writing progress to stdout')
parser.add_argument('--eval-every', type=int, default=0,
                    help='number of steps between evaluations')
parser.add_argument('--baseline-scalar', type=float,
                    help='scale the output of the baseline nets by this value')
parser.add_argument('--no-baselines', action='store_true', default=False,
                    help='do not use data dependent baselines')
parser.add_argument('--encoder-net', type=int, nargs='+', default=[200],
                    help='encoder net hidden layer sizes')
parser.add_argument('--decoder-net', type=int, nargs='+', default=[200],
                    help='decoder net hidden layer sizes')
parser.add_argument('--predict-net', type=int, nargs='+',
                    help='predict net hidden layer sizes')
parser.add_argument('--embed-net', type=int, nargs='+',
                    help='embed net architecture')
parser.add_argument('--bl-predict-net', type=int, nargs='+',
                    help='baseline predict net hidden layer sizes')
parser.add_argument('--non-linearity', type=str,
                    help='non linearity to use throughout')
parser.add_argument('--viz', action='store_true', default=True,
                    help='generate vizualizations during optimization')
parser.add_argument('--viz-every', type=int, default=10,
                    help='number of steps between vizualizations')
parser.add_argument('--visdom-env', default='main',
                    help='visdom enviroment name')
parser.add_argument('--load', type=str,
                    help='load previously saved parameters')
parser.add_argument('--save', type=str,
                    help='save parameters to specified file')
parser.add_argument('--save-every', type=int, default=1e4,
                    help='number of steps between parameter saves')
parser.add_argument('--cuda', action='store_true', default=False,
                    help='use cuda')
parser.add_argument('--jit', action='store_true', default=False,
                    help='use PyTorch jit')
parser.add_argument('-t', '--model-steps', type=int, default=3,
                    help='number of time steps')
parser.add_argument('--rnn-hidden-size', type=int, default=256,
                    help='rnn hidden size')
parser.add_argument('--encoder-latent-size', type=int, default=50,
                    help='attention window encoder/decoder latent space size')
parser.add_argument('--decoder-output-bias', type=float,
                    help='bias added to decoder output (prior to applying non-linearity)')
parser.add_argument('--decoder-output-use-sigmoid', action='store_true',
                    help='apply sigmoid function to output of decoder network')
parser.add_argument('--window-size', type=int, default=28,
                    help='attention window size')
parser.add_argument('--z-pres-prior', type=float, default=0.5,
                    help='prior success probability for z_pres')
parser.add_argument('--z-pres-prior-raw', action='store_true', default=False,
                    help='use --z-pres-prior directly as success prob instead of a geometric like prior')
parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none',
                    help='anneal z_pres prior during optimization')
parser.add_argument('--anneal-prior-to', type=float, default=1e-7,
                    help='target z_pres prior prob')
parser.add_argument('--anneal-prior-begin', type=int, default=0,
                    help='number of steps to wait before beginning to anneal the prior')
parser.add_argument('--anneal-prior-duration', type=int, default=100000,
                    help='number of steps over which to anneal the prior')
parser.add_argument('--pos-prior-mean', type=float,
                    help='mean of the window position prior')
parser.add_argument('--pos-prior-sd', type=float,
                    help='std. dev. of the window position prior')
parser.add_argument('--scale-prior-mean', type=float,
                    help='mean of the window scale prior')
parser.add_argument('--scale-prior-sd', type=float,
                    help='std. dev. of the window scale prior')
parser.add_argument('--no-masking', action='store_true', default=False,
                    help='do not mask out the costs of unused choices')
parser.add_argument('--seed', type=int, help='random seed', default=None)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='write hyper parameters and network architecture to stdout')

_StoreTrueAction(option_strings=['-v', '--verbose'], dest='verbose', nargs=0, const=True, default=False, type=None, choices=None, help='write hyper parameters and network architecture to stdout', metavar=None)

In [4]:
# vars(parser.parse_args(""))
args = argparse.Namespace(**vars(parser.parse_args("")))

In [5]:
if 'save' in args:
    if os.path.exists(args.save):
        raise RuntimeError('Output file "{}" already exists.'.format(args.save))

if args.seed is not None:
    pyro.set_rng_seed(args.seed)

# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
    def base_z_pres_prior_p(t):
        return args.z_pres_prior
else:
    base_z_pres_prior_p = make_prior(args.z_pres_prior)

# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
    p = base_z_pres_prior_p(time_step)
    if args.anneal_prior == 'none':
        return p
    else:
        decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
        return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                     args.anneal_prior_duration, opt_step)



## Data loader

In [6]:
def load_data():
#     inpath = './air/.data'
#     X_np, Y = multi_mnist.load(inpath)
    X_np = np.load('/Users/chamathabeysinghe/Projects/monash/test/variational_auto_encoder/data/ANTS2/masks_50x50.npy')
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([1 for objs in X_np])
    return X, counts

In [7]:
X, true_counts = load_data()
X_size = X.size(0)
if args.cuda:
    X = X.cuda()

In [8]:
X.shape

torch.Size([5339, 50, 50])

## Model

In [9]:
model_arg_keys = ['window_size',
                  'rnn_hidden_size',
                  'decoder_output_bias',
                  'decoder_output_use_sigmoid',
                  'baseline_scalar',
                  'encoder_net',
                  'decoder_net',
                  'predict_net',
                  'embed_net',
                  'bl_predict_net',
                  'non_linearity',
                  'pos_prior_mean',
                  'pos_prior_sd',
                  'scale_prior_mean',
                  'scale_prior_sd']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}

In [10]:
air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

In [11]:
if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

## Visualize

In [12]:
if args.viz:
    vis = visdom.Visdom(env=args.visdom_env)
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

Setting up a new session...


## Optimizer

In [13]:
def isBaselineParam(module_name, param_name):
    return 'bl_' in module_name or 'bl_' in param_name

def per_param_optim_args(module_name, param_name):
    lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
    return {'lr': lr}

adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)

In [14]:
t0 = time.time()
examples_to_viz = X[5:10]

for i in range(1, args.num_steps + 1):

    loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

    if args.progress_every > 0 and i % args.progress_every == 0:
        print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
            i,
            (i * args.batch_size) / X_size,
            (time.time() - t0) / 3600,
            loss / X_size))

    if args.viz and i % args.viz_every == 0:
        trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
        z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
        z_wheres = tensor_to_objs(latents_to_tensor(z))

        # Show data with inferred objection positions.
        vis.images(draw_many(examples_to_viz, z_wheres))
        # Show reconstructions of data.
        vis.images(draw_many(recons, z_wheres))

    if args.eval_every > 0 and i % args.eval_every == 0:
        # Measure accuracy on subset of training data.
        acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
        print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
        if args.viz and error_ix.size(0) > 0:
            vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                       opts=dict(caption='errors ({})'.format(i)))

    if 'save' in args and i % args.save_every == 0:
        print('Saving parameters...')
        torch.save(air.state_dict(), args.save)

i=1, epochs=0.01, elapsed=0.00, elbo=-680.42
i=2, epochs=0.02, elapsed=0.00, elbo=-679.05
i=3, epochs=0.04, elapsed=0.00, elbo=-678.41
i=4, epochs=0.05, elapsed=0.00, elbo=-678.74
i=5, epochs=0.06, elapsed=0.00, elbo=-682.42
i=6, epochs=0.07, elapsed=0.00, elbo=-681.29
i=7, epochs=0.08, elapsed=0.00, elbo=-685.39
i=8, epochs=0.10, elapsed=0.00, elbo=-684.21
i=9, epochs=0.11, elapsed=0.00, elbo=-678.91
i=10, epochs=0.12, elapsed=0.00, elbo=-681.83
i=11, epochs=0.13, elapsed=0.00, elbo=-683.92
i=12, epochs=0.14, elapsed=0.00, elbo=-688.95
i=13, epochs=0.16, elapsed=0.00, elbo=-682.25
i=14, epochs=0.17, elapsed=0.00, elbo=-679.49
i=15, epochs=0.18, elapsed=0.00, elbo=-689.83
i=16, epochs=0.19, elapsed=0.00, elbo=-678.59
i=17, epochs=0.20, elapsed=0.00, elbo=-689.76
i=18, epochs=0.22, elapsed=0.00, elbo=-685.74
i=19, epochs=0.23, elapsed=0.00, elbo=-679.95
i=20, epochs=0.24, elapsed=0.00, elbo=-681.76
i=21, epochs=0.25, elapsed=0.00, elbo=-684.12
i=22, epochs=0.26, elapsed=0.00, elbo=-685.

i=179, epochs=2.15, elapsed=0.01, elbo=-694.62
i=180, epochs=2.16, elapsed=0.01, elbo=-692.25
i=181, epochs=2.17, elapsed=0.01, elbo=-689.71
i=182, epochs=2.18, elapsed=0.01, elbo=-691.09
i=183, epochs=2.19, elapsed=0.01, elbo=-691.85
i=184, epochs=2.21, elapsed=0.01, elbo=-695.07
i=185, epochs=2.22, elapsed=0.01, elbo=-690.82
i=186, epochs=2.23, elapsed=0.01, elbo=-693.75
i=187, epochs=2.24, elapsed=0.01, elbo=-688.88
i=188, epochs=2.25, elapsed=0.01, elbo=-694.11
i=189, epochs=2.27, elapsed=0.01, elbo=-693.01
i=190, epochs=2.28, elapsed=0.01, elbo=-693.20
i=191, epochs=2.29, elapsed=0.01, elbo=-693.64
i=192, epochs=2.30, elapsed=0.01, elbo=-689.25
i=193, epochs=2.31, elapsed=0.01, elbo=-692.45
i=194, epochs=2.33, elapsed=0.01, elbo=-694.92
i=195, epochs=2.34, elapsed=0.01, elbo=-692.96
i=196, epochs=2.35, elapsed=0.01, elbo=-694.56
i=197, epochs=2.36, elapsed=0.01, elbo=-694.36
i=198, epochs=2.37, elapsed=0.01, elbo=-692.30
i=199, epochs=2.39, elapsed=0.01, elbo=-693.89
i=200, epochs

i=354, epochs=4.24, elapsed=0.02, elbo=-694.20
i=355, epochs=4.26, elapsed=0.02, elbo=-695.41
i=356, epochs=4.27, elapsed=0.02, elbo=-697.81
i=357, epochs=4.28, elapsed=0.02, elbo=-695.69
i=358, epochs=4.29, elapsed=0.02, elbo=-695.36
i=359, epochs=4.30, elapsed=0.02, elbo=-697.42
i=360, epochs=4.32, elapsed=0.02, elbo=-696.73
i=361, epochs=4.33, elapsed=0.02, elbo=-694.22
i=362, epochs=4.34, elapsed=0.02, elbo=-694.50
i=363, epochs=4.35, elapsed=0.02, elbo=-694.97
i=364, epochs=4.36, elapsed=0.02, elbo=-697.42
i=365, epochs=4.38, elapsed=0.02, elbo=-697.87
i=366, epochs=4.39, elapsed=0.02, elbo=-694.05
i=367, epochs=4.40, elapsed=0.02, elbo=-696.58
i=368, epochs=4.41, elapsed=0.02, elbo=-692.42
i=369, epochs=4.42, elapsed=0.02, elbo=-698.59
i=370, epochs=4.44, elapsed=0.02, elbo=-696.65
i=371, epochs=4.45, elapsed=0.02, elbo=-695.82
i=372, epochs=4.46, elapsed=0.02, elbo=-695.04
i=373, epochs=4.47, elapsed=0.02, elbo=-696.46
i=374, epochs=4.48, elapsed=0.02, elbo=-695.37
i=375, epochs

i=529, epochs=6.34, elapsed=0.03, elbo=-700.45
i=530, epochs=6.35, elapsed=0.03, elbo=-698.18
i=531, epochs=6.37, elapsed=0.03, elbo=-696.20
i=532, epochs=6.38, elapsed=0.03, elbo=-697.21
i=533, epochs=6.39, elapsed=0.03, elbo=-699.31
i=534, epochs=6.40, elapsed=0.03, elbo=-698.21
i=535, epochs=6.41, elapsed=0.03, elbo=-696.83
i=536, epochs=6.43, elapsed=0.03, elbo=-697.62
i=537, epochs=6.44, elapsed=0.03, elbo=-699.21
i=538, epochs=6.45, elapsed=0.03, elbo=-698.55
i=539, epochs=6.46, elapsed=0.03, elbo=-698.40
i=540, epochs=6.47, elapsed=0.03, elbo=-699.09
i=541, epochs=6.49, elapsed=0.03, elbo=-695.44
i=542, epochs=6.50, elapsed=0.03, elbo=-696.56
i=543, epochs=6.51, elapsed=0.03, elbo=-700.27
i=544, epochs=6.52, elapsed=0.03, elbo=-697.95
i=545, epochs=6.53, elapsed=0.03, elbo=-698.57
i=546, epochs=6.55, elapsed=0.03, elbo=-700.19
i=547, epochs=6.56, elapsed=0.03, elbo=-699.10
i=548, epochs=6.57, elapsed=0.03, elbo=-697.48
i=549, epochs=6.58, elapsed=0.03, elbo=-699.58
i=550, epochs

i=705, epochs=8.45, elapsed=0.04, elbo=-701.57
i=706, epochs=8.46, elapsed=0.04, elbo=-695.99
i=707, epochs=8.47, elapsed=0.04, elbo=-699.89
i=708, epochs=8.49, elapsed=0.04, elbo=-700.33
i=709, epochs=8.50, elapsed=0.04, elbo=-700.42
i=710, epochs=8.51, elapsed=0.04, elbo=-700.79
i=711, epochs=8.52, elapsed=0.04, elbo=-697.19
i=712, epochs=8.53, elapsed=0.04, elbo=-699.74
i=713, epochs=8.55, elapsed=0.04, elbo=-701.19
i=714, epochs=8.56, elapsed=0.04, elbo=-699.98
i=715, epochs=8.57, elapsed=0.04, elbo=-701.57
i=716, epochs=8.58, elapsed=0.04, elbo=-699.66
i=717, epochs=8.59, elapsed=0.04, elbo=-700.44
i=718, epochs=8.61, elapsed=0.04, elbo=-699.82
i=719, epochs=8.62, elapsed=0.04, elbo=-701.00
i=720, epochs=8.63, elapsed=0.04, elbo=-699.11
i=721, epochs=8.64, elapsed=0.04, elbo=-700.86
i=722, epochs=8.65, elapsed=0.04, elbo=-700.17
i=723, epochs=8.67, elapsed=0.04, elbo=-699.10
i=724, epochs=8.68, elapsed=0.04, elbo=-699.50
i=725, epochs=8.69, elapsed=0.04, elbo=-696.91
i=726, epochs

i=879, epochs=10.54, elapsed=0.05, elbo=-702.38
i=880, epochs=10.55, elapsed=0.05, elbo=-698.60
i=881, epochs=10.56, elapsed=0.05, elbo=-699.86
i=882, epochs=10.57, elapsed=0.05, elbo=-702.04
i=883, epochs=10.58, elapsed=0.05, elbo=-700.92
i=884, epochs=10.60, elapsed=0.05, elbo=-701.10
i=885, epochs=10.61, elapsed=0.05, elbo=-700.80
i=886, epochs=10.62, elapsed=0.05, elbo=-701.01
i=887, epochs=10.63, elapsed=0.05, elbo=-700.80
i=888, epochs=10.64, elapsed=0.05, elbo=-702.02
i=889, epochs=10.66, elapsed=0.05, elbo=-701.05
i=890, epochs=10.67, elapsed=0.05, elbo=-699.60
i=891, epochs=10.68, elapsed=0.05, elbo=-701.48
i=892, epochs=10.69, elapsed=0.05, elbo=-702.59
i=893, epochs=10.70, elapsed=0.05, elbo=-700.94
i=894, epochs=10.72, elapsed=0.05, elbo=-701.01
i=895, epochs=10.73, elapsed=0.05, elbo=-702.13
i=896, epochs=10.74, elapsed=0.05, elbo=-699.56
i=897, epochs=10.75, elapsed=0.05, elbo=-700.02
i=898, epochs=10.76, elapsed=0.05, elbo=-700.89
i=899, epochs=10.78, elapsed=0.05, elbo=

i=1049, epochs=12.57, elapsed=0.05, elbo=-700.95
i=1050, epochs=12.59, elapsed=0.05, elbo=-701.94
i=1051, epochs=12.60, elapsed=0.05, elbo=-702.59
i=1052, epochs=12.61, elapsed=0.05, elbo=-700.59
i=1053, epochs=12.62, elapsed=0.05, elbo=-701.13
i=1054, epochs=12.63, elapsed=0.05, elbo=-700.41
i=1055, epochs=12.65, elapsed=0.05, elbo=-700.66
i=1056, epochs=12.66, elapsed=0.05, elbo=-701.08
i=1057, epochs=12.67, elapsed=0.05, elbo=-701.76
i=1058, epochs=12.68, elapsed=0.06, elbo=-700.48
i=1059, epochs=12.69, elapsed=0.06, elbo=-700.78
i=1060, epochs=12.71, elapsed=0.06, elbo=-701.79
i=1061, epochs=12.72, elapsed=0.06, elbo=-701.86
i=1062, epochs=12.73, elapsed=0.06, elbo=-700.89
i=1063, epochs=12.74, elapsed=0.06, elbo=-702.35
i=1064, epochs=12.75, elapsed=0.06, elbo=-700.99
i=1065, epochs=12.77, elapsed=0.06, elbo=-700.82
i=1066, epochs=12.78, elapsed=0.06, elbo=-701.91
i=1067, epochs=12.79, elapsed=0.06, elbo=-701.37
i=1068, epochs=12.80, elapsed=0.06, elbo=-699.70
i=1069, epochs=12.81

i=1217, epochs=14.59, elapsed=0.06, elbo=-700.94
i=1218, epochs=14.60, elapsed=0.06, elbo=-702.01
i=1219, epochs=14.61, elapsed=0.06, elbo=-701.07
i=1220, epochs=14.62, elapsed=0.06, elbo=-701.60
i=1221, epochs=14.64, elapsed=0.06, elbo=-702.09
i=1222, epochs=14.65, elapsed=0.06, elbo=-701.12
i=1223, epochs=14.66, elapsed=0.06, elbo=-700.52
i=1224, epochs=14.67, elapsed=0.06, elbo=-699.69
i=1225, epochs=14.68, elapsed=0.06, elbo=-702.09
i=1226, epochs=14.70, elapsed=0.06, elbo=-698.98
i=1227, epochs=14.71, elapsed=0.06, elbo=-701.67
i=1228, epochs=14.72, elapsed=0.06, elbo=-701.22
i=1229, epochs=14.73, elapsed=0.06, elbo=-702.08
i=1230, epochs=14.74, elapsed=0.06, elbo=-702.47
i=1231, epochs=14.76, elapsed=0.06, elbo=-701.39
i=1232, epochs=14.77, elapsed=0.06, elbo=-702.77
i=1233, epochs=14.78, elapsed=0.06, elbo=-700.79
i=1234, epochs=14.79, elapsed=0.06, elbo=-701.22
i=1235, epochs=14.80, elapsed=0.06, elbo=-700.11
i=1236, epochs=14.82, elapsed=0.06, elbo=-701.75
i=1237, epochs=14.83

i=1385, epochs=16.60, elapsed=0.07, elbo=-701.38
i=1386, epochs=16.61, elapsed=0.07, elbo=-701.53
i=1387, epochs=16.63, elapsed=0.07, elbo=-701.69
i=1388, epochs=16.64, elapsed=0.07, elbo=-701.88
i=1389, epochs=16.65, elapsed=0.07, elbo=-702.07
i=1390, epochs=16.66, elapsed=0.07, elbo=-701.73
i=1391, epochs=16.67, elapsed=0.07, elbo=-701.00
i=1392, epochs=16.69, elapsed=0.07, elbo=-698.61
i=1393, epochs=16.70, elapsed=0.07, elbo=-701.38
i=1394, epochs=16.71, elapsed=0.07, elbo=-701.13
i=1395, epochs=16.72, elapsed=0.07, elbo=-703.11
i=1396, epochs=16.73, elapsed=0.07, elbo=-702.22
i=1397, epochs=16.75, elapsed=0.07, elbo=-702.16
i=1398, epochs=16.76, elapsed=0.07, elbo=-700.86
i=1399, epochs=16.77, elapsed=0.07, elbo=-701.69
i=1400, epochs=16.78, elapsed=0.07, elbo=-701.38
i=1401, epochs=16.79, elapsed=0.07, elbo=-701.64
i=1402, epochs=16.81, elapsed=0.07, elbo=-701.17
i=1403, epochs=16.82, elapsed=0.07, elbo=-702.76
i=1404, epochs=16.83, elapsed=0.07, elbo=-701.60
i=1405, epochs=16.84

i=1553, epochs=18.62, elapsed=0.08, elbo=-701.36
i=1554, epochs=18.63, elapsed=0.08, elbo=-701.09
i=1555, epochs=18.64, elapsed=0.08, elbo=-701.49
i=1556, epochs=18.65, elapsed=0.08, elbo=-698.93
i=1557, epochs=18.66, elapsed=0.08, elbo=-702.14
i=1558, epochs=18.68, elapsed=0.08, elbo=-702.08
i=1559, epochs=18.69, elapsed=0.08, elbo=-703.52
i=1560, epochs=18.70, elapsed=0.08, elbo=-701.79
i=1561, epochs=18.71, elapsed=0.08, elbo=-702.62
i=1562, epochs=18.72, elapsed=0.08, elbo=-702.74
i=1563, epochs=18.74, elapsed=0.08, elbo=-701.42
i=1564, epochs=18.75, elapsed=0.08, elbo=-701.53
i=1565, epochs=18.76, elapsed=0.08, elbo=-701.78
i=1566, epochs=18.77, elapsed=0.08, elbo=-702.75
i=1567, epochs=18.78, elapsed=0.08, elbo=-701.17
i=1568, epochs=18.80, elapsed=0.08, elbo=-702.32
i=1569, epochs=18.81, elapsed=0.08, elbo=-701.82
i=1570, epochs=18.82, elapsed=0.08, elbo=-702.21
i=1571, epochs=18.83, elapsed=0.08, elbo=-702.97
i=1572, epochs=18.84, elapsed=0.08, elbo=-703.37
i=1573, epochs=18.86

i=1721, epochs=20.63, elapsed=0.09, elbo=-698.42
i=1722, epochs=20.64, elapsed=0.09, elbo=-703.38
i=1723, epochs=20.65, elapsed=0.09, elbo=-702.96
i=1724, epochs=20.67, elapsed=0.09, elbo=-701.92
i=1725, epochs=20.68, elapsed=0.09, elbo=-701.98
i=1726, epochs=20.69, elapsed=0.09, elbo=-701.38
i=1727, epochs=20.70, elapsed=0.09, elbo=-701.58
i=1728, epochs=20.71, elapsed=0.09, elbo=-701.70
i=1729, epochs=20.73, elapsed=0.09, elbo=-702.45
i=1730, epochs=20.74, elapsed=0.09, elbo=-700.79
i=1731, epochs=20.75, elapsed=0.09, elbo=-701.22
i=1732, epochs=20.76, elapsed=0.09, elbo=-698.39
i=1733, epochs=20.77, elapsed=0.09, elbo=-701.52
i=1734, epochs=20.79, elapsed=0.09, elbo=-702.41
i=1735, epochs=20.80, elapsed=0.09, elbo=-700.81
i=1736, epochs=20.81, elapsed=0.09, elbo=-702.37
i=1737, epochs=20.82, elapsed=0.09, elbo=-701.37
i=1738, epochs=20.83, elapsed=0.09, elbo=-702.36
i=1739, epochs=20.85, elapsed=0.09, elbo=-701.76
i=1740, epochs=20.86, elapsed=0.09, elbo=-703.42
i=1741, epochs=20.87

i=1890, epochs=22.66, elapsed=0.10, elbo=-700.88
i=1891, epochs=22.67, elapsed=0.10, elbo=-701.69
i=1892, epochs=22.68, elapsed=0.10, elbo=-703.02
i=1893, epochs=22.69, elapsed=0.10, elbo=-702.79
i=1894, epochs=22.70, elapsed=0.10, elbo=-701.64
i=1895, epochs=22.72, elapsed=0.10, elbo=-702.27
i=1896, epochs=22.73, elapsed=0.10, elbo=-701.84
i=1897, epochs=22.74, elapsed=0.10, elbo=-703.08
i=1898, epochs=22.75, elapsed=0.10, elbo=-702.71
i=1899, epochs=22.76, elapsed=0.10, elbo=-701.96
i=1900, epochs=22.78, elapsed=0.10, elbo=-703.37
i=1901, epochs=22.79, elapsed=0.10, elbo=-702.69
i=1902, epochs=22.80, elapsed=0.10, elbo=-703.88
i=1903, epochs=22.81, elapsed=0.10, elbo=-704.04
i=1904, epochs=22.82, elapsed=0.10, elbo=-700.68
i=1905, epochs=22.84, elapsed=0.10, elbo=-703.40
i=1906, epochs=22.85, elapsed=0.10, elbo=-698.09
i=1907, epochs=22.86, elapsed=0.10, elbo=-703.65
i=1908, epochs=22.87, elapsed=0.10, elbo=-701.73
i=1909, epochs=22.88, elapsed=0.10, elbo=-701.80
i=1910, epochs=22.90

i=2059, epochs=24.68, elapsed=0.11, elbo=-701.58
i=2060, epochs=24.69, elapsed=0.11, elbo=-703.92
i=2061, epochs=24.71, elapsed=0.11, elbo=-703.20
i=2062, epochs=24.72, elapsed=0.11, elbo=-702.60
i=2063, epochs=24.73, elapsed=0.11, elbo=-702.43
i=2064, epochs=24.74, elapsed=0.11, elbo=-698.12
i=2065, epochs=24.75, elapsed=0.11, elbo=-703.03
i=2066, epochs=24.77, elapsed=0.11, elbo=-702.56
i=2067, epochs=24.78, elapsed=0.11, elbo=-702.37
i=2068, epochs=24.79, elapsed=0.11, elbo=-702.71
i=2069, epochs=24.80, elapsed=0.11, elbo=-698.59
i=2070, epochs=24.81, elapsed=0.11, elbo=-704.23
i=2071, epochs=24.83, elapsed=0.11, elbo=-701.79
i=2072, epochs=24.84, elapsed=0.11, elbo=-703.16
i=2073, epochs=24.85, elapsed=0.11, elbo=-703.05
i=2074, epochs=24.86, elapsed=0.11, elbo=-700.77
i=2075, epochs=24.87, elapsed=0.11, elbo=-702.21
i=2076, epochs=24.89, elapsed=0.11, elbo=-701.85
i=2077, epochs=24.90, elapsed=0.11, elbo=-702.09
i=2078, epochs=24.91, elapsed=0.11, elbo=-701.33
i=2079, epochs=24.92

i=2227, epochs=26.70, elapsed=0.12, elbo=-701.77
i=2228, epochs=26.71, elapsed=0.12, elbo=-703.25
i=2229, epochs=26.72, elapsed=0.12, elbo=-703.67
i=2230, epochs=26.73, elapsed=0.12, elbo=-702.87
i=2231, epochs=26.74, elapsed=0.12, elbo=-703.59
i=2232, epochs=26.76, elapsed=0.12, elbo=-701.90
i=2233, epochs=26.77, elapsed=0.12, elbo=-703.76
i=2234, epochs=26.78, elapsed=0.12, elbo=-702.27
i=2235, epochs=26.79, elapsed=0.12, elbo=-701.70
i=2236, epochs=26.80, elapsed=0.12, elbo=-702.98
i=2237, epochs=26.82, elapsed=0.12, elbo=-702.39
i=2238, epochs=26.83, elapsed=0.12, elbo=-700.75
i=2239, epochs=26.84, elapsed=0.12, elbo=-703.45
i=2240, epochs=26.85, elapsed=0.12, elbo=-703.81
i=2241, epochs=26.86, elapsed=0.12, elbo=-701.65
i=2242, epochs=26.88, elapsed=0.12, elbo=-700.41
i=2243, epochs=26.89, elapsed=0.12, elbo=-703.02
i=2244, epochs=26.90, elapsed=0.12, elbo=-702.74
i=2245, epochs=26.91, elapsed=0.12, elbo=-703.82
i=2246, epochs=26.92, elapsed=0.12, elbo=-702.56
i=2247, epochs=26.94

i=2395, epochs=28.71, elapsed=0.13, elbo=-702.15
i=2396, epochs=28.72, elapsed=0.13, elbo=-702.94
i=2397, epochs=28.73, elapsed=0.13, elbo=-703.22
i=2398, epochs=28.75, elapsed=0.13, elbo=-701.71
i=2399, epochs=28.76, elapsed=0.13, elbo=-701.47
i=2400, epochs=28.77, elapsed=0.13, elbo=-702.57
i=2401, epochs=28.78, elapsed=0.13, elbo=-702.42
i=2402, epochs=28.79, elapsed=0.13, elbo=-703.09
i=2403, epochs=28.81, elapsed=0.13, elbo=-703.56
i=2404, epochs=28.82, elapsed=0.13, elbo=-702.26
i=2405, epochs=28.83, elapsed=0.13, elbo=-702.75
i=2406, epochs=28.84, elapsed=0.13, elbo=-701.25
i=2407, epochs=28.85, elapsed=0.13, elbo=-701.89
i=2408, epochs=28.87, elapsed=0.13, elbo=-702.53
i=2409, epochs=28.88, elapsed=0.13, elbo=-701.80
i=2410, epochs=28.89, elapsed=0.13, elbo=-701.29
i=2411, epochs=28.90, elapsed=0.13, elbo=-703.40
i=2412, epochs=28.91, elapsed=0.13, elbo=-703.65
i=2413, epochs=28.93, elapsed=0.13, elbo=-702.97
i=2414, epochs=28.94, elapsed=0.13, elbo=-702.66
i=2415, epochs=28.95

i=2563, epochs=30.72, elapsed=0.14, elbo=-702.21
i=2564, epochs=30.74, elapsed=0.14, elbo=-702.81
i=2565, epochs=30.75, elapsed=0.14, elbo=-702.34
i=2566, epochs=30.76, elapsed=0.14, elbo=-702.18
i=2567, epochs=30.77, elapsed=0.14, elbo=-701.68
i=2568, epochs=30.78, elapsed=0.14, elbo=-702.76
i=2569, epochs=30.80, elapsed=0.14, elbo=-701.04
i=2570, epochs=30.81, elapsed=0.14, elbo=-702.17
i=2571, epochs=30.82, elapsed=0.14, elbo=-702.32
i=2572, epochs=30.83, elapsed=0.14, elbo=-702.93
i=2573, epochs=30.84, elapsed=0.14, elbo=-703.12
i=2574, epochs=30.86, elapsed=0.14, elbo=-702.96
i=2575, epochs=30.87, elapsed=0.14, elbo=-702.91
i=2576, epochs=30.88, elapsed=0.14, elbo=-701.98
i=2577, epochs=30.89, elapsed=0.14, elbo=-704.10
i=2578, epochs=30.90, elapsed=0.14, elbo=-702.61
i=2579, epochs=30.92, elapsed=0.14, elbo=-702.18
i=2580, epochs=30.93, elapsed=0.14, elbo=-703.21
i=2581, epochs=30.94, elapsed=0.14, elbo=-703.58
i=2582, epochs=30.95, elapsed=0.14, elbo=-701.92
i=2583, epochs=30.96

i=2731, epochs=32.74, elapsed=0.15, elbo=-702.51
i=2732, epochs=32.75, elapsed=0.15, elbo=-701.50
i=2733, epochs=32.76, elapsed=0.15, elbo=-702.34
i=2734, epochs=32.77, elapsed=0.15, elbo=-703.26
i=2735, epochs=32.79, elapsed=0.15, elbo=-701.84
i=2736, epochs=32.80, elapsed=0.15, elbo=-701.56
i=2737, epochs=32.81, elapsed=0.15, elbo=-703.03
i=2738, epochs=32.82, elapsed=0.15, elbo=-702.52
i=2739, epochs=32.83, elapsed=0.15, elbo=-702.88
i=2740, epochs=32.85, elapsed=0.15, elbo=-703.05
i=2741, epochs=32.86, elapsed=0.15, elbo=-702.59
i=2742, epochs=32.87, elapsed=0.15, elbo=-701.23
i=2743, epochs=32.88, elapsed=0.15, elbo=-702.35
i=2744, epochs=32.89, elapsed=0.15, elbo=-702.67
i=2745, epochs=32.91, elapsed=0.15, elbo=-702.12
i=2746, epochs=32.92, elapsed=0.15, elbo=-702.17
i=2747, epochs=32.93, elapsed=0.15, elbo=-701.93
i=2748, epochs=32.94, elapsed=0.15, elbo=-703.01
i=2749, epochs=32.95, elapsed=0.15, elbo=-703.15
i=2750, epochs=32.96, elapsed=0.15, elbo=-703.36
i=2751, epochs=32.98

i=2899, epochs=34.75, elapsed=0.16, elbo=-702.65
i=2900, epochs=34.76, elapsed=0.16, elbo=-702.94
i=2901, epochs=34.78, elapsed=0.16, elbo=-702.68
i=2902, epochs=34.79, elapsed=0.16, elbo=-702.65
i=2903, epochs=34.80, elapsed=0.16, elbo=-703.80
i=2904, epochs=34.81, elapsed=0.16, elbo=-701.42
i=2905, epochs=34.82, elapsed=0.16, elbo=-703.53
i=2906, epochs=34.83, elapsed=0.16, elbo=-702.90
i=2907, epochs=34.85, elapsed=0.16, elbo=-703.50
i=2908, epochs=34.86, elapsed=0.16, elbo=-702.44
i=2909, epochs=34.87, elapsed=0.16, elbo=-703.96
i=2910, epochs=34.88, elapsed=0.16, elbo=-702.85
i=2911, epochs=34.89, elapsed=0.16, elbo=-703.79
i=2912, epochs=34.91, elapsed=0.16, elbo=-702.42
i=2913, epochs=34.92, elapsed=0.16, elbo=-702.50
i=2914, epochs=34.93, elapsed=0.16, elbo=-703.67
i=2915, epochs=34.94, elapsed=0.16, elbo=-702.83
i=2916, epochs=34.95, elapsed=0.16, elbo=-702.01
i=2917, epochs=34.97, elapsed=0.16, elbo=-701.67
i=2918, epochs=34.98, elapsed=0.16, elbo=-702.85
i=2919, epochs=34.99

KeyboardInterrupt: 