In [1]:
import argparse
import math
import os
import time
from functools import partial

import numpy as np
import torch
import visdom

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from components.AIR import AIR, latents_to_tensor
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from utils.viz import draw_many, tensor_to_objs
from utils.visualizer import plot_mnist_sample
from utils.visualizer import plot_mnist_sample



## Support functions

In [2]:
def count_accuracy(X, true_counts, air, batch_size):
    assert X.size(0) == true_counts.size(0), 'Size mismatch.'
    assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.'
    counts = torch.LongTensor(3, 4).zero_()
    error_latents = []
    error_indicators = []

    def count_vec_to_mat(vec, max_index):
        out = torch.LongTensor(vec.size(0), max_index + 1).zero_()
        out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1)
        return out

    for i in range(X.size(0) // batch_size):
        X_batch = X[i * batch_size:(i + 1) * batch_size]
        true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size]
        z_where, z_pres = air.guide(X_batch, batch_size)
        inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data
        true_counts_m = count_vec_to_mat(true_counts_batch, 2)
        inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
        counts += torch.mm(true_counts_m.t(), inferred_counts_m)
        error_ind = 1 - (true_counts_batch == inferred_counts)
        error_ix = error_ind.nonzero().squeeze()
        error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix))
        error_indicators.append(error_ind)

    acc = counts.diag().sum().float() / X.size(0)
    error_indices = torch.cat(error_indicators).nonzero().squeeze()
    if X.is_cuda:
        error_indices = error_indices.cuda()
    return acc, counts, torch.cat(error_latents), error_indices


# Defines something like a truncated geometric. Like the geometric,
# this has the property that there's a constant difference in log prob
# between p(steps=n) and p(steps=n+1).
def make_prior(k):
    assert 0 < k <= 1
    u = 1 / (1 + k + k**2 + k**3)
    p0 = 1 - u
    p1 = 1 - (k * u) / p0
    p2 = 1 - (k**2 * u) / (p0 * p1)
    trial_probs = [p0, p1, p2]
    # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
    # print(dist)
    return lambda t: trial_probs[t]


# Implements "prior annealing" as described in this blog post:
# http://akosiorek.github.io/ml/2017/09/03/implementing-air.html

# That implementation does something very close to the following:
# --z-pres-prior (1 - 1e-15)
# --z-pres-prior-raw
# --anneal-prior exp
# --anneal-prior-to 1e-7
# --anneal-prior-begin 1000
# --anneal-prior-duration 1e6

# e.g. After 200K steps z_pres_p will have decayed to ~0.04

# These compute the value of a decaying value at time t.
# initial: initial value
# final: final value, reached after begin + duration steps
# begin: number of steps before decay begins
# duration: number of steps over which decay occurs
# t: current time step


def lin_decay(initial, final, begin, duration, t):
    assert duration > 0
    x = (final - initial) * (t - begin) / duration + initial
    return max(min(x, initial), final)


def exp_decay(initial, final, begin, duration, t):
    assert final > 0
    assert duration > 0
    # half_life = math.log(2) / math.log(initial / final) * duration
    decay_rate = math.log(initial / final) / duration
    x = initial * math.exp(-decay_rate * (t - begin))
    return max(min(x, initial), final)

## Args

In [3]:
parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
                    help='number of optimization steps to take')
parser.add_argument('-b', '--batch-size', type=int, default=64,
                    help='batch size')
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3,
                    help='baseline learning rate')
parser.add_argument('--progress-every', type=int, default=1,
                    help='number of steps between writing progress to stdout')
parser.add_argument('--eval-every', type=int, default=0,
                    help='number of steps between evaluations')
parser.add_argument('--baseline-scalar', type=float,
                    help='scale the output of the baseline nets by this value')
parser.add_argument('--no-baselines', action='store_true', default=False,
                    help='do not use data dependent baselines')
parser.add_argument('--encoder-net', type=int, nargs='+', default=[200],
                    help='encoder net hidden layer sizes')
parser.add_argument('--decoder-net', type=int, nargs='+', default=[200],
                    help='decoder net hidden layer sizes')
parser.add_argument('--predict-net', type=int, nargs='+',
                    help='predict net hidden layer sizes')
parser.add_argument('--embed-net', type=int, nargs='+',
                    help='embed net architecture')
parser.add_argument('--bl-predict-net', type=int, nargs='+',
                    help='baseline predict net hidden layer sizes')
parser.add_argument('--non-linearity', type=str,
                    help='non linearity to use throughout')
parser.add_argument('--viz', action='store_true', default=True,
                    help='generate vizualizations during optimization')
parser.add_argument('--viz-every', type=int, default=100,
                    help='number of steps between vizualizations')
parser.add_argument('--visdom-env', default='main',
                    help='visdom enviroment name')
parser.add_argument('--load', type=str,
                    help='load previously saved parameters')
parser.add_argument('--save', type=str,
                    help='save parameters to specified file')
parser.add_argument('--save-every', type=int, default=1e4,
                    help='number of steps between parameter saves')
parser.add_argument('--cuda', action='store_true', default=False,
                    help='use cuda')
parser.add_argument('--jit', action='store_true', default=False,
                    help='use PyTorch jit')
parser.add_argument('-t', '--model-steps', type=int, default=3,
                    help='number of time steps')
parser.add_argument('--rnn-hidden-size', type=int, default=256,
                    help='rnn hidden size')
parser.add_argument('--encoder-latent-size', type=int, default=50,
                    help='attention window encoder/decoder latent space size')
parser.add_argument('--decoder-output-bias', type=float,
                    help='bias added to decoder output (prior to applying non-linearity)')
parser.add_argument('--decoder-output-use-sigmoid', action='store_true',
                    help='apply sigmoid function to output of decoder network')
parser.add_argument('--window-size', type=int, default=28,
                    help='attention window size')
parser.add_argument('--z-pres-prior', type=float, default=0.5,
                    help='prior success probability for z_pres')
parser.add_argument('--z-pres-prior-raw', action='store_true', default=False,
                    help='use --z-pres-prior directly as success prob instead of a geometric like prior')
parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none',
                    help='anneal z_pres prior during optimization')
parser.add_argument('--anneal-prior-to', type=float, default=1e-7,
                    help='target z_pres prior prob')
parser.add_argument('--anneal-prior-begin', type=int, default=0,
                    help='number of steps to wait before beginning to anneal the prior')
parser.add_argument('--anneal-prior-duration', type=int, default=100000,
                    help='number of steps over which to anneal the prior')
parser.add_argument('--pos-prior-mean', type=float,
                    help='mean of the window position prior')
parser.add_argument('--pos-prior-sd', type=float,
                    help='std. dev. of the window position prior')
parser.add_argument('--scale-prior-mean', type=float,
                    help='mean of the window scale prior')
parser.add_argument('--scale-prior-sd', type=float,
                    help='std. dev. of the window scale prior')
parser.add_argument('--no-masking', action='store_true', default=False,
                    help='do not mask out the costs of unused choices')
parser.add_argument('--seed', type=int, help='random seed', default=None)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='write hyper parameters and network architecture to stdout')

_StoreTrueAction(option_strings=['-v', '--verbose'], dest='verbose', nargs=0, const=True, default=False, type=None, choices=None, help='write hyper parameters and network architecture to stdout', metavar=None)

In [4]:
# vars(parser.parse_args(""))
args = argparse.Namespace(**vars(parser.parse_args("")))

In [5]:
if 'save' in args:
    if os.path.exists(args.save):
        raise RuntimeError('Output file "{}" already exists.'.format(args.save))

if args.seed is not None:
    pyro.set_rng_seed(args.seed)

# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
    def base_z_pres_prior_p(t):
        return args.z_pres_prior
else:
    base_z_pres_prior_p = make_prior(args.z_pres_prior)

# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
    p = base_z_pres_prior_p(time_step)
    if args.anneal_prior == 'none':
        return p
    else:
        decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
        return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                     args.anneal_prior_duration, opt_step)



## Data loader

In [8]:
def load_data():
#     inpath = './air/.data'
#     X_np, Y = multi_mnist.load(inpath)
    X_np = np.load('/Users/chamathabeysinghe/Projects/monash/test/variational_auto_encoder/data/synthetic/smaill_size-150_rad-5.npy')
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([1 for objs in X_np])
    return X, counts

In [9]:
#     X_np = np.load('/Users/chamathabeysinghe/Projects/monash/test/variational_auto_encoder/data/synthetic/size-50_rad-5.npy')

X, true_counts = load_data()
X_size = X.size(0)
if args.cuda:
    X = X.cuda()
    
    

In [10]:
X.shape
# plot_mnist_sample(X[2])

torch.Size([10000, 150, 150])

## Model

In [11]:
model_arg_keys = ['window_size',
                  'rnn_hidden_size',
                  'decoder_output_bias',
                  'decoder_output_use_sigmoid',
                  'baseline_scalar',
                  'encoder_net',
                  'decoder_net',
                  'predict_net',
                  'embed_net',
                  'bl_predict_net',
                  'non_linearity',
                  'pos_prior_mean',
                  'pos_prior_sd',
                  'scale_prior_mean',
                  'scale_prior_sd']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}

In [12]:
air = AIR(
        num_steps=args.model_steps,
        x_size=150,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

In [13]:
if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

## Visualize

In [14]:
if args.viz:
    vis = visdom.Visdom(env=args.visdom_env)
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

Setting up a new session...


## Optimizer

In [15]:
def isBaselineParam(module_name, param_name):
    return 'bl_' in module_name or 'bl_' in param_name

def per_param_optim_args(module_name, param_name):
    lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
    return {'lr': lr}

adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)

In [16]:
t0 = time.time()
examples_to_viz = X[5:10]

for i in range(1, args.num_steps + 1):

    loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

    if args.progress_every > 0 and i % args.progress_every == 0:
        print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
            i,
            (i * args.batch_size) / X_size,
            (time.time() - t0) / 3600,
            loss / X_size))

    if args.viz and i % args.viz_every == 0:
        print('Drawing')
        trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
        z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
        z_wheres = tensor_to_objs(latents_to_tensor(z))

        # Show data with inferred objection positions.
        vis.images(draw_many(examples_to_viz, z_wheres))
        # Show reconstructions of data.
        vis.images(draw_many(recons, z_wheres))

    if args.eval_every > 0 and i % args.eval_every == 0:
        # Measure accuracy on subset of training data.
        acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
        print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
        if args.viz and error_ix.size(0) > 0:
            vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                       opts=dict(caption='errors ({})'.format(i)))

    if 'save' in args and i % args.save_every == 0:
        print('Saving parameters...')
        torch.save(air.state_dict(), args.save)

i=1, epochs=0.01, elapsed=0.00, elbo=-5980.85
i=2, epochs=0.01, elapsed=0.00, elbo=-5974.87
i=3, epochs=0.02, elapsed=0.00, elbo=-5910.69
i=4, epochs=0.03, elapsed=0.00, elbo=-6041.07
i=5, epochs=0.03, elapsed=0.00, elbo=-5935.86
i=6, epochs=0.04, elapsed=0.00, elbo=-6002.75
i=7, epochs=0.04, elapsed=0.00, elbo=-5948.34
i=8, epochs=0.05, elapsed=0.00, elbo=-5942.32
i=9, epochs=0.06, elapsed=0.00, elbo=-5992.71
i=10, epochs=0.06, elapsed=0.00, elbo=-5939.23
i=11, epochs=0.07, elapsed=0.00, elbo=-6000.89
i=12, epochs=0.08, elapsed=0.00, elbo=-6030.49
i=13, epochs=0.08, elapsed=0.00, elbo=-5895.71
i=14, epochs=0.09, elapsed=0.00, elbo=-5973.95
i=15, epochs=0.10, elapsed=0.00, elbo=-5972.18
i=16, epochs=0.10, elapsed=0.00, elbo=-6014.98
i=17, epochs=0.11, elapsed=0.00, elbo=-5929.12
i=18, epochs=0.12, elapsed=0.01, elbo=-6042.53
i=19, epochs=0.12, elapsed=0.01, elbo=-6003.73
i=20, epochs=0.13, elapsed=0.01, elbo=-6025.98
i=21, epochs=0.13, elapsed=0.01, elbo=-5999.86
i=22, epochs=0.14, ela

i=174, epochs=1.11, elapsed=0.05, elbo=-6075.88
i=175, epochs=1.12, elapsed=0.05, elbo=-6096.17
i=176, epochs=1.13, elapsed=0.05, elbo=-6090.50
i=177, epochs=1.13, elapsed=0.05, elbo=-6086.23
i=178, epochs=1.14, elapsed=0.05, elbo=-6055.39
i=179, epochs=1.15, elapsed=0.05, elbo=-6107.24
i=180, epochs=1.15, elapsed=0.05, elbo=-6052.59
i=181, epochs=1.16, elapsed=0.05, elbo=-6045.95
i=182, epochs=1.16, elapsed=0.05, elbo=-6068.63
i=183, epochs=1.17, elapsed=0.05, elbo=-6103.47
i=184, epochs=1.18, elapsed=0.05, elbo=-6006.92
i=185, epochs=1.18, elapsed=0.05, elbo=-6084.17
i=186, epochs=1.19, elapsed=0.05, elbo=-6067.16
i=187, epochs=1.20, elapsed=0.06, elbo=-6040.41
i=188, epochs=1.20, elapsed=0.06, elbo=-6049.00
i=189, epochs=1.21, elapsed=0.06, elbo=-6116.11
i=190, epochs=1.22, elapsed=0.06, elbo=-6095.85
i=191, epochs=1.22, elapsed=0.06, elbo=-6087.33
i=192, epochs=1.23, elapsed=0.06, elbo=-6054.59
i=193, epochs=1.24, elapsed=0.06, elbo=-6062.66
i=194, epochs=1.24, elapsed=0.06, elbo=-

i=345, epochs=2.21, elapsed=0.10, elbo=-6048.24
i=346, epochs=2.21, elapsed=0.10, elbo=-6083.41
i=347, epochs=2.22, elapsed=0.10, elbo=-6104.36
i=348, epochs=2.23, elapsed=0.10, elbo=-6064.01
i=349, epochs=2.23, elapsed=0.10, elbo=-6112.11
i=350, epochs=2.24, elapsed=0.10, elbo=-6058.60
i=351, epochs=2.25, elapsed=0.10, elbo=-6072.73
i=352, epochs=2.25, elapsed=0.10, elbo=-6040.37
i=353, epochs=2.26, elapsed=0.10, elbo=-6077.85
i=354, epochs=2.27, elapsed=0.10, elbo=-6092.15
i=355, epochs=2.27, elapsed=0.10, elbo=-6060.84
i=356, epochs=2.28, elapsed=0.11, elbo=-6072.59
i=357, epochs=2.28, elapsed=0.11, elbo=-6040.69
i=358, epochs=2.29, elapsed=0.11, elbo=-6072.01
i=359, epochs=2.30, elapsed=0.11, elbo=-6118.56
i=360, epochs=2.30, elapsed=0.11, elbo=-5997.87
i=361, epochs=2.31, elapsed=0.11, elbo=-6109.18
i=362, epochs=2.32, elapsed=0.11, elbo=-6154.90
i=363, epochs=2.32, elapsed=0.11, elbo=-6098.61
i=364, epochs=2.33, elapsed=0.11, elbo=-6036.82
i=365, epochs=2.34, elapsed=0.11, elbo=-

i=516, epochs=3.30, elapsed=0.15, elbo=-6117.89
i=517, epochs=3.31, elapsed=0.15, elbo=-6104.64
i=518, epochs=3.32, elapsed=0.15, elbo=-6068.47
i=519, epochs=3.32, elapsed=0.15, elbo=-6109.25
i=520, epochs=3.33, elapsed=0.15, elbo=-6111.25
i=521, epochs=3.33, elapsed=0.16, elbo=-6038.84
i=522, epochs=3.34, elapsed=0.16, elbo=-6065.47
i=523, epochs=3.35, elapsed=0.16, elbo=-6049.77
i=524, epochs=3.35, elapsed=0.16, elbo=-6119.18
i=525, epochs=3.36, elapsed=0.16, elbo=-6110.65
i=526, epochs=3.37, elapsed=0.16, elbo=-6149.15
i=527, epochs=3.37, elapsed=0.16, elbo=-6063.76
i=528, epochs=3.38, elapsed=0.16, elbo=-6020.06
i=529, epochs=3.39, elapsed=0.16, elbo=-6079.76
i=530, epochs=3.39, elapsed=0.16, elbo=-6098.23
i=531, epochs=3.40, elapsed=0.16, elbo=-6137.51
i=532, epochs=3.40, elapsed=0.16, elbo=-6151.36
i=533, epochs=3.41, elapsed=0.16, elbo=-6122.17
i=534, epochs=3.42, elapsed=0.16, elbo=-6054.50
i=535, epochs=3.42, elapsed=0.16, elbo=-6048.66
i=536, epochs=3.43, elapsed=0.16, elbo=-

i=687, epochs=4.40, elapsed=0.22, elbo=-6101.14
i=688, epochs=4.40, elapsed=0.22, elbo=-6179.91
i=689, epochs=4.41, elapsed=0.22, elbo=-6088.23
i=690, epochs=4.42, elapsed=0.22, elbo=-6100.35
i=691, epochs=4.42, elapsed=0.22, elbo=-6106.76
i=692, epochs=4.43, elapsed=0.22, elbo=-6115.37
i=693, epochs=4.44, elapsed=0.22, elbo=-6084.13
i=694, epochs=4.44, elapsed=0.22, elbo=-6101.28
i=695, epochs=4.45, elapsed=0.22, elbo=-6060.24
i=696, epochs=4.45, elapsed=0.22, elbo=-6074.81
i=697, epochs=4.46, elapsed=0.22, elbo=-6111.43
i=698, epochs=4.47, elapsed=0.22, elbo=-6145.64
i=699, epochs=4.47, elapsed=0.22, elbo=-6138.53
i=700, epochs=4.48, elapsed=0.22, elbo=-6077.08
Drawing
i=701, epochs=4.49, elapsed=0.22, elbo=-6051.84
i=702, epochs=4.49, elapsed=0.22, elbo=-6091.96
i=703, epochs=4.50, elapsed=0.22, elbo=-6149.04
i=704, epochs=4.51, elapsed=0.22, elbo=-6148.23
i=705, epochs=4.51, elapsed=0.22, elbo=-6081.72
i=706, epochs=4.52, elapsed=0.22, elbo=-6070.88
i=707, epochs=4.52, elapsed=0.22

i=858, epochs=5.49, elapsed=0.28, elbo=-6063.24
i=859, epochs=5.50, elapsed=0.28, elbo=-6109.42
i=860, epochs=5.50, elapsed=0.28, elbo=-6112.24
i=861, epochs=5.51, elapsed=0.28, elbo=-6115.38
i=862, epochs=5.52, elapsed=0.28, elbo=-6046.29
i=863, epochs=5.52, elapsed=0.28, elbo=-6080.27
i=864, epochs=5.53, elapsed=0.28, elbo=-6093.05
i=865, epochs=5.54, elapsed=0.28, elbo=-6087.72
i=866, epochs=5.54, elapsed=0.28, elbo=-6066.19
i=867, epochs=5.55, elapsed=0.28, elbo=-6099.63
i=868, epochs=5.56, elapsed=0.28, elbo=-6128.37
i=869, epochs=5.56, elapsed=0.28, elbo=-6105.33
i=870, epochs=5.57, elapsed=0.28, elbo=-6054.86
i=871, epochs=5.57, elapsed=0.28, elbo=-6088.04
i=872, epochs=5.58, elapsed=0.28, elbo=-6094.66
i=873, epochs=5.59, elapsed=0.28, elbo=-6074.69
i=874, epochs=5.59, elapsed=0.29, elbo=-6097.92
i=875, epochs=5.60, elapsed=0.29, elbo=-6080.96
i=876, epochs=5.61, elapsed=0.29, elbo=-6133.79
i=877, epochs=5.61, elapsed=0.29, elbo=-6079.52
i=878, epochs=5.62, elapsed=0.29, elbo=-

i=1028, epochs=6.58, elapsed=0.34, elbo=-6074.44
i=1029, epochs=6.59, elapsed=0.34, elbo=-6148.71
i=1030, epochs=6.59, elapsed=0.34, elbo=-6111.15
i=1031, epochs=6.60, elapsed=0.34, elbo=-6077.26
i=1032, epochs=6.60, elapsed=0.34, elbo=-6097.71
i=1033, epochs=6.61, elapsed=0.35, elbo=-6047.96
i=1034, epochs=6.62, elapsed=0.35, elbo=-6089.42
i=1035, epochs=6.62, elapsed=0.35, elbo=-6126.93
i=1036, epochs=6.63, elapsed=0.35, elbo=-6091.85
i=1037, epochs=6.64, elapsed=0.35, elbo=-6041.23
i=1038, epochs=6.64, elapsed=0.35, elbo=-6080.31
i=1039, epochs=6.65, elapsed=0.35, elbo=-6089.29
i=1040, epochs=6.66, elapsed=0.35, elbo=-6117.60
i=1041, epochs=6.66, elapsed=0.35, elbo=-6057.91
i=1042, epochs=6.67, elapsed=0.35, elbo=-6138.38
i=1043, epochs=6.68, elapsed=0.35, elbo=-6106.68
i=1044, epochs=6.68, elapsed=0.35, elbo=-6067.34
i=1045, epochs=6.69, elapsed=0.35, elbo=-6119.80
i=1046, epochs=6.69, elapsed=0.35, elbo=-6106.46
i=1047, epochs=6.70, elapsed=0.35, elbo=-6120.47
i=1048, epochs=6.71,

i=1196, epochs=7.65, elapsed=0.40, elbo=-6045.94
i=1197, epochs=7.66, elapsed=0.40, elbo=-6070.69
i=1198, epochs=7.67, elapsed=0.40, elbo=-6090.20
i=1199, epochs=7.67, elapsed=0.40, elbo=-6074.06
i=1200, epochs=7.68, elapsed=0.40, elbo=-6123.36
Drawing
i=1201, epochs=7.69, elapsed=0.40, elbo=-6121.12
i=1202, epochs=7.69, elapsed=0.40, elbo=-6126.70
i=1203, epochs=7.70, elapsed=0.40, elbo=-6080.17
i=1204, epochs=7.71, elapsed=0.40, elbo=-6114.87
i=1205, epochs=7.71, elapsed=0.40, elbo=-6067.71
i=1206, epochs=7.72, elapsed=0.40, elbo=-6070.99
i=1207, epochs=7.72, elapsed=0.40, elbo=-6129.88
i=1208, epochs=7.73, elapsed=0.40, elbo=-6097.92
i=1209, epochs=7.74, elapsed=0.40, elbo=-6172.10
i=1210, epochs=7.74, elapsed=0.41, elbo=-6062.69
i=1211, epochs=7.75, elapsed=0.41, elbo=-6111.98
i=1212, epochs=7.76, elapsed=0.41, elbo=-6026.27
i=1213, epochs=7.76, elapsed=0.41, elbo=-6088.44
i=1214, epochs=7.77, elapsed=0.41, elbo=-6092.59
i=1215, epochs=7.78, elapsed=0.41, elbo=-6120.65
i=1216, epoc

i=1363, epochs=8.72, elapsed=0.45, elbo=-6080.16
i=1364, epochs=8.73, elapsed=0.45, elbo=-6138.54
i=1365, epochs=8.74, elapsed=0.45, elbo=-6103.73
i=1366, epochs=8.74, elapsed=0.45, elbo=-6093.65
i=1367, epochs=8.75, elapsed=0.45, elbo=-6082.60
i=1368, epochs=8.76, elapsed=0.45, elbo=-6014.20
i=1369, epochs=8.76, elapsed=0.45, elbo=-6102.97
i=1370, epochs=8.77, elapsed=0.45, elbo=-6084.49
i=1371, epochs=8.77, elapsed=0.46, elbo=-6110.47
i=1372, epochs=8.78, elapsed=0.46, elbo=-6106.12
i=1373, epochs=8.79, elapsed=0.46, elbo=-6076.39
i=1374, epochs=8.79, elapsed=0.46, elbo=-6093.91
i=1375, epochs=8.80, elapsed=0.46, elbo=-6034.43
i=1376, epochs=8.81, elapsed=0.46, elbo=-6081.78
i=1377, epochs=8.81, elapsed=0.46, elbo=-6050.90
i=1378, epochs=8.82, elapsed=0.46, elbo=-6129.38
i=1379, epochs=8.83, elapsed=0.46, elbo=-6074.86
i=1380, epochs=8.83, elapsed=0.46, elbo=-6067.18
i=1381, epochs=8.84, elapsed=0.46, elbo=-6122.33
i=1382, epochs=8.84, elapsed=0.46, elbo=-6089.12
i=1383, epochs=8.85,

i=1530, epochs=9.79, elapsed=0.50, elbo=-6114.38
i=1531, epochs=9.80, elapsed=0.50, elbo=-6114.75
i=1532, epochs=9.80, elapsed=0.51, elbo=-6138.75
i=1533, epochs=9.81, elapsed=0.51, elbo=-6089.86
i=1534, epochs=9.82, elapsed=0.51, elbo=-6095.99
i=1535, epochs=9.82, elapsed=0.51, elbo=-6123.00
i=1536, epochs=9.83, elapsed=0.51, elbo=-6102.09
i=1537, epochs=9.84, elapsed=0.51, elbo=-6127.11
i=1538, epochs=9.84, elapsed=0.51, elbo=-6084.63
i=1539, epochs=9.85, elapsed=0.51, elbo=-6115.43
i=1540, epochs=9.86, elapsed=0.51, elbo=-6056.21
i=1541, epochs=9.86, elapsed=0.51, elbo=-6066.62
i=1542, epochs=9.87, elapsed=0.51, elbo=-6115.15
i=1543, epochs=9.88, elapsed=0.51, elbo=-6113.82
i=1544, epochs=9.88, elapsed=0.51, elbo=-6061.98
i=1545, epochs=9.89, elapsed=0.51, elbo=-6079.83
i=1546, epochs=9.89, elapsed=0.51, elbo=-6120.12
i=1547, epochs=9.90, elapsed=0.51, elbo=-6062.53
i=1548, epochs=9.91, elapsed=0.51, elbo=-6087.68
i=1549, epochs=9.91, elapsed=0.51, elbo=-6122.88
i=1550, epochs=9.92,

i=1695, epochs=10.85, elapsed=0.56, elbo=-6147.88
i=1696, epochs=10.85, elapsed=0.56, elbo=-6071.86
i=1697, epochs=10.86, elapsed=0.56, elbo=-6081.06
i=1698, epochs=10.87, elapsed=0.56, elbo=-6123.07
i=1699, epochs=10.87, elapsed=0.56, elbo=-6090.95
i=1700, epochs=10.88, elapsed=0.56, elbo=-6114.12
Drawing
i=1701, epochs=10.89, elapsed=0.56, elbo=-6084.99
i=1702, epochs=10.89, elapsed=0.56, elbo=-6064.67
i=1703, epochs=10.90, elapsed=0.56, elbo=-6141.07
i=1704, epochs=10.91, elapsed=0.56, elbo=-6096.86
i=1705, epochs=10.91, elapsed=0.56, elbo=-6153.00
i=1706, epochs=10.92, elapsed=0.56, elbo=-6143.95
i=1707, epochs=10.92, elapsed=0.56, elbo=-6117.06
i=1708, epochs=10.93, elapsed=0.57, elbo=-6072.44
i=1709, epochs=10.94, elapsed=0.57, elbo=-6093.95
i=1710, epochs=10.94, elapsed=0.57, elbo=-6050.69
i=1711, epochs=10.95, elapsed=0.57, elbo=-6129.36
i=1712, epochs=10.96, elapsed=0.57, elbo=-6080.80
i=1713, epochs=10.96, elapsed=0.57, elbo=-6087.40
i=1714, epochs=10.97, elapsed=0.57, elbo=-

i=1859, epochs=11.90, elapsed=0.62, elbo=-6133.82
i=1860, epochs=11.90, elapsed=0.62, elbo=-6118.07
i=1861, epochs=11.91, elapsed=0.62, elbo=-6103.22
i=1862, epochs=11.92, elapsed=0.62, elbo=-6085.32
i=1863, epochs=11.92, elapsed=0.62, elbo=-6137.84
i=1864, epochs=11.93, elapsed=0.62, elbo=-6122.68
i=1865, epochs=11.94, elapsed=0.62, elbo=-6095.37
i=1866, epochs=11.94, elapsed=0.62, elbo=-6053.05
i=1867, epochs=11.95, elapsed=0.62, elbo=-6119.31
i=1868, epochs=11.96, elapsed=0.62, elbo=-6133.80
i=1869, epochs=11.96, elapsed=0.62, elbo=-6119.34
i=1870, epochs=11.97, elapsed=0.62, elbo=-6070.17
i=1871, epochs=11.97, elapsed=0.62, elbo=-6104.98
i=1872, epochs=11.98, elapsed=0.62, elbo=-6048.36
i=1873, epochs=11.99, elapsed=0.62, elbo=-6143.60
i=1874, epochs=11.99, elapsed=0.62, elbo=-6104.76
i=1875, epochs=12.00, elapsed=0.62, elbo=-6058.32
i=1876, epochs=12.01, elapsed=0.62, elbo=-6104.32
i=1877, epochs=12.01, elapsed=0.62, elbo=-6092.65
i=1878, epochs=12.02, elapsed=0.62, elbo=-6118.26


i=2023, epochs=12.95, elapsed=0.67, elbo=-6084.50
i=2024, epochs=12.95, elapsed=0.67, elbo=-6101.55
i=2025, epochs=12.96, elapsed=0.67, elbo=-6044.98
i=2026, epochs=12.97, elapsed=0.67, elbo=-6046.57
i=2027, epochs=12.97, elapsed=0.67, elbo=-6099.07
i=2028, epochs=12.98, elapsed=0.67, elbo=-6055.35
i=2029, epochs=12.99, elapsed=0.67, elbo=-6073.70
i=2030, epochs=12.99, elapsed=0.67, elbo=-6095.29
i=2031, epochs=13.00, elapsed=0.67, elbo=-6054.68
i=2032, epochs=13.00, elapsed=0.67, elbo=-6091.37
i=2033, epochs=13.01, elapsed=0.67, elbo=-6098.20
i=2034, epochs=13.02, elapsed=0.67, elbo=-6119.62
i=2035, epochs=13.02, elapsed=0.67, elbo=-6096.67
i=2036, epochs=13.03, elapsed=0.67, elbo=-6092.04
i=2037, epochs=13.04, elapsed=0.67, elbo=-6128.34
i=2038, epochs=13.04, elapsed=0.67, elbo=-6123.26
i=2039, epochs=13.05, elapsed=0.67, elbo=-6072.76
i=2040, epochs=13.06, elapsed=0.67, elbo=-6066.84
i=2041, epochs=13.06, elapsed=0.67, elbo=-6092.21
i=2042, epochs=13.07, elapsed=0.67, elbo=-6135.23


KeyboardInterrupt: 