In [1]:
import torch
import torch.nn as nn

In [15]:
from copy import deepcopy

class A:
    def __new__(cls, num): 
        class RenamedClass(A): pass
        RenamedClass.__name__ = RenamedClass.__name__ + str(num)
        RenamedClass.__qualname__ = RenamedClass.__qualname__ + str(num)

        return super().__new__(RenamedClass)

    def __init__(self, num: int) -> None:
        self.num_ = num

In [16]:
a = A(2)
print(a.num_)
print(type(a))
print(type(a).__name__)

2
<class '__main__.A.__new__.<locals>.RenamedClass2'>
RenamedClass2


In [18]:
print(type(a).__bases__)

(<class '__main__.A'>,)


In [5]:
a.__dict__

{'num_': 2}

In [10]:
a.__class__.__class__

type

In [2]:
class myModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.linear_1 = nn.Linear(10, 20)
        self.linear_2 = nn.Linear(5, 10)

    def forward(self, x):
        return self.linear_1(self.linear_2(x))

In [3]:
net = myModule()
params = net.parameters()

In [4]:
next(params)

Parameter containing:
tensor([[-0.0053, -0.2688, -0.2500,  0.2547, -0.0331, -0.2060, -0.1798, -0.1644,
          0.2731, -0.1672],
        [-0.0105,  0.2532,  0.1932,  0.1279, -0.0585, -0.2119, -0.1132,  0.2680,
          0.1519,  0.2612],
        [ 0.2614, -0.0190,  0.1357,  0.2151, -0.1640, -0.2250,  0.0911,  0.1884,
         -0.1633, -0.1808],
        [ 0.0460, -0.1405, -0.3025, -0.3056, -0.1088,  0.1284,  0.3000, -0.1107,
          0.0367,  0.0755],
        [ 0.0845, -0.3000,  0.0666,  0.1791, -0.1139,  0.0707,  0.1563, -0.0166,
          0.2621,  0.0430],
        [ 0.1358, -0.2560,  0.1651,  0.2751, -0.0942,  0.1834, -0.0700, -0.1850,
         -0.0859, -0.1729],
        [-0.2302, -0.0779,  0.3085, -0.1661,  0.1191, -0.1805,  0.0731,  0.0679,
          0.2603,  0.0940],
        [-0.2694, -0.2155, -0.0845,  0.3019,  0.1486,  0.2191,  0.1877, -0.1678,
          0.1364,  0.0080],
        [-0.2938, -0.1822,  0.1818, -0.2225, -0.0695,  0.0811,  0.2484,  0.0096,
         -0.1907, -0.1783

In [5]:
for module_name, module in net.named_modules():
    # skip module as itself
    if module_name == "":
        continue

    named_params = dict(module.named_parameters())

    for param_name, param in named_params.items():
        # this param will be a storage for samples
        param.requires_grad = False

        # add mean and std as paramters
        module.register_parameter(
            param_name + "_mean", 
            nn.Parameter(torch.rand_like(param, requires_grad=True))
        )
        module.register_parameter(
            param_name + "_std", 
            nn.Parameter(torch.rand_like(param, requires_grad=True))
        )

        # add forward pre hook to sample paramters for the module
        def sample_param_prehook(cur_module: nn.Module, input, param_name=param_name) -> None:
            param = cur_module.get_parameter(param_name)
            param.copy_(
                cur_module.get_parameter(param_name + "_std") * torch.randn_like(param) + \
                  cur_module.get_parameter(param_name + "_mean")
            )
            print(f"Sampled for {param_name}")

        module.register_forward_pre_hook(sample_param_prehook)


In [6]:
for name, p in net.linear_1.named_parameters():
    print(name)
    print(p.requires_grad)

weight
False
bias
False
weight_mean
True
weight_std
True
bias_mean
True
bias_std
True


In [7]:
for name, p in net.named_parameters():
    print(name)
    print(p.requires_grad)

linear_1.weight
False
linear_1.bias
False
linear_1.weight_mean
True
linear_1.weight_std
True
linear_1.bias_mean
True
linear_1.bias_std
True
linear_2.weight
False
linear_2.bias
False
linear_2.weight_mean
True
linear_2.weight_std
True
linear_2.bias_mean
True
linear_2.bias_std
True


In [8]:
net(torch.ones(5))

Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias


tensor([28.7302, 15.8110, 11.7792, 24.1246, 23.4398, 29.1531, 10.5294,  6.4090,
         9.9188,  8.4975, 22.0848, 12.5610, 11.8430, 14.7623, 12.0982, 15.4158,
        19.2481,  8.7205, 21.6517, 12.2197], grad_fn=<ViewBackward0>)

In [9]:
from types import MethodType

net._old_forward = net.forward

def new_forward(self, x, num_samples: int = 1):
    sampled_outputs = []
    for i in range(num_samples):
        sampled_outputs.append(self._old_forward(x))

    return sampled_outputs

net.forward = MethodType(new_forward, net)

In [15]:
net.forward(torch.ones(5), 3)

Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias
Sampled for weight
Sampled for bias


[tensor([ 6.5150, 14.8343, 15.1563,  9.7855, 16.7029, 12.4275, 13.6434, 11.8232,
         18.7794, 11.2198, 18.6116, 17.9930, 23.7289, 17.5775,  8.4132, 18.4627,
         10.6028, 33.9548, 20.1214, 13.8757], grad_fn=<ViewBackward0>),
 tensor([22.7525, 10.8335,  8.2850,  5.2818, 19.1122, 21.8927, 22.0851, 14.2603,
         13.7619, 13.4381, 16.2655, 20.6990, 12.2032, 22.3954,  7.1262, 18.9957,
         25.9468, 13.9817, 31.5973, 12.3743], grad_fn=<ViewBackward0>),
 tensor([ 1.7146, 19.5655,  6.6596, 12.0201,  1.5429, 14.9744, 10.6279, 20.5214,
         14.6439, 16.6848,  4.1312,  9.1195, 15.2734, 12.6234,  9.1329, 13.3845,
         14.0635,  7.4908, 17.8786, 13.4148], grad_fn=<ViewBackward0>)]