## [Modules in Pyro](http://pyro.ai/examples/modules.html#Modules-in-Pyro)

In [1]:
import os
import torch
import torch.nn as nn

In [2]:
import pyro
from pyro.nn import PyroModule, PyroParam, PyroSample

In [3]:
class Linear(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        self.weights = nn.Parameter(torch.randn(in_dim, out_dim))
        self.bias = nn.Parameter(torch.randn(out_dim))

    def forward(self, x):
        return self.bias + x @ self.weights

In [4]:
linear = Linear(5,2)

In [5]:
assert isinstance(linear, nn.Module)

In [6]:
assert not isinstance(linear, PyroModule)

In [7]:
input = torch.randn(100, 5)
out_put = linear(input)
out_put.shape

torch.Size([100, 2])

In [8]:
class PLinear(PyroModule):
    pass

In [9]:
pl = PyroModule[Linear](5,2)

In [10]:
pl(input).shape

torch.Size([100, 2])

In [11]:
from pyro.nn.module import to_pyro_module_

In [12]:
to_pyro_module_(linear)

In [13]:
linear(input)

tensor([[ 0.8498,  0.3232],
        [ 3.6070, -0.0276],
        [ 6.3415, -0.0842],
        [ 2.1117,  0.3711],
        [ 1.0867,  2.7148],
        [ 6.6765, -1.1550],
        [ 4.2565,  1.1502],
        [ 4.2230, -0.4409],
        [ 3.1318, -0.4750],
        [ 1.3123, -0.1186],
        [ 2.7899,  1.5780],
        [ 2.7586, -0.4809],
        [ 3.6424,  0.6866],
        [ 0.6538,  2.1651],
        [ 6.0489,  0.8106],
        [ 7.3646,  0.1095],
        [ 1.8836,  2.5773],
        [ 1.3078,  1.2425],
        [ 1.2893, -1.4890],
        [ 2.2098,  0.8795],
        [ 2.6475,  0.2944],
        [-2.3164,  1.4191],
        [ 2.0159,  0.6273],
        [ 1.0478,  1.5719],
        [ 6.0270, -0.7681],
        [ 4.0176,  1.8191],
        [ 3.1099,  0.1005],
        [ 1.2689,  0.2008],
        [ 1.5573, -0.8755],
        [ 5.5609,  0.8563],
        [ 4.2680, -0.6270],
        [ 1.4890,  0.8538],
        [ 3.1696, -0.5615],
        [ 6.8742,  2.3512],
        [ 4.0680, -1.2307],
        [ 4.5771, -0

### [Pyro Effects](http://pyro.ai/examples/modules.html#How-effects-work)

In [14]:
import pyro.poutine as poutine

In [15]:
pyro.clear_param_store()

In [16]:
linear = Linear(5,2)

In [17]:
with poutine.trace() as tr:
    linear(input)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

[]
[]


In [18]:
to_pyro_module_(linear)
with poutine.trace() as tr:
    linear(input)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

['bias', 'weights']
['bias', 'weights']


## [Constraints](http://pyro.ai/examples/modules.html#How-to-constrain-parameters)

In [19]:
from pyro.distributions import constraints

In [20]:
print([name for name in linear.named_parameters()])

[('weights', Parameter containing:
tensor([[ 1.1425,  1.8639],
        [ 0.7406,  0.9192],
        [ 0.7400,  1.8840],
        [-0.0123, -1.6374],
        [-0.0611, -0.7611]], requires_grad=True)), ('bias', Parameter containing:
tensor([0.5482, 0.3786], requires_grad=True))]


In [21]:
linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)

In [22]:
print([name for name in linear.named_parameters()])

[('weights', Parameter containing:
tensor([[ 1.1425,  1.8639],
        [ 0.7406,  0.9192],
        [ 0.7400,  1.8840],
        [-0.0123, -1.6374],
        [-0.0611, -0.7611]], requires_grad=True)), ('bias_unconstrained', Parameter containing:
tensor([0.0840, 0.9590], requires_grad=True))]


## [How to make a PyroModule Bayesian](http://pyro.ai/examples/modules.html#How-to-make-a-PyroModule-Bayesian)

In [23]:
import pyro.distributions as dist

In [25]:
class NormalModel(PyroModule):
    def __init__(self, name=""):
        super().__init__(name)
        self.loc = PyroSample(dist.Normal(0,1))

In [29]:
class GlobalModel(NormalModel):
    def forward(self, data):
        loc = self.loc
        with pyro.plate("data", len(data)):
            print(loc.shape)
            pyro.sample("obs", dist.Normal(loc,1), obs=data)

In [28]:
class LocalModel(NormalModel):
    def forward(self, data):
        with pyro.plate("data", len(data)):
            loc = self.loc
            print(loc.shape)
            pyro.sample("obs", dist.Normal(loc, 1), obs=data)

In [30]:
data = torch.randn(10)

In [31]:
LocalModel()(data)

torch.Size([10])


In [32]:
GlobalModel()(data)

torch.Size([])


### [How to create a complex nested](http://pyro.ai/examples/modules.html#How-to-create-a-complex-nested-PyroModule)

In [33]:
class BayesianLinear(PyroModule):
    def __init__(self, in_size, out_size):
       super().__init__()
       self.bias = PyroSample(
           prior=dist.LogNormal(0, 1).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))

    def forward(self, input):
        return self.bias + input @ self.weight  # this line samples bias and weight

In [37]:
class Model(PyroModule):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = BayesianLinear(in_dim, out_dim)
        self.obs_scale = PyroSample(dist.LogNormal(0,1))
    
    def forward(self, input, output=None):
        obs_loc = self.linear(input)
        obs_scale = self.obs_scale
        with pyro.plate("instances", len(input)):
            return pyro.sample("obs", dist.Normal(obs_loc, obs_scale).to_event(1), obs=output)

In [39]:
from pyro.infer.autoguide import AutoNormal

In [42]:
from pyro.infer import Trace_ELBO, SVI

In [44]:
from pyro.optim import Adam

In [46]:
%%time
pyro.clear_param_store()
pyro.set_rng_seed(2)

model = Model(5,2)
x = torch.randn(100,5)
y = model(x)

guide = AutoNormal(model)

svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

for step in range(501):
    loss = svi.step(x,y) / y.numel()
    if step % 100 ==0:
        print("step {} loss= {:0.4g}".format(step, loss))

step 0 loss= 6.516
step 100 loss= 2.61
step 200 loss= 2.235
step 300 loss= 2.076
step 400 loss= 1.968
step 500 loss= 1.98
CPU times: user 2.54 s, sys: 0 ns, total: 2.54 s
Wall time: 2.83 s


In [47]:
with poutine.trace() as tr:
    model(x)
for site in tr.trace.nodes.values():
    print(site["type"], site['name'], site["value"].shape)

sample linear.bias torch.Size([2])
sample linear.weight torch.Size([5, 2])
sample obs_scale torch.Size([])
sample instances torch.Size([100])
sample obs torch.Size([100, 2])


In [48]:
with poutine.trace() as tr:
    guide(x)
for site in tr.trace.nodes.values():
    print(site["type"], site['name'], site["value"].shape)

param AutoNormal.locs.linear.bias torch.Size([2])
param AutoNormal.scales.linear.bias torch.Size([2])
sample linear.bias_unconstrained torch.Size([2])
sample linear.bias torch.Size([2])
param AutoNormal.locs.linear.weight torch.Size([5, 2])
param AutoNormal.scales.linear.weight torch.Size([5, 2])
sample linear.weight_unconstrained torch.Size([5, 2])
sample linear.weight torch.Size([5, 2])
param AutoNormal.locs.obs_scale torch.Size([])
param AutoNormal.scales.obs_scale torch.Size([])
sample obs_scale_unconstrained torch.Size([])
sample obs_scale torch.Size([])
