In [1]:
import pandas as pd
import os
import numpy as np
import scipy as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import pycuda.driver as cuda

In [2]:
from livelossplot import PlotLosses
from time import sleep
import timeit

In [3]:
import pyro
import pyro.distributions as dist
from pyro.distributions import Normal, Uniform, Delta
from pyro.infer.mcmc import MCMC, HMC, NUTS
from pyro.infer.mcmc.api import MCMC
import pyro.poutine as poutine
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, TracePredictive, JitTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO
from pyro.optim import Adam
from pyro.infer.mcmc.util import predictive
from pyro.distributions.util import sum_rightmost

In [4]:
from scipy.stats import uniform, norm
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt

In [5]:
# for CI testing
pyro.enable_validation(True)
pyro.set_rng_seed(1)
pyro.enable_validation(True)

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
device

device(type='cuda', index=0)

In [8]:
sigma_noise = 0.1
(x_data, y_data) = torch.load('foong_data.pt')
x_data = x_data.to(device)
y_data = y_data.to(device)

In [None]:
H = 50

In [None]:
class RegressionModel(nn.Module):
    def __init__(self, p):
        super(RegressionModel, self).__init__()
        self.linear1 = nn.Linear(p, H)
        self.linear2 = nn.Linear(H,1)
        self.to(device)

    def forward(self, x):
        out = torch.tanh(self.linear1(x))
        out = self.linear2(out)
        return out

p = 1
softplus = nn.Softplus()
regression_model = RegressionModel(p)

In [None]:
def model(x_data, y_data):
    options = dict(dtype=x_data.dtype, device=x_data.device)
    
    weight_loc = torch.zeros(H, p, **options)
    weight_scale = 5*torch.ones(H, p, **options)
    bias_loc = torch.zeros(H, **options)
    bias_scale = 5*torch.ones(H, **options)
    
    linear1_w_prior = Normal(weight_loc, weight_scale).independent(2)
    linear1_b_prior = Normal(bias_loc, bias_scale).independent(1)
    
    weight_loc = torch.zeros(1, H, **options)
    weight_scale = 5*torch.ones(1, H, **options)
    bias_loc = torch.zeros(1, **options)
    bias_scale = 5*torch.ones(1, **options)
    
    linear2_w_prior = Normal(weight_loc, weight_scale).independent(2)
    linear2_b_prior = Normal(bias_loc, bias_scale).independent(1)
    
    priors = {'linear1.weight': linear1_w_prior, 'linear1.bias': linear1_b_prior,
              'linear2.weight': linear2_w_prior, 'linear2.bias': linear2_b_prior,}
    
    lifted_module = pyro.random_module("module", regression_model, priors)
    
    lifted_reg_model = lifted_module()
    
    with pyro.plate("map", device=x_data.device):
        prediction_mean = lifted_reg_model(x_data).squeeze(-1)
        pyro.sample("obs", Normal(prediction_mean, sigma_noise), obs=y_data)
    return prediction_mean

In [None]:
from pyro.contrib.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model, prefix='modelA')

In [None]:
pyro.clear_param_store()

In [None]:
scheduler = pyro.optim.ReduceLROnPlateau({'optimizer': torch.optim.Adam, 'optim_args': {'lr': 0.05}, 'factor': 0.5, 'patience': 2})
svi = SVI(model, guide, scheduler, num_samples=100, loss=Trace_ELBO(num_particles=1, vectorize_particles=True))
svi.step(x_data, y_data)

In [None]:
%matplotlib inline
num_epoch = 100
num_iterations = 1000
liveloss = PlotLosses()
for j in range(num_epoch):
    logs = {}
    losses = [None] * num_iterations
    for k in range(num_iterations):
        loss = svi.step(x_data, y_data)
        losses[k] = loss
    scheduler.step(np.mean(losses))
    L = next(iter(scheduler.optim_objs.values()))
    logs['expected_loss'] = np.mean(losses)
    logs['learning_rate'] = L.optimizer.param_groups[0]['lr']
    liveloss.update(logs)
    liveloss.draw()

In [None]:
get_marginal = lambda traces, sites:EmpiricalMarginal(traces, sites)._get_samples_and_weights()[0].detach().cpu().numpy()

def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))

posterior = svi.run(x_data, y_data)

In [None]:
x_test = torch.linspace(-2.0, 2.0).unsqueeze(1).cuda()

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches(11.7, 8.27)

plt.scatter(x_data.cpu(), y_data.cpu())
for _ in range(1000):
    guide_trace = poutine.trace(guide).get_trace(x_test, None)
    lifted_reg_model = poutine.replay(model, guide_trace)
    y_test = lifted_reg_model(x_test, None)
    plt.plot(x_test.detach().cpu().numpy(), y_test.detach().cpu().numpy(), alpha=0.05, linewidth=1, color='lightblue')

In [None]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))