In [1]:
import torch
import torch.nn as nn

In [2]:
class myModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.linear_1 = nn.Linear(10, 20)
        self.linear_2 = nn.Linear(5, 10)

        self.third_layer = nn.Sequential(nn.Conv1d(1, 1, 3, padding=1), nn.Linear(20, 10))

    def forward(self, x):
        return self.linear_1(self.linear_2(x))

In [15]:
a = nn.Sequential(nn.Conv2d(1, 1, 3, padding=1), nn.Linear(20, 10))
len(list(a.parameters()))

4

In [16]:
list(a.named_parameters())

[('0.weight',
  Parameter containing:
  tensor([[[[ 0.1003, -0.3127, -0.2752],
            [ 0.2209,  0.0025,  0.2514],
            [-0.0948,  0.2127, -0.2904]]]], requires_grad=True)),
 ('0.bias',
  Parameter containing:
  tensor([-0.0777], requires_grad=True)),
 ('1.weight',
  Parameter containing:
  tensor([[-0.0386,  0.1847,  0.0200,  0.0824,  0.0632, -0.1593, -0.0861, -0.0369,
           -0.0978,  0.0437,  0.2033, -0.2067,  0.2038,  0.0943, -0.0493, -0.2016,
           -0.0472, -0.0265,  0.1920,  0.0072],
          [-0.0711,  0.0820, -0.1121,  0.1862, -0.1268,  0.0115, -0.1832,  0.0183,
           -0.0252, -0.1815,  0.0781, -0.2101,  0.0484, -0.2213, -0.0135, -0.2053,
           -0.1764,  0.0458, -0.0057, -0.0939],
          [ 0.0474, -0.1719, -0.0569, -0.0930,  0.0250,  0.2138,  0.1263, -0.0348,
            0.1050, -0.0655,  0.2144, -0.0643,  0.1849,  0.0166, -0.0249,  0.1120,
            0.0393, -0.0422, -0.0755, -0.0328],
          [-0.0033,  0.0134,  0.0326,  0.1710, -0.0171, 

In [19]:
dict(a.named_modules())

{'': Sequential(
   (0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (1): Linear(in_features=20, out_features=10, bias=True)
 ),
 '0': Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '1': Linear(in_features=20, out_features=10, bias=True)}

In [21]:
b = nn.Linear(20, 20)
dict(b.named_modules())

{'': Linear(in_features=20, out_features=20, bias=True)}

In [3]:
net = myModule()
params = net.parameters()

In [4]:
next(params)

Parameter containing:
tensor([[-3.2434e-02, -2.1235e-01,  2.0307e-01, -3.8971e-02, -4.7705e-02,
          2.9648e-02,  2.8886e-01,  2.2747e-01,  2.5189e-01, -1.3701e-01],
        [-1.7861e-01, -2.7500e-01, -2.7414e-01,  1.1913e-01,  2.3626e-01,
         -2.0756e-02, -2.1586e-01,  1.5784e-01,  2.5881e-01,  2.2508e-01],
        [ 2.1662e-01,  1.9394e-01,  1.3046e-02, -1.6407e-01,  1.2347e-01,
          1.8318e-01,  2.7386e-02,  2.5449e-01,  2.6245e-01, -1.0728e-01],
        [-2.4045e-01,  3.0579e-01, -1.0416e-01, -2.2607e-01,  1.2463e-01,
         -9.3524e-02, -5.9915e-02,  1.4448e-01,  2.6371e-01,  3.1516e-01],
        [-2.2760e-02,  1.3410e-01,  2.9566e-01,  2.3258e-01,  8.2151e-02,
          4.3427e-05,  2.9750e-01, -1.7836e-01,  1.1250e-01, -1.1434e-01],
        [ 2.8166e-01,  2.3681e-02, -1.0265e-01, -6.9556e-02,  1.1485e-01,
          8.7006e-02, -5.8378e-04,  1.8080e-02,  1.5271e-01,  1.6586e-01],
        [ 3.0796e-01,  2.0883e-01, -2.5709e-01, -3.0728e-01,  1.1218e-01,
          

In [5]:
from functools import partial
import torch.nn.functional as F
from types import MethodType


for module_name, module in net.named_modules():
    if module_name == "":
            continue
    
    if isinstance(module, nn.Linear):
        named_params = dict(module.named_parameters())

        for param_name, param in named_params.items():
            # this param will be a storage for samples
            param.requires_grad = False

            # add mean and std as paramters
            module.register_parameter(
                param_name + "_mean", 
                nn.Parameter(torch.rand_like(param, requires_grad=True))
            )
            module.register_parameter(
                param_name + "_std", 
                nn.Parameter(torch.rand_like(param, requires_grad=True))
            )

        # inital nn.Linear params are not needed
        del module.weight
        del module.bias
        
        # use local reparamterization trick for Linear layers
        def new_forward(self, x):
            # sample from weights posterior
            output =  F.linear(x, self.weight_mean)
            output += torch.randn(output.shape[-1]) * F.linear(x, self.weight_std)

            # sample from bias posterior
            output += self.bias_mean + torch.randn_like(self.bias_mean) * self.bias_std

            return output

        module.forward = MethodType(new_forward, module)

        continue

    # obtain all inital params of the submodule
    named_params = dict(module.named_parameters())

    for param_name, param in named_params.items():
        # this param will be a storage for samples
        param.requires_grad = False

        # add mean and std as paramters
        module.register_parameter(
            param_name + "_mean", 
            nn.Parameter(torch.rand_like(param, requires_grad=True))
        )
        module.register_parameter(
            param_name + "_std", 
            nn.Parameter(torch.rand_like(param, requires_grad=True))
        )

        # add forward pre hook to sample paramters for the module
        def sample_param_prehook(cur_module: nn.Module, input, param_name) -> None:
            param = cur_module.get_parameter(param_name)
            param.copy_(
                cur_module.get_parameter(param_name + "_std") * torch.randn_like(param) + \
                  cur_module.get_parameter(param_name + "_mean")
            )
            print(f"Sampled for {param_name}")

        module.register_forward_pre_hook(partial(sample_param_prehook, param_name=param_name))


KeyError: 'parameter name can\'t contain "."'

In [18]:
for name, p in net.named_parameters():
    print(name)
    print(p.requires_grad)

linear_1.weight_mean
True
linear_1.weight_std
True
linear_1.bias_mean
True
linear_1.bias_std
True
linear_2.weight_mean
True
linear_2.weight_std
True
linear_2.bias_mean
True
linear_2.bias_std
True


In [24]:
net(torch.ones(5, 5)).shape

torch.Size([5, 20])

In [25]:
from types import MethodType

net._old_forward = net.forward

def new_forward(self, x, num_samples: int = 1):
    sampled_outputs = []
    for i in range(num_samples):
        sampled_outputs.append(self._old_forward(x))

    return sampled_outputs

net.forward = MethodType(new_forward, net)

In [32]:
net.forward(torch.ones(5), 3)

[tensor([ 15.1424,   8.8401,   3.5889,  13.7657,  20.1540,  -2.6729,  23.2921,
          15.9443,  -6.3283,  38.5617,   9.6503,  52.0065,   8.3783,  18.1263,
          10.2334,  15.9154,  -9.4791,  -2.1770,  35.2312, -14.1118],
        grad_fn=<AddBackward0>),
 tensor([ 3.5292,  9.0041, 23.1724,  2.5326, -6.5397,  8.5960,  8.2655,  4.8640,
          7.3816, 12.2835,  1.8815,  8.0460, 11.9264,  1.7556,  5.3030, -3.9603,
          7.2321,  9.5580, 10.9277,  6.2839], grad_fn=<AddBackward0>),
 tensor([23.8653, 22.8296, 34.2731,  0.3284,  7.2709, 11.7731, 35.7310, 25.2350,
         21.2042, 46.7060, 17.2928, 32.7929,  4.2334,  4.2700, 41.7038,  0.9138,
         16.6423, 30.0019, 26.2667, 39.0951], grad_fn=<AddBackward0>)]