In [153]:
from torchbnn.modules.module import BayesModule
from torchbnn.modules.linear import BayesLinear
from torchbnn.utils import freeze, unfreeze
from torch.nn.functional import relu, tanh
import torchbnn
import torch

test_input = torch.randn(20)    


In [210]:
class BNN(BayesModule):
    def __init__(self, action_dim, obs_dim, reward_dim, W_world_model=None):
        super(BayesModule, self).__init__()

        # is reward dim == 1 ???
        self.in_features = action_dim + obs_dim
        self.h_in_features = self.in_features + 64
        self.h_out_features = self.h_in_features + 32
        self.out_features = obs_dim + reward_dim

        self.input_layer = BayesLinear(prior_mu=0, prior_sigma=1,
                                       in_features=self.in_features,
                                       out_features=self.h_in_features)
        self.hidden_layer = BayesLinear(prior_mu=0, prior_sigma=1,
                                        in_features=self.h_in_features,
                                        out_features=self.h_out_features)
        self.ouput_layer = BayesLinear(prior_mu=0, prior_sigma=1,
                                       in_features=self.h_out_features,
                                       out_features=self.out_features)

        if W_world_model:
            self.copy_params_from_world_model(W_world_model)

    def forward(self, x):
        x = relu(self.input_layer(x))
        x = relu(self.hidden_layer(x))
        x = self.ouput_layer(x)
        return x

    def copy_params_from_world_model(self, W):
        try:
            self.load_state_dict(W)
        except BaseException:
            print('non compatible W')

    def deterministic_mode(self):
        '''deterministic output'''
        freeze(self)
        
    def stochatisc_mode(self):
        '''stochatisc output'''
        unfreeze(self)



net = BNN(action_dim=10, obs_dim=10, reward_dim=1)
net.deterministic_mode()

In [222]:
net(test_input)

tensor([  81.3517, -361.4016,  337.7245,  302.7568,  -37.9019,   44.4398,
         -29.8435,   55.0830, -151.8567,  118.5528, -428.5110],
       grad_fn=<AddBackward0>)

In [130]:
torchbnn.utils.freeze(net)
print('FREEZE')
print('first: ', net(test))
print('second: ', net(test))


torchbnn.utils.unfreeze(net)
print('UNFREEZE')
print('first: ', net(test))
print('second: ', net(test))


FREEZE
first:  tensor([-142.6856, -154.8361, -362.0369,   63.5905, -372.0220,    9.6467,
         375.8112,  313.8050, -168.6241,  169.5249,  277.2051],
       grad_fn=<AddBackward0>)
second:  tensor([-142.6856, -154.8361, -362.0369,   63.5905, -372.0220,    9.6467,
         375.8112,  313.8050, -168.6241,  169.5249,  277.2051],
       grad_fn=<AddBackward0>)
UNFREEZE
first:  tensor([ 304.5933,  431.0191, -203.6651,  216.2456, -108.9122,  -41.8183,
         -41.3938,  -56.9498, -391.9138,   98.8383, -335.1754],
       grad_fn=<AddBackward0>)
second:  tensor([-250.9065,  701.5959, -173.7515,  197.5743, -163.2101, -200.6471,
         553.5250,  110.7768, -351.0387, -447.2856,  233.4259],
       grad_fn=<AddBackward0>)


In [133]:
# SAMPLE PARAMETERS TEST
net_params_first_query = dict(net.named_parameters())
net_params_second_query = dict(net.named_parameters())
net_params_key = list(net_params_first_query.keys())

In [134]:
net_params_key

['in_layer.weight_mu',
 'in_layer.weight_log_sigma',
 'in_layer.bias_mu',
 'in_layer.bias_log_sigma',
 'hidden_layer.weight_mu',
 'hidden_layer.weight_log_sigma',
 'hidden_layer.bias_mu',
 'hidden_layer.bias_log_sigma',
 'out_layer.weight_mu',
 'out_layer.weight_log_sigma',
 'out_layer.bias_mu',
 'out_layer.bias_log_sigma']

In [137]:
[net_params_first_query[k] == net_params_second_query[k] for k in net_params_key]

[tensor([[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]),
 tensor([[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]),
 tensor([True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
        

In [152]:
net2 = BasicNet(action_dim=10, obs_dim=10, reward_dim=1, W_world_model=net.state_dict())

In [143]:
net_params_first_query

{'in_layer.weight_mu': Parameter containing:
 tensor([[-0.1157,  0.1072, -0.1538,  ..., -0.0555, -0.0156, -0.1273],
         [-0.0550, -0.2128,  0.0103,  ..., -0.0186,  0.1170, -0.0612],
         [-0.1844, -0.0708, -0.0252,  ..., -0.1186, -0.1652, -0.2070],
         ...,
         [-0.1870,  0.1595, -0.1436,  ..., -0.0367, -0.0323, -0.2139],
         [ 0.0441, -0.1508,  0.0981,  ..., -0.1064,  0.1030,  0.0827],
         [ 0.0303,  0.2087,  0.0735,  ..., -0.1331, -0.1048, -0.0448]],
        requires_grad=True),
 'in_layer.weight_log_sigma': Parameter containing:
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True),
 'in_layer.bias_mu': Parameter containing:
 tensor([-0.1026, -0.0030, -0.0679,  0.1733, -0.2041,  0.0486,  0.0939,  0.1831,
         -0.0933,  0.1928, 

In [142]:
net.state_dict()

OrderedDict([('in_layer.weight_mu',
              tensor([[-0.1157,  0.1072, -0.1538,  ..., -0.0555, -0.0156, -0.1273],
                      [-0.0550, -0.2128,  0.0103,  ..., -0.0186,  0.1170, -0.0612],
                      [-0.1844, -0.0708, -0.0252,  ..., -0.1186, -0.1652, -0.2070],
                      ...,
                      [-0.1870,  0.1595, -0.1436,  ..., -0.0367, -0.0323, -0.2139],
                      [ 0.0441, -0.1508,  0.0981,  ..., -0.1064,  0.1030,  0.0827],
                      [ 0.0303,  0.2087,  0.0735,  ..., -0.1331, -0.1048, -0.0448]])),
             ('in_layer.weight_log_sigma',
              tensor([[0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.]])),
             ('in_layer.bias_mu',