## [Modules in Pyro](http://pyro.ai/examples/modules.html#Modules-in-Pyro)

In [4]:
import os
import torch
import torch.nn as nn

In [14]:
import pyro
from pyro.nn import PyroModule, PyroParam, PyroSample

In [22]:
class Linear(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        self.weights = nn.Parameter(torch.randn(in_dim, out_dim))
        self.bias = nn.Parameter(torch.randn(out_dim))

    def forward(self, x):
        return self.bias + x @ self.weights

In [23]:
linear = Linear(5,2)

In [11]:
assert isinstance(linear, nn.Module)

In [17]:
assert not isinstance(linear, PyroModule)

In [26]:
input = torch.randn(100, 5)
out_put = linear(input)
out_put.shape

torch.Size([100, 2])

In [28]:
class PLinear(PyroModule):
    pass

In [30]:
pl = PyroModule[Linear](5,2)

In [32]:
pl(input).shape

torch.Size([100, 2])

In [33]:
from pyro.nn.module import to_pyro_module_

In [35]:
to_pyro_module_(linear)

In [36]:
linear(input)

tensor([[ 2.0766e+00, -2.6112e+00],
        [ 1.6895e+00, -5.4963e+00],
        [ 2.0128e+00, -7.5941e-01],
        [ 4.8355e+00, -1.5488e+00],
        [ 9.2909e-01, -9.6555e-01],
        [ 1.9485e+00,  4.5479e+00],
        [ 1.4266e+00,  8.9857e-02],
        [ 3.3437e+00,  1.1225e+00],
        [ 2.0931e+00, -5.6345e+00],
        [ 2.1787e+00, -2.1629e+00],
        [ 1.4653e+00,  6.7229e+00],
        [ 2.6442e+00, -2.4068e+00],
        [ 1.1619e+00, -1.5065e+00],
        [-9.0335e-02, -2.2592e+00],
        [ 5.2148e+00,  6.8842e+00],
        [ 8.0969e-01,  8.8792e-01],
        [ 3.1978e+00, -8.8508e-01],
        [ 1.8560e+00, -4.0673e+00],
        [ 2.7549e+00, -3.6104e+00],
        [ 5.3286e+00, -3.5658e+00],
        [ 2.5343e+00, -1.5373e+00],
        [ 1.9811e+00,  5.7311e+00],
        [ 2.2636e+00,  9.7492e-01],
        [ 1.2939e+00,  3.1828e+00],
        [ 2.3499e+00, -3.1883e+00],
        [ 2.4353e+00, -1.6102e+00],
        [ 5.4609e-01, -4.8563e+00],
        [ 2.3265e+00,  2.404

### [Pyro Effects](http://pyro.ai/examples/modules.html#How-effects-work)

In [39]:
import pyro.poutine as poutine

In [37]:
pyro.clear_param_store()

In [38]:
linear = Linear(5,2)

In [42]:
with poutine.trace() as tr:
    linear(input)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

[]
[]


In [43]:
to_pyro_module_(linear)
with poutine.trace() as tr:
    linear(input)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

['bias', 'weights']
['bias', 'weights']


## [Constraints](http://pyro.ai/examples/modules.html#How-to-constrain-parameters)

In [45]:
from pyro.distributions import constraints

In [46]:
print([name for name in linear.named_parameters()])

[('weights', Parameter containing:
tensor([[-1.8910, -0.3617],
        [ 0.6565, -0.2591],
        [-1.2691,  0.0461],
        [-0.9272,  1.5582],
        [-0.4612,  2.2665]], requires_grad=True)), ('bias', Parameter containing:
tensor([-1.2014, -0.7166], requires_grad=True))]


In [49]:
linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)

In [50]:
print([name for name in linear.named_parameters()])

[('weights', Parameter containing:
tensor([[-1.8910, -0.3617],
        [ 0.6565, -0.2591],
        [-1.2691,  0.0461],
        [-0.9272,  1.5582],
        [-0.4612,  2.2665]], requires_grad=True)), ('bias_unconstrained', Parameter containing:
tensor([1.3703, 1.3749], requires_grad=True))]


## [How to make a PyroModule Bayesian](http://pyro.ai/examples/modules.html#How-to-make-a-PyroModule-Bayesian)