In [1]:
import argparse
import math
import os
import time
from functools import partial

import numpy as np
import torch
import visdom

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from components.AIR import AIR, latents_to_tensor
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from utils.viz import draw_many, tensor_to_objs
from utils.visualizer import plot_mnist_sample


## Support functions

In [2]:
def count_accuracy(X, true_counts, air, batch_size):
    assert X.size(0) == true_counts.size(0), 'Size mismatch.'
    assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.'
    counts = torch.LongTensor(3, 4).zero_()
    error_latents = []
    error_indicators = []

    def count_vec_to_mat(vec, max_index):
        out = torch.LongTensor(vec.size(0), max_index + 1).zero_()
        out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1)
        return out

    for i in range(X.size(0) // batch_size):
        X_batch = X[i * batch_size:(i + 1) * batch_size]
        true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size]
        z_where, z_pres = air.guide(X_batch, batch_size)
        inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data
        true_counts_m = count_vec_to_mat(true_counts_batch, 2)
        inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
        counts += torch.mm(true_counts_m.t(), inferred_counts_m)
        error_ind = 1 - (true_counts_batch == inferred_counts)
        error_ix = error_ind.nonzero().squeeze()
        error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix))
        error_indicators.append(error_ind)

    acc = counts.diag().sum().float() / X.size(0)
    error_indices = torch.cat(error_indicators).nonzero().squeeze()
    if X.is_cuda:
        error_indices = error_indices.cuda()
    return acc, counts, torch.cat(error_latents), error_indices


# Defines something like a truncated geometric. Like the geometric,
# this has the property that there's a constant difference in log prob
# between p(steps=n) and p(steps=n+1).
def make_prior(k):
    assert 0 < k <= 1
    u = 1 / (1 + k + k**2 + k**3)
    p0 = 1 - u
    p1 = 1 - (k * u) / p0
    p2 = 1 - (k**2 * u) / (p0 * p1)
    trial_probs = [p0, p1, p2]
    # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
    # print(dist)
    return lambda t: trial_probs[t]


# Implements "prior annealing" as described in this blog post:
# http://akosiorek.github.io/ml/2017/09/03/implementing-air.html

# That implementation does something very close to the following:
# --z-pres-prior (1 - 1e-15)
# --z-pres-prior-raw
# --anneal-prior exp
# --anneal-prior-to 1e-7
# --anneal-prior-begin 1000
# --anneal-prior-duration 1e6

# e.g. After 200K steps z_pres_p will have decayed to ~0.04

# These compute the value of a decaying value at time t.
# initial: initial value
# final: final value, reached after begin + duration steps
# begin: number of steps before decay begins
# duration: number of steps over which decay occurs
# t: current time step


def lin_decay(initial, final, begin, duration, t):
    assert duration > 0
    x = (final - initial) * (t - begin) / duration + initial
    return max(min(x, initial), final)


def exp_decay(initial, final, begin, duration, t):
    assert final > 0
    assert duration > 0
    # half_life = math.log(2) / math.log(initial / final) * duration
    decay_rate = math.log(initial / final) / duration
    x = initial * math.exp(-decay_rate * (t - begin))
    return max(min(x, initial), final)

## Args

In [3]:
parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
                    help='number of optimization steps to take')
parser.add_argument('-b', '--batch-size', type=int, default=64,
                    help='batch size')
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3,
                    help='baseline learning rate')
parser.add_argument('--progress-every', type=int, default=1,
                    help='number of steps between writing progress to stdout')
parser.add_argument('--eval-every', type=int, default=0,
                    help='number of steps between evaluations')
parser.add_argument('--baseline-scalar', type=float,
                    help='scale the output of the baseline nets by this value')
parser.add_argument('--no-baselines', action='store_true', default=False,
                    help='do not use data dependent baselines')
parser.add_argument('--encoder-net', type=int, nargs='+', default=[200],
                    help='encoder net hidden layer sizes')
parser.add_argument('--decoder-net', type=int, nargs='+', default=[200],
                    help='decoder net hidden layer sizes')
parser.add_argument('--predict-net', type=int, nargs='+',
                    help='predict net hidden layer sizes')
parser.add_argument('--embed-net', type=int, nargs='+',
                    help='embed net architecture')
parser.add_argument('--bl-predict-net', type=int, nargs='+',
                    help='baseline predict net hidden layer sizes')
parser.add_argument('--non-linearity', type=str,
                    help='non linearity to use throughout')
parser.add_argument('--viz', action='store_true', default=True,
                    help='generate vizualizations during optimization')
parser.add_argument('--viz-every', type=int, default=2,
                    help='number of steps between vizualizations')
parser.add_argument('--visdom-env', default='main',
                    help='visdom enviroment name')
parser.add_argument('--load', type=str,
                    help='load previously saved parameters')
parser.add_argument('--save', type=str,
                    help='save parameters to specified file')
parser.add_argument('--save-every', type=int, default=1e4,
                    help='number of steps between parameter saves')
parser.add_argument('--cuda', action='store_true', default=False,
                    help='use cuda')
parser.add_argument('--jit', action='store_true', default=False,
                    help='use PyTorch jit')
parser.add_argument('-t', '--model-steps', type=int, default=3,
                    help='number of time steps')
parser.add_argument('--rnn-hidden-size', type=int, default=256,
                    help='rnn hidden size')
parser.add_argument('--encoder-latent-size', type=int, default=50,
                    help='attention window encoder/decoder latent space size')
parser.add_argument('--decoder-output-bias', type=float,
                    help='bias added to decoder output (prior to applying non-linearity)')
parser.add_argument('--decoder-output-use-sigmoid', action='store_true',
                    help='apply sigmoid function to output of decoder network')
parser.add_argument('--window-size', type=int, default=60,
                    help='attention window size')
parser.add_argument('--z-pres-prior', type=float, default=0.5,
                    help='prior success probability for z_pres')
parser.add_argument('--z-pres-prior-raw', action='store_true', default=False,
                    help='use --z-pres-prior directly as success prob instead of a geometric like prior')
parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none',
                    help='anneal z_pres prior during optimization')
parser.add_argument('--anneal-prior-to', type=float, default=1e-7,
                    help='target z_pres prior prob')
parser.add_argument('--anneal-prior-begin', type=int, default=0,
                    help='number of steps to wait before beginning to anneal the prior')
parser.add_argument('--anneal-prior-duration', type=int, default=100000,
                    help='number of steps over which to anneal the prior')
parser.add_argument('--pos-prior-mean', type=float,
                    help='mean of the window position prior')
parser.add_argument('--pos-prior-sd', type=float,
                    help='std. dev. of the window position prior')
parser.add_argument('--scale-prior-mean', type=float,
                    help='mean of the window scale prior')
parser.add_argument('--scale-prior-sd', type=float,
                    help='std. dev. of the window scale prior')
parser.add_argument('--no-masking', action='store_true', default=False,
                    help='do not mask out the costs of unused choices')
parser.add_argument('--seed', type=int, help='random seed', default=None)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='write hyper parameters and network architecture to stdout')

_StoreTrueAction(option_strings=['-v', '--verbose'], dest='verbose', nargs=0, const=True, default=False, type=None, choices=None, help='write hyper parameters and network architecture to stdout', metavar=None)

In [4]:
# vars(parser.parse_args(""))
args = argparse.Namespace(**vars(parser.parse_args("")))

In [5]:
if 'save' in args:
    if os.path.exists(args.save):
        raise RuntimeError('Output file "{}" already exists.'.format(args.save))

if args.seed is not None:
    pyro.set_rng_seed(args.seed)

# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
    def base_z_pres_prior_p(t):
        return args.z_pres_prior
else:
    base_z_pres_prior_p = make_prior(args.z_pres_prior)

# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
    p = base_z_pres_prior_p(time_step)
    if args.anneal_prior == 'none':
        return p
    else:
        decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
        return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                     args.anneal_prior_duration, opt_step)



## Data loader

In [6]:
def load_data():
#     inpath = './air/.data'
#     X_np, Y = multi_mnist.load(inpath)
    X_np = np.load('/Users/chamathabeysinghe/Projects/monash/test/variational_auto_encoder/data/ANTS2/masks_400x400.npy')
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([1 for objs in X_np])
    return X, counts

In [9]:
X, true_counts = load_data()
X_size = X.size(0)
if args.cuda:
    X = X.cuda()

## Model

In [10]:
model_arg_keys = ['window_size',
                  'rnn_hidden_size',
                  'decoder_output_bias',
                  'decoder_output_use_sigmoid',
                  'baseline_scalar',
                  'encoder_net',
                  'decoder_net',
                  'predict_net',
                  'embed_net',
                  'bl_predict_net',
                  'non_linearity',
                  'pos_prior_mean',
                  'pos_prior_sd',
                  'scale_prior_mean',
                  'scale_prior_sd']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}

In [11]:
air = AIR(
        num_steps=args.model_steps,
        x_size=400,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

In [12]:
if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

## Visualize

In [13]:
if args.viz:
    vis = visdom.Visdom(env=args.visdom_env)
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

Setting up a new session...


## Optimizer

In [14]:
def isBaselineParam(module_name, param_name):
    return 'bl_' in module_name or 'bl_' in param_name

def per_param_optim_args(module_name, param_name):
    lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
    return {'lr': lr}

adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)

In [None]:
t0 = time.time()
examples_to_viz = X[5:10]

for i in range(1, args.num_steps + 1):

    loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

    if args.progress_every > 0 and i % args.progress_every == 0:
        print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
            i,
            (i * args.batch_size) / X_size,
            (time.time() - t0) / 3600,
            loss / X_size))

    if args.viz and i % args.viz_every == 0:
        print('Drawing')
        trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
        z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
        z_wheres = tensor_to_objs(latents_to_tensor(z))

        # Show data with inferred objection positions.
        vis.images(draw_many(examples_to_viz, z_wheres))
        # Show reconstructions of data.
        vis.images(draw_many(recons, z_wheres))

    if args.eval_every > 0 and i % args.eval_every == 0:
        # Measure accuracy on subset of training data.
        acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
        print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
        if args.viz and error_ix.size(0) > 0:
            vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                       opts=dict(caption='errors ({})'.format(i)))

    if 'save' in args and i % args.save_every == 0:
        print('Saving parameters...')
        torch.save(air.state_dict(), args.save)

i=1, epochs=0.01, elapsed=0.00, elbo=-43560.35
i=2, epochs=0.02, elapsed=0.01, elbo=-43873.53
Drawing
i=3, epochs=0.04, elapsed=0.01, elbo=-43901.61
i=4, epochs=0.05, elapsed=0.01, elbo=-44116.16
Drawing
i=5, epochs=0.06, elapsed=0.02, elbo=-44038.20
i=6, epochs=0.07, elapsed=0.02, elbo=-44024.16
Drawing
i=7, epochs=0.08, elapsed=0.02, elbo=-44176.02
i=8, epochs=0.10, elapsed=0.02, elbo=-44100.34
Drawing
i=9, epochs=0.11, elapsed=0.03, elbo=-44242.08
i=10, epochs=0.12, elapsed=0.03, elbo=-43942.19
Drawing
i=11, epochs=0.13, elapsed=0.03, elbo=-43981.84
i=12, epochs=0.14, elapsed=0.04, elbo=-44189.46
Drawing
i=13, epochs=0.16, elapsed=0.04, elbo=-44119.58
i=14, epochs=0.17, elapsed=0.04, elbo=-44366.00
Drawing
i=15, epochs=0.18, elapsed=0.04, elbo=-44322.22
i=16, epochs=0.19, elapsed=0.05, elbo=-44235.07
Drawing
i=17, epochs=0.20, elapsed=0.05, elbo=-44080.97
i=18, epochs=0.22, elapsed=0.05, elbo=-44302.06
Drawing
i=19, epochs=0.23, elapsed=0.06, elbo=-44056.67
i=20, epochs=0.24, elapse

i=158, epochs=1.89, elapsed=0.48, elbo=-44791.90
Drawing
i=159, epochs=1.91, elapsed=0.48, elbo=-44738.40
i=160, epochs=1.92, elapsed=0.48, elbo=-44658.90
Drawing
i=161, epochs=1.93, elapsed=0.49, elbo=-44830.29
i=162, epochs=1.94, elapsed=0.49, elbo=-44758.25
Drawing
i=163, epochs=1.95, elapsed=0.49, elbo=-44771.63
i=164, epochs=1.97, elapsed=0.50, elbo=-44770.99
Drawing
i=165, epochs=1.98, elapsed=0.50, elbo=-44773.30
i=166, epochs=1.99, elapsed=0.50, elbo=-44812.80
Drawing
i=167, epochs=2.00, elapsed=0.51, elbo=-44778.30
i=168, epochs=2.01, elapsed=0.51, elbo=-44838.49
Drawing
i=169, epochs=2.03, elapsed=0.51, elbo=-44788.57
i=170, epochs=2.04, elapsed=0.51, elbo=-44740.50
Drawing
i=171, epochs=2.05, elapsed=0.52, elbo=-44750.89
i=172, epochs=2.06, elapsed=0.52, elbo=-44828.99
Drawing
i=173, epochs=2.07, elapsed=0.52, elbo=-44790.37
i=174, epochs=2.09, elapsed=0.53, elbo=-44843.07
Drawing
i=175, epochs=2.10, elapsed=0.53, elbo=-44813.33
i=176, epochs=2.11, elapsed=0.53, elbo=-44886.

i=313, epochs=3.75, elapsed=0.94, elbo=-44857.96
i=314, epochs=3.76, elapsed=0.94, elbo=-44860.69
Drawing
i=315, epochs=3.78, elapsed=0.94, elbo=-44876.70
i=316, epochs=3.79, elapsed=0.95, elbo=-44910.30
Drawing
i=317, epochs=3.80, elapsed=0.95, elbo=-44806.77
i=318, epochs=3.81, elapsed=0.95, elbo=-44880.70
Drawing
i=319, epochs=3.82, elapsed=0.96, elbo=-44812.94
i=320, epochs=3.84, elapsed=0.96, elbo=-44816.67
Drawing
i=321, epochs=3.85, elapsed=0.96, elbo=-44923.76
i=322, epochs=3.86, elapsed=0.96, elbo=-44935.16
Drawing
i=323, epochs=3.87, elapsed=0.97, elbo=-44774.40
i=324, epochs=3.88, elapsed=0.97, elbo=-44923.27
Drawing
i=325, epochs=3.90, elapsed=0.97, elbo=-44927.30
i=326, epochs=3.91, elapsed=0.98, elbo=-44830.79
Drawing
i=327, epochs=3.92, elapsed=0.98, elbo=-44822.06
i=328, epochs=3.93, elapsed=0.98, elbo=-44879.51
Drawing
i=329, epochs=3.94, elapsed=0.99, elbo=-44676.35
i=330, epochs=3.96, elapsed=0.99, elbo=-44840.45
Drawing
i=331, epochs=3.97, elapsed=0.99, elbo=-44918.

i=468, epochs=5.61, elapsed=1.49, elbo=-44959.33
Drawing
i=469, epochs=5.62, elapsed=1.49, elbo=-44951.52
i=470, epochs=5.63, elapsed=1.49, elbo=-44906.79
Drawing
i=471, epochs=5.65, elapsed=1.50, elbo=-44813.09
i=472, epochs=5.66, elapsed=1.50, elbo=-44896.76
Drawing
i=473, epochs=5.67, elapsed=1.50, elbo=-44950.56
i=474, epochs=5.68, elapsed=1.51, elbo=-44893.75
Drawing
i=475, epochs=5.69, elapsed=1.51, elbo=-44908.28
i=476, epochs=5.71, elapsed=1.51, elbo=-44882.53
Drawing
i=477, epochs=5.72, elapsed=1.52, elbo=-44900.23
i=478, epochs=5.73, elapsed=1.52, elbo=-44881.21
Drawing
i=479, epochs=5.74, elapsed=1.52, elbo=-44944.81
i=480, epochs=5.75, elapsed=1.53, elbo=-44934.93
Drawing
i=481, epochs=5.77, elapsed=1.53, elbo=-44925.61
i=482, epochs=5.78, elapsed=1.53, elbo=-44928.10
Drawing
i=483, epochs=5.79, elapsed=1.54, elbo=-44857.07
i=484, epochs=5.80, elapsed=1.54, elbo=-44937.80
Drawing
i=485, epochs=5.81, elapsed=1.54, elbo=-44945.33
i=486, epochs=5.83, elapsed=1.55, elbo=-44932.

i=623, epochs=7.47, elapsed=2.00, elbo=-44953.48
i=624, epochs=7.48, elapsed=2.00, elbo=-44963.07
Drawing
i=625, epochs=7.49, elapsed=2.00, elbo=-44982.10
i=626, epochs=7.50, elapsed=2.00, elbo=-44978.17
Drawing
i=627, epochs=7.52, elapsed=2.01, elbo=-45005.61
i=628, epochs=7.53, elapsed=2.01, elbo=-45004.00
Drawing
i=629, epochs=7.54, elapsed=2.01, elbo=-44923.31
i=630, epochs=7.55, elapsed=2.02, elbo=-44918.75
Drawing
i=631, epochs=7.56, elapsed=2.02, elbo=-44856.84
i=632, epochs=7.58, elapsed=2.03, elbo=-44969.14
Drawing
i=633, epochs=7.59, elapsed=2.03, elbo=-44971.04
i=634, epochs=7.60, elapsed=2.03, elbo=-44956.41
Drawing
i=635, epochs=7.61, elapsed=2.04, elbo=-44914.60
i=636, epochs=7.62, elapsed=2.04, elbo=-44912.02
Drawing
i=637, epochs=7.64, elapsed=2.04, elbo=-44715.48
i=638, epochs=7.65, elapsed=2.05, elbo=-44935.41
Drawing
i=639, epochs=7.66, elapsed=2.05, elbo=-44904.45
i=640, epochs=7.67, elapsed=2.05, elbo=-44959.68
Drawing
i=641, epochs=7.68, elapsed=2.06, elbo=-44885.

i=778, epochs=9.33, elapsed=2.49, elbo=-44994.42
Drawing
i=779, epochs=9.34, elapsed=2.50, elbo=-45011.38
i=780, epochs=9.35, elapsed=2.50, elbo=-44928.26
Drawing
i=781, epochs=9.36, elapsed=2.50, elbo=-44981.91
i=782, epochs=9.37, elapsed=2.51, elbo=-44963.61
Drawing
i=783, epochs=9.39, elapsed=2.51, elbo=-45054.08
i=784, epochs=9.40, elapsed=2.51, elbo=-44958.91
Drawing
i=785, epochs=9.41, elapsed=2.51, elbo=-44959.98
i=786, epochs=9.42, elapsed=2.52, elbo=-44952.24
Drawing
i=787, epochs=9.43, elapsed=2.52, elbo=-44982.70
i=788, epochs=9.45, elapsed=2.52, elbo=-45014.13
Drawing
i=789, epochs=9.46, elapsed=2.53, elbo=-44993.20
i=790, epochs=9.47, elapsed=2.53, elbo=-44937.03
Drawing
i=791, epochs=9.48, elapsed=2.53, elbo=-44737.04
i=792, epochs=9.49, elapsed=2.53, elbo=-44895.73
Drawing
i=793, epochs=9.51, elapsed=2.54, elbo=-44966.33
i=794, epochs=9.52, elapsed=2.54, elbo=-44919.77
Drawing
i=795, epochs=9.53, elapsed=2.54, elbo=-44978.39
i=796, epochs=9.54, elapsed=2.55, elbo=-44967.

i=931, epochs=11.16, elapsed=2.97, elbo=-44948.08
i=932, epochs=11.17, elapsed=2.97, elbo=-45003.21
Drawing
i=933, epochs=11.18, elapsed=2.97, elbo=-44988.36
i=934, epochs=11.20, elapsed=2.98, elbo=-44973.89
Drawing
i=935, epochs=11.21, elapsed=2.98, elbo=-45050.95
i=936, epochs=11.22, elapsed=2.98, elbo=-44963.52
Drawing
i=937, epochs=11.23, elapsed=2.99, elbo=-44964.26
i=938, epochs=11.24, elapsed=2.99, elbo=-44862.75
Drawing
i=939, epochs=11.26, elapsed=2.99, elbo=-44933.00
i=940, epochs=11.27, elapsed=3.00, elbo=-44949.19
Drawing
i=941, epochs=11.28, elapsed=3.00, elbo=-45004.84
i=942, epochs=11.29, elapsed=3.00, elbo=-44989.47
Drawing
i=943, epochs=11.30, elapsed=3.01, elbo=-44978.84
i=944, epochs=11.32, elapsed=3.01, elbo=-45004.23
Drawing
i=945, epochs=11.33, elapsed=3.01, elbo=-44912.36
i=946, epochs=11.34, elapsed=3.02, elbo=-44880.75
Drawing
i=947, epochs=11.35, elapsed=3.02, elbo=-44954.48
i=948, epochs=11.36, elapsed=3.02, elbo=-44994.72
Drawing
i=949, epochs=11.38, elapsed

i=1082, epochs=12.97, elapsed=3.45, elbo=-44915.89
Drawing
i=1083, epochs=12.98, elapsed=3.46, elbo=-44966.26
i=1084, epochs=12.99, elapsed=3.46, elbo=-45002.14
Drawing
i=1085, epochs=13.01, elapsed=3.46, elbo=-44998.01
i=1086, epochs=13.02, elapsed=3.46, elbo=-45005.03
Drawing
i=1087, epochs=13.03, elapsed=3.47, elbo=-44925.43
i=1088, epochs=13.04, elapsed=3.47, elbo=-45023.30
Drawing
i=1089, epochs=13.05, elapsed=3.47, elbo=-45009.29
i=1090, epochs=13.07, elapsed=3.47, elbo=-45003.83
Drawing
i=1091, epochs=13.08, elapsed=3.48, elbo=-44978.92
i=1092, epochs=13.09, elapsed=3.48, elbo=-44961.74
Drawing
i=1093, epochs=13.10, elapsed=3.48, elbo=-45010.18
i=1094, epochs=13.11, elapsed=3.49, elbo=-44972.11
Drawing
i=1095, epochs=13.13, elapsed=3.49, elbo=-45033.50
i=1096, epochs=13.14, elapsed=3.49, elbo=-44920.08
Drawing
i=1097, epochs=13.15, elapsed=3.50, elbo=-44944.94
i=1098, epochs=13.16, elapsed=3.50, elbo=-44983.97
Drawing
i=1099, epochs=13.17, elapsed=3.50, elbo=-45039.42
i=1100, ep

i=1231, epochs=14.76, elapsed=3.96, elbo=-44907.67
i=1232, epochs=14.77, elapsed=3.96, elbo=-45000.95
Drawing
i=1233, epochs=14.78, elapsed=3.96, elbo=-45011.24
i=1234, epochs=14.79, elapsed=3.97, elbo=-44889.27
Drawing
i=1235, epochs=14.80, elapsed=3.97, elbo=-45012.48
i=1236, epochs=14.82, elapsed=3.97, elbo=-44899.52
Drawing
i=1237, epochs=14.83, elapsed=3.98, elbo=-44822.24
i=1238, epochs=14.84, elapsed=3.98, elbo=-45018.18
Drawing
i=1239, epochs=14.85, elapsed=3.98, elbo=-44956.92
i=1240, epochs=14.86, elapsed=3.99, elbo=-44953.94
Drawing
i=1241, epochs=14.88, elapsed=3.99, elbo=-44950.40
i=1242, epochs=14.89, elapsed=3.99, elbo=-44991.88
Drawing
i=1243, epochs=14.90, elapsed=4.00, elbo=-44921.27
i=1244, epochs=14.91, elapsed=4.00, elbo=-44969.76
Drawing
i=1245, epochs=14.92, elapsed=4.00, elbo=-45011.23
i=1246, epochs=14.94, elapsed=4.01, elbo=-44968.22
Drawing
i=1247, epochs=14.95, elapsed=4.01, elbo=-44986.15
i=1248, epochs=14.96, elapsed=4.01, elbo=-44994.56
Drawing
i=1249, ep

i=1381, epochs=16.55, elapsed=4.51, elbo=-44955.58
i=1382, epochs=16.57, elapsed=4.52, elbo=-44960.39
Drawing
i=1383, epochs=16.58, elapsed=4.52, elbo=-45052.80
i=1384, epochs=16.59, elapsed=4.53, elbo=-45000.32
Drawing
i=1385, epochs=16.60, elapsed=4.54, elbo=-44856.82
i=1386, epochs=16.61, elapsed=4.55, elbo=-44992.67
Drawing
i=1387, epochs=16.63, elapsed=4.55, elbo=-44952.75
i=1388, epochs=16.64, elapsed=4.55, elbo=-44947.26
Drawing
i=1389, epochs=16.65, elapsed=4.56, elbo=-45041.94
i=1390, epochs=16.66, elapsed=4.56, elbo=-44978.85
Drawing
i=1391, epochs=16.67, elapsed=4.56, elbo=-44872.66
i=1392, epochs=16.69, elapsed=4.56, elbo=-44938.47
Drawing
i=1393, epochs=16.70, elapsed=4.57, elbo=-44979.47
i=1394, epochs=16.71, elapsed=4.57, elbo=-45010.92
Drawing
i=1395, epochs=16.72, elapsed=4.57, elbo=-44970.14
i=1396, epochs=16.73, elapsed=4.58, elbo=-44965.30
Drawing
i=1397, epochs=16.75, elapsed=4.58, elbo=-44965.99
i=1398, epochs=16.76, elapsed=4.58, elbo=-44945.43
Drawing
i=1399, ep

i=1531, epochs=18.35, elapsed=4.99, elbo=-44957.55
i=1532, epochs=18.36, elapsed=4.99, elbo=-44908.59
Drawing
i=1533, epochs=18.38, elapsed=5.00, elbo=-44708.21
i=1534, epochs=18.39, elapsed=5.00, elbo=-44950.59
Drawing
i=1535, epochs=18.40, elapsed=5.00, elbo=-45022.69
i=1536, epochs=18.41, elapsed=5.00, elbo=-45044.11
Drawing
i=1537, epochs=18.42, elapsed=5.01, elbo=-44974.80
i=1538, epochs=18.44, elapsed=5.01, elbo=-44985.07
Drawing
i=1539, epochs=18.45, elapsed=5.01, elbo=-45046.37
i=1540, epochs=18.46, elapsed=5.02, elbo=-45000.11
Drawing
i=1541, epochs=18.47, elapsed=5.02, elbo=-45026.55
i=1542, epochs=18.48, elapsed=5.02, elbo=-45014.04
Drawing
i=1543, epochs=18.50, elapsed=5.03, elbo=-44979.13
i=1544, epochs=18.51, elapsed=5.03, elbo=-44979.49
Drawing
i=1545, epochs=18.52, elapsed=5.03, elbo=-44947.18
i=1546, epochs=18.53, elapsed=5.03, elbo=-44992.85
Drawing
i=1547, epochs=18.54, elapsed=5.04, elbo=-45030.21
i=1548, epochs=18.56, elapsed=5.04, elbo=-45004.57
Drawing
i=1549, ep