In [1]:
import torch

In [4]:
class BernouilliWeights:
    """ Bernouilli Weights

    Represents bayesian weights with a bernouilli distribution
    """

    def __init__(self, lambda_):
        """ Initialize the weights with a bernouilli distribution 

        Args:
            lambda_ (torch.Tensor): Tensor of the same shape as the weights, represents the certainty of the weights of either being 1 or -1
        """
        self.lambda_ = lambda_
        self.uniform = torch.distributions.uniform.Uniform(0, 1)

    def sample(self, samples=1):
        """ Sample from the exponential distribution using the Gumbel-softmax trick"""

        # 1. Sample from the uniform distribution U(0, 1) the logistic noise (G1 - G2)
        # Shape should be (samples, *self.lambda_.shape)
        logistic_noise = self.uniform.sample(
            (samples, *self.lambda_.shape)).to(self.lambda_.device)
        # 2. Compute delta = 1/2 * log(U/(1-U))
        # Shape should be (samples, *self.lambda_.shape)
        delta = torch.log(logistic_noise / (1 - logistic_noise)) / 2
        # 3. Compute the relaxed weights w.r.t the mean
        # Shape should be (samples, *self.lambda_.shape)
        relaxed_w = torch.tanh((self.lambda_ + delta))
        return relaxed_w

In [122]:
lambda_ = torch.nn.parameter.Parameter(torch.empty(
            (4, 4)))
lambda_.data = torch.distributions.normal.Normal(0, 1).sample(lambda_.shape)
weights = BernouilliWeights(lambda_)
relaxed_w = weights.sample(5)
print(relaxed_w)
print(relaxed_w.T)

tensor([[[-0.5204, -0.3785, -0.8682, -0.5991],
         [ 0.5199, -0.1939,  0.9848,  0.5348],
         [-0.5119,  0.7114, -0.5791,  0.0712],
         [ 0.9896,  0.4623, -0.9462, -0.2486]],

        [[-0.1238,  0.3015,  0.1561, -0.7636],
         [-0.3629,  0.8665,  0.7653,  0.8692],
         [ 0.2279,  0.2945, -0.1524, -0.9783],
         [-0.0323,  0.2806, -0.9515, -0.8894]],

        [[-0.6353, -0.9143,  0.3775, -0.0537],
         [ 0.2229,  0.7806,  0.8320,  0.6712],
         [-0.3936, -0.9235, -0.6569, -0.9744],
         [-0.5836, -0.7403, -0.1723, -0.3913]],

        [[ 0.7785, -0.8510, -0.1429, -0.2452],
         [ 0.9694,  0.9219,  0.5260,  0.3916],
         [ 0.5950, -0.6732,  0.9818, -0.5747],
         [ 0.1652, -0.1687, -0.9570, -0.7221]],

        [[ 0.0991, -0.3027,  0.9812, -0.8097],
         [ 0.9450, -0.7345,  0.9936, -0.1031],
         [-0.6236,  0.6347,  0.9114, -0.3396],
         [-0.9018,  0.5795, -0.6098, -0.4076]]], grad_fn=<TanhBackward0>)
tensor([[[-0.5204, -0.123