In [1]:
import time

import numpy as np
import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
from pyro.infer import (
    SVI,
    RenyiELBO,
    Trace_ELBO,
    TraceGraph_ELBO,
    TraceEnum_ELBO,
    ReweightedWakeSleep,
)
from pyro.optim import Adam
import pandas as pd

import torch
from pyro_air import AIR, make_prior, get_per_param_lr, count_accuracy, visualize_model

In [2]:
#####################
# Benchmark Configs
#####################
seed = 123456
use_cuda = torch.cuda.is_available()
batch_size = 1
num_epoches = 6

z_pres_prior = 0.01
learning_rate = 1e-4
baseline_lr = 1e-1
elbo = RenyiELBO(num_particles=2)
# explicitly list out all configurable options
air_model_args = dict(
    num_steps=3,
    x_size=50,
    window_size=28,
    z_what_size=50,
    rnn_hidden_size=256,
    encoder_net=[200],
    decoder_net=[200],
    predict_net=[200],
    embed_net=None,
    bl_predict_net=[200],
    non_linearity="ReLU",
    decoder_output_bias=-2,
    decoder_output_use_sigmoid=True,
    use_masking=True,
    use_baselines=False,
    baseline_scalar=None,
    scale_prior_mean=3.0,
    scale_prior_sd=0.2,
    pos_prior_mean=0.0,
    pos_prior_sd=1.0,
    likelihood_sd=0.3,
    use_cuda=use_cuda,
)

In [3]:
#####################
# Initial Setup
#####################
device = torch.device("cuda" if use_cuda else "cpu")
z_pres_prior_fn = lambda t: [0.05, 0.05**2.3, 0.05 ** (5)][t]

X, Y = multi_mnist.load("data/air/.data")
X = (torch.from_numpy(X).float() / 255.0).to(device)
visualize_examples = X[5:10]
# Using float to allow comparison with values sampled from
# Bernoulli.
true_counts = torch.tensor([len(objs) for objs in Y], dtype=torch.float)

In [None]:
ac_losses, ac_accuracy, ac_wall_clock_times = None, None, None
for train_idx in range(0, 5):
    pyro.distributions.enable_validation(False)
    pyro.clear_param_store()  # just in case

    air = AIR(**air_model_args)
    adam = Adam(get_per_param_lr(learning_rate, baseline_lr))
    svi = SVI(air.model, air.guide, adam, loss=elbo)

    all_loss = []
    all_accuracy = []
    time_per_epoch = []

    for i in range(num_epoches):
        start_time = time.perf_counter()
        # technically this might step over slightly more than 1 epoch...
        losses = []
        for j in range(int(np.ceil(X.size(0) / batch_size))):
            losses.append(
                svi.step(X, batch_size=batch_size, z_pres_prior_p=z_pres_prior_fn)
            )
        end_time = time.perf_counter()

        accuracy, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)

        all_loss.append(np.mean(losses) / X.size(0))
        all_accuracy.append(accuracy)
        time_per_epoch.append(end_time - start_time)

        print(
            f"Epoch={i}, current_epoch_step_time={time_per_epoch[-1]:.2f}, loss={all_loss[-1]:.2f}"
        )
        print(f"accuracy={accuracy}, counts={counts}")

    if ac_losses is None:
        ac_losses = all_loss
        ac_accuracy = all_accuracy
        ac_wall_clock_times = np.cumsum(time_per_epoch)

    else:
        ac_losses = np.vstack((ac_losses, all_loss))
        ac_accuracy = np.vstack((ac_accuracy, all_accuracy))
        ac_wall_clock_times = np.vstack(
            (ac_wall_clock_times, np.cumsum(time_per_epoch))
        )

    # Save run.
    wall_clock_times = np.cumsum(time_per_epoch)
    arr = np.array([all_loss, all_accuracy, wall_clock_times])
    df = pd.DataFrame(
        arr.T, columns=["ELBO loss", "Accuracy", "Epoch wall clock times"]
    )
    df.to_csv(
        f"./training_runs/pyro_air_renyi_epochs_6_mccoy_prior_{train_idx}.csv",
        index=False,
    )

arr = np.array([ac_losses, ac_accuracy, ac_wall_clock_times])
mean_arr = np.mean(arr, axis=1)
std_arr = np.std(arr, axis=1)
df_arr = np.vstack((mean_arr, std_arr))
df = pd.DataFrame(
    df_arr.T,
    columns=[
        "Mean ELBO loss",
        "Mean accuracy",
        "Mean epoch wall clock times",
        "Std ELBO loss",
        "Std accuracy",
        "Std epoch wall clock times",
    ],
)
df.to_csv(
    "./training_runs/pyro_air_rws_epochs_6_mccoy_prior.csv",
    index=False,
)

If you want to enumerate sites, you need to use TraceEnum_ELBO instead.
