In [9]:
import math
import torch
import torch.nn as nn
import hess
from hess.nets import MaskedNet
from torch.nn import Module, init
from torch.nn.parameter import Parameter
import torch.nn.functional as F

In [10]:
class MaskedLayer(Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True, pct_keep=0.6):
        super(MaskedLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        
        dist = torch.distributions.Bernoulli(pct_keep)
        self.mask = dist.sample(sample_shape=torch.Size(self.weight.shape))
        
        if bias:
            self.bias_mask = dist.sample(sample_shape=torch.Size(self.bias.shape))

        
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)


    def forward(self, input):
        return F.linear(input, self.weight * self.mask, self.bias * self.bias_mask)
    
    def extra_repr(self):
        return 'iln_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


In [11]:
in_ = 10
hid_ = 5
module = nn.ModuleList()
layer = MaskedLayer(in_, hid_, pct_keep=0.6)
module.append(layer)
mod = torch.nn.Sequential(*module)

In [12]:
inp = torch.rand(1, 10)

In [13]:
layer(inp)

tensor([[-0.8089, -0.3762,  0.2887,  0.5946, -0.3476]],
       grad_fn=<AddmmBackward>)

In [14]:
mod(inp)

tensor([[-0.8089, -0.3762,  0.2887,  0.5946, -0.3476]],
       grad_fn=<AddmmBackward>)

## Now actually trying to build masked nets ##

In [15]:
hess.nets.MaskedNet(train)

hess.nets.masked_net.MaskedNet