In [1]:
import torch
from torch.autograd import Variable

import pyro
import pyro.distributions as dist

In [20]:
mu = Variable(torch.zeros(1))   # mean zero
sigma = Variable(torch.ones(1)) # unit variance
x = dist.Normal(mu, sigma)      # x is a sample from N(0,1)
print(x, x())

Normal() tensor([-1.5713])


In [28]:
def weather():
    cloudy = pyro.sample('cloudy', dist.Bernoulli(0.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    sigma_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = pyro.sample('temp', dist.Normal, mean_temp, sigma_temp)
    return cloudy, temp()

for _ in range(10):
    print(weather())

('sunny', tensor(60.5679))
('cloudy', tensor(72.2203))
('sunny', tensor(86.3582))
('sunny', tensor(65.5065))
('sunny', tensor(93.9287))
('sunny', tensor(89.1511))
('sunny', tensor(99.6871))
('cloudy', tensor(45.3701))
('sunny', tensor(80.9162))
('cloudy', tensor(49.2897))


In [31]:
def ice_cream_sales():
    cloudy, temp = weather()
    expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.
    ice_cream = pyro.sample('ice_cream', dist.Normal(expected_sales, 10.0))
    return ice_cream

In [32]:
ice_cream_sales()

tensor(198.9955)

In [40]:
def geometric(p, t=None):
    if t is None:
        t = 0
    x = pyro.sample("x_{}".format(t), dist.Bernoulli(p))
    if torch.equal(x.data, torch.zeros(1)):
        return x
    else:
        return x + geometric(p, t+1)

print(geometric(0.5))

tensor(3.)


In [46]:
def normal_product(loc, scale):
    mu = Variable(torch.Tensor([loc]))
    sigma = Variable(torch.Tensor([scale]))
    z1 = pyro.sample("z1", dist.Normal(mu, sigma))
    z2 = pyro.sample("z2", dist.Normal(mu, sigma))
    y = z1 * z2
    return y

def make_normal_normal():
    mu_latent = pyro.sample("mu_latent", dist.Normal(0, 1))
    print (mu_latent)
    fn = lambda scale: normal_product(mu_latent, scale)
    return fn


In [47]:
print(make_normal_normal()(1.))

tensor(0.8393)
tensor([ 0.7729])


In [None]:

%matplotlib inline
# import some dependencies
import numpy as np
import matplotlib.pyplot as plt
try:
    import seaborn as sns
    sns.set()
except ImportError:
    pass

import torch
from torch.autograd import Variable

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

torch.manual_seed(101)