In [9]:
import os
import typing

import numpy as np
import torch
import torch.optim
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score
from torch import nn
from torch.nn import functional as F
from tqdm import trange

import abc
import warnings

In [6]:
torch.Tensor([1.0]).item()

1.0

In [11]:
class ParameterDistribution(torch.nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract class that models a distribution over model parameters,
    usable for Bayes by backprop.
    You can implement this class using any distribution you want
    and try out different priors and variational posteriors.
    All torch.nn.Parameter that you add in the __init__ method of this class
    will automatically be registered and know to PyTorch.
    """

    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log-likelihood of the given values
        :param values: Values to calculate the log-likelihood on
        :return: Log-likelihood
        """
        pass

    @abc.abstractmethod
    def sample(self) -> torch.Tensor:
        """
        Sample from this distribution.
        Note that you only need to implement this method for variational posteriors, not priors.

        :return: Sample from this distribution. The sample shape depends on your semantics.
        """
        pass

    def forward(self, values: torch.Tensor) -> torch.Tensor:
        # DO NOT USE THIS METHOD
        # We only implement it since torch.nn.Module requires a forward method
        warnings.warn('ParameterDistribution should not be called! Use its explicit methods!')
        return self.log_likelihood(values)

In [112]:
class UnivariateGaussian(ParameterDistribution):
    """
    Univariate Gaussian distribution.
    For multivariate data, this assumes all elements to be i.i.d.
    """

    def __init__(self, mu: torch.Tensor, sigma: torch.Tensor):
        super(
            UnivariateGaussian, self
        ).__init__()  # always make sure to include the super-class init call!
        assert mu.size() == () and sigma.size() == ()
        assert sigma > 0
        self.mu = mu
        self.sigma = sigma

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        ll = torch.log(((1. / np.sqrt(2.0 * np.pi))/self.sigma) * torch.exp(- (values - self.mu) ** 2 / (2.0 * self.sigma ** 2))).sum()
        return ll  # clip to avoid numerical issues

    def sample(self) -> torch.Tensor:
        # TODO: Implement this
        raise NotImplementedError()

In [113]:
test_norm = UnivariateGaussian(torch.tensor(0.0), torch.tensor(1.0))

In [68]:
test_values = torch.Tensor([1.1])

In [69]:
torch.log(1/ np.sqrt(2* np.pi * (torch.tensor(1.0) ** 2)) * torch.exp(-1/2*(torch.tensor(1.0) ** 2) * (test_values - torch.tensor(0.0)) ** 2)).sum()

tensor(-1.5239)

In [114]:
test_norm.log_likelihood(weight)

NameError: name 'sigma' is not defined

In [95]:
test_norm.log_likelihood(weight) + test_norm.log_likelihood(bias)

tensor(-22.2006, grad_fn=<AddBackward0>)

In [89]:
out_features = 4
in_features = 5
weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-3., -3.))
weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(0., .1))
bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(0., .1))
bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-3., -3.))

weight_sigma = torch.log(1. + torch.exp(weight_rho))
bias_sigma = torch.log(1. + torch.exp(bias_rho))
epsilon_weight = torch.autograd.Variable(torch.Tensor(out_features, in_features).normal_(0., 1.))
epsilon_bias = torch.autograd.Variable(torch.Tensor(out_features).normal_(0., 1.))
weight = weight_mu + weight_sigma * epsilon_weight
bias = bias_mu + bias_sigma * epsilon_bias

In [94]:
GAUSSIAN_SCALER = 1. / np.sqrt(2.0 * np.pi)
def gaussian(x, mu, sigma):
    bell = torch.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))
    return torch.clamp(GAUSSIAN_SCALER / sigma * bell, 1e-10, 1.)  # clip to avoid numerical issues

torch.log(gaussian(weight, 0, 1.0).sum() + gaussian(bias, 0, 1.0).sum())

tensor(2.2531, grad_fn=<LogBackward>)