In [13]:
import numpy as np
import torch
import torch.nn as nn

from torch.autograd import Variable

import pyro
from pyro.distributions import Normal
from pyro.infer import SVI
from pyro.optim import Adam


In [14]:
N = 100  # size of toy data
p = 1    # number of features

def build_linear_dataset(N, noise_std=0.1):
    X = np.linspace(-6, 6, num=N)
    y = 3 * X + 1 + np.random.normal(0, noise_std, size=N)
    X, y = X.reshape((N, 1)), y.reshape((N, 1))
    X, y = Variable(torch.Tensor(X)), Variable(torch.Tensor(y))
    return torch.cat((X, y), 1)


In [15]:
class RegressionModel(nn.Module):
    def __init__(self, p):
        super(RegressionModel, self).__init__()
        self.linear = nn.Linear(p, 1)

    def forward(self, x):
        return self.linear(x)

regression_model = RegressionModel(p)


In [16]:
loss_fn = torch.nn.MSELoss(size_average=False)
optim = torch.optim.Adam(regression_model.parameters(), lr=0.01)
num_iterations = 2000

def main():
    data = build_linear_dataset(N, p)
    x_data = data[:, :-1]
    y_data = data[:, -1]
    for j in range(num_iterations):
        # run the model forward on the data
        y_pred = regression_model(x_data)
        # calculate the mse loss
        loss = loss_fn(y_pred, y_data)
        # initialize gradients to zero
        optim.zero_grad()
        # backpropagate
        loss.backward()
        # take a gradient step
        optim.step()
        if (j + 1) % 50 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss.data[0]))
    # Inspect learned parameters
    print("Learned parameters:")
    for name, param in regression_model.named_parameters():
        print("%s: %.3f" % (name, param.data.numpy()))


In [17]:
main()

[iteration 0050] loss: 4002.3042
[iteration 0100] loss: 2344.6309
[iteration 0150] loss: 1298.9349
[iteration 0200] loss: 688.2485
[iteration 0250] loss: 360.9415
[iteration 0300] loss: 201.0548
[iteration 0350] loss: 130.1795
[iteration 0400] loss: 101.7277
[iteration 0450] loss: 91.3897
[iteration 0500] loss: 87.9892
[iteration 0550] loss: 86.9765
[iteration 0600] loss: 86.7035
[iteration 0650] loss: 86.6369
[iteration 0700] loss: 86.6222
[iteration 0750] loss: 86.6193
[iteration 0800] loss: 86.6188
[iteration 0850] loss: 86.6187
[iteration 0900] loss: 86.6187
[iteration 0950] loss: 86.6187
[iteration 1000] loss: 86.6187
[iteration 1050] loss: 86.6187
[iteration 1100] loss: 86.6187
[iteration 1150] loss: 86.6187
[iteration 1200] loss: 86.6187
[iteration 1250] loss: 86.6186
[iteration 1300] loss: 86.6186
[iteration 1350] loss: 86.6186
[iteration 1400] loss: 86.6186
[iteration 1450] loss: 86.6186
[iteration 1500] loss: 86.6186
[iteration 1550] loss: 86.6187
[iteration 1600] loss: 86.61

In [18]:
mu = Variable(torch.zeros(1, 1))
sigma = Variable(torch.ones(1, 1))
# define a unit normal prior
prior = Normal(mu, sigma)
# overload the parameters in the regression nn with samples from the prior
lifted_module = pyro.random_module("regression_module", regression_model, prior)
# sample a nn from the prior
sampled_nn = lifted_module()


In [19]:
def model(data):
    # Create unit normal priors over the parameters
    x_data = data[:, :-1]
    y_data = data[:, -1]
    mu, sigma = Variable(torch.zeros(p, 1)), Variable(10 * torch.ones(p, 1))
    bias_mu, bias_sigma = Variable(torch.zeros(1)), Variable(10 * torch.ones(1))
    w_prior, b_prior = Normal(mu, sigma), Normal(bias_mu, bias_sigma)
    priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
    
    # lift module parameters to random variables
    lifted_module = pyro.random_module("module", regression_model, priors)
    
    # sample a nn (which also samples w and b)
    lifted_nn = lifted_module()
    
    # run the nn forward
    latent = lifted_nn(x_data).squeeze()
    
    # condition on the observed data
    pyro.observe("obs", Normal(latent, Variable(0.1 * torch.ones(data.size(0)))),
                 y_data.squeeze())


In [20]:
softplus = torch.nn.Softplus()

def guide(data):
    # define our variational parameters
    w_mu = Variable(torch.randn(p, 1), requires_grad=True)
    
    # note that we initialize our sigmas to be pretty narrow
    w_log_sig = Variable(-3.0 * torch.ones(p, 1) + 0.05 * torch.randn(p, 1),
                         requires_grad=True)
    b_mu = Variable(torch.randn(1), requires_grad=True)
    b_log_sig = Variable(-3.0 * torch.ones(1) + 0.05 * torch.randn(1),
                         requires_grad=True)
    
    # register learnable params in the param store
    mw_param = pyro.param("guide_mean_weight", w_mu)
    sw_param = softplus(pyro.param("guide_log_sigma_weight", w_log_sig))
    mb_param = pyro.param("guide_mean_bias", b_mu)
    sb_param = softplus(pyro.param("guide_log_sigma_bias", b_log_sig))
    
    # guide distributions for w and b
    w_dist, b_dist = Normal(mw_param, sw_param), Normal(mb_param, sb_param)
    dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
    
    # overload the parameters in the module with random samples
    # from the guide distributions
    lifted_module = pyro.random_module("module", regression_model, dists)
    
    # sample a nn (which also samples w and b)
    return lifted_module()


In [21]:
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss="ELBO")


In [22]:
def main():
    pyro.clear_param_store()
    data = build_linear_dataset(N, p)
    for j in range(num_iterations):
        # calculate the loss and take a gradient step
        loss = svi.step(data)
        if j % 100 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / float(N)))


In [23]:
main()

[iteration 0001] loss: 3270.7712
[iteration 0101] loss: 1088.8938
[iteration 0201] loss: 330.5920
[iteration 0301] loss: 92.8331
[iteration 0401] loss: 53.6706
[iteration 0501] loss: 47.4898
[iteration 0601] loss: 48.9851
[iteration 0701] loss: 45.7022
[iteration 0801] loss: 45.8497
[iteration 0901] loss: 47.1203
[iteration 1001] loss: 46.2104
[iteration 1101] loss: 46.1773
[iteration 1201] loss: 46.7581
[iteration 1301] loss: 45.8975
[iteration 1401] loss: 46.5042
[iteration 1501] loss: 46.0531
[iteration 1601] loss: 45.8823
[iteration 1701] loss: 45.6859
[iteration 1801] loss: 45.6970
[iteration 1901] loss: 46.1963


In [24]:
for name in pyro.get_param_store().get_all_param_names():
    print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))


[guide_mean_weight]: 3.006
[guide_log_sigma_weight]: -3.902
[guide_mean_bias]: 1.033
[guide_log_sigma_bias]: -3.685
