In [1]:
import torch
import pyro
import pyro.distributions as dist

In [2]:
# model
def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))


In [3]:
scale(torch.tensor(3.3))

tensor(3.8493)

##### for condition
This function is only looks at weight, that satisfy `measurement==9.5`:

In [4]:
# condition0
def conditioned_scale_obs(guess):  # equivalent to conditioned_scale above
    # print("conditioned guess:", guess)
    weight = pyro.sample("weight", dist.Normal(guess, 1.))
     # here we condition on measurement == 9.5
    return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=9.5)

##### for guide
i.e. approximation for weight's posterior distribution `quides = q(z=weight) ~ p(z=weight|x=mesurment)` (will be with use of ELBO):

In [5]:
# guide0
from torch.distributions import constraints

def scale_parametrized_guide(guess):
    
    a = pyro.param("a", torch.tensor(guess))
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    # print("guide a:", a)
    return pyro.sample("weight", dist.Normal(a, torch.abs(b)))

In [6]:
scale_parametrized_guide(torch.tensor(2.2))



tensor(2.5861, grad_fn=<AddBackward0>)

##### for params

In [7]:
print(pyro.get_param_store().keys())

dict_keys(['a', 'b'])


In [8]:
pyro.clear_param_store()

In [9]:
print(pyro.get_param_store().keys())

dict_keys([])


##### for inference
Take such params (`a` and `b`) so they can produce `weight` (using `scale_parametrized_guide`) which satisfy condition (i.e. `conditioned_scale_ops`) (condition used as model). So we knew about `p(x|z)` and will try to find `p(z|x)`:

In [11]:
guess = 8.5

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale_obs,
                     guide=scale_parametrized_guide,
                     optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
                     loss=pyro.infer.Trace_ELBO())


losses, a,b  = [], [], []
num_steps = 250  # 2500
for t in range(num_steps):
    losses.append(svi.step(guess))
    a.append(pyro.param("a").item())
    b.append(pyro.param("b").item())

In [12]:
print(len(a))
print(a[-1])

250
8.809281349182129


##### results:

In [13]:
import matplotlib.pyplot as plt


In [14]:
%matplotlib
plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss");
print('a = ',pyro.param("a").item())
print('b = ', pyro.param("b").item())

Using matplotlib backend: Qt5Agg
a =  8.809281349182129
b =  0.7539201378822327


In [16]:
%matplotlib
plt.subplot(1,2,1)
plt.plot([0,num_steps],[9.14,9.14], 'k:')
plt.plot(a)
plt.ylabel('a')

plt.subplot(1,2,2)
plt.ylabel('b')
plt.plot([0,num_steps],[0.6,0.6], 'k:')
plt.plot(b)
plt.tight_layout()

Using matplotlib backend: Qt5Agg
