# Reliable Uncertainty Estimates in Deep Neural Networks using Noise Contrastive Priors

PyTorch implementation of ["Reliable Uncertainty Estimates in Deep Neural Networks using Noise Contrastive Priors"](https://arxiv.org/abs/1807.09289). It is based on the [reference tensorflow implementation](https://github.com/brain-research/ncp). 

$$
\begin{aligned}
    w_{ij} &= \bar{w}_{ij} + \sigma_{w,ij} \epsilon_{w,ij} \\
    b_i &= \bar{b}_i + \sigma_{b,i} \epsilon_{b,i}
    \mu_i &= \sum_i w_{ij} x_j + b_i
\end{aligned}
$$

$$
\begin{aligned}
    \langle \epsilon_{ij} \epsilon_{kl} \rangle &= \delta_{ik} \delta{jl} \\
    \langle \epsilon_{b,i} \epsilon_{b,k} \rangle  &= \delta_{ik}
\end{aligned}
$$

$$
\begin{aligned}
    \langle \mu_i \rangle &= \sum_j \bar{w}_{ij} x_j + \bar{b}_i \\
    \langle (\mu_i - \bar{\mu}_i) (\mu_k - \bar{\mu}_k) \rangle 
        &= 
            \sum_{jl} \sigma_{w,ij} x_j \sigma_{w,kl} x_l \langle \epsilon_{w,ij} \epsilon_{w,kl} \rangle  
            + \sigma_{b,i} \sigma_{b,k} \langle \epsilon_{b,i} \epsilon_{b,k} \rangle  
            \\
        &= \delta_{ik} \left( \sum_{j} \sigma_{w,ij}^2 x_j^2 + \sigma_i^2 \right)
\end{aligned}
$$

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data

from chmp.ds import get_color_cycle, Loop
from chmp.torch_utils.nn import Lambda, t2n

In [None]:
def generate_dataset(length=1000, noise_slope=0.2):
    "Adapted from https://github.com/brain-research/ncp/blob/master/ncp/datasets/toy.py"
    random = np.random.RandomState(0)
    
    inputs = np.linspace(-1, 1, length)
    noise_std = np.maximum(0, (inputs + 1) * noise_slope)
    targets = 0.5 * + np.sin(25 * inputs) + random.normal(0, noise_std)
    targets += 0.5 * inputs
    
    domain = np.linspace(-1.2, 1.2, 1000)
    train_split = np.repeat([False, True, False, True, False], 200)
    test_split = (1 - train_split).astype(bool)
    domain, inputs, targets = domain[:, None], inputs[:, None], targets[:, None]
    test_inputs, test_targets = inputs[test_split], targets[test_split]
    train_inputs, train_targets = inputs[train_split], targets[train_split]
    
    return dict(
        domain=domain, 
        target_scale=1,
        train=dict(inputs=train_inputs, targets=train_targets),
        test=dict(inputs=test_inputs, targets=test_targets),
    )

In [None]:
data = generate_dataset()

In [None]:
c0, c1, c2 = get_color_cycle(3)

for label, c in [('train', c1), ('test', c2)]:
    plt.plot(data[label]['inputs'][:, 0], data[label]['targets'][:, 0], '.', alpha=0.2, c=c, label=label)
    
plt.legend(loc='best')

In [None]:
class NCPEstimator(torch.nn.Module):
    def __init__(self, transform, transform_features, out_features, prior_scale=1e-1, eps=1e-6):
        super().__init__()
        self.transform = transform
        
        self.prior_scale = prior_scale
        self.transform_features = transform_features
        self.out_features = out_features
        self.eps = eps
        
        _p = torch.nn.Parameter
        
        self.mean_weight_loc = _p(torch.empty(self.out_features, self.transform_features))
        self.mean_weight_scale_p = _p(torch.ones(self.out_features, self.transform_features))
        
        self.mean_bias_loc = _p(torch.empty(self.out_features))
        self.mean_bias_scale_p = _p(torch.ones(self.out_features))
        
        self.to_scale = torch.nn.Sequential(
            torch.nn.Linear(self.transform_features, self.out_features),
            Lambda(lambda x: eps + F.softplus(x)),
        )
        
        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.mean_weight_loc)
            torch.nn.init.uniform_(self.mean_bias_loc, -1e-4, +1e-4)
        
    def forward(self, x):
        """Return the mean distribution and the target distribution."""
        hidden = self.transform(x)
        
        mean_weight_scale = self.eps + F.softplus(self.mean_weight_scale_p)
        mean_bias_scale = self.eps + F.softplus(self.mean_bias_scale_p)
        
        mean_loc = F.linear(hidden, self.mean_weight_loc, self.mean_bias_loc)
        mean_var = F.linear(hidden ** 2.0, mean_weight_scale ** 2.0, mean_bias_scale ** 2.0)
        mean_scale = torch.sqrt(mean_var)
        
        mean_params = mean_loc, mean_scale
        
        weight = self.mean_weight_loc + torch.randn_like(mean_weight_scale) * mean_weight_scale
        bias = self.mean_bias_loc + torch.randn_like(mean_bias_scale) * mean_bias_scale
        
        target_loc = F.linear(hidden, weight, bias)
        target_scale = self.to_scale(hidden)
        
        target_params = target_loc, target_scale
        
        return target_params, mean_params
    
    def predict(self, x):
        (_, target_scale), (mean_loc, mean_scale) = self(x)
        return mean_loc, target_scale, mean_scale
    
    def loss(self, x, y, input_noise, ood_mean_std=1.0, n_samples=1.0, bbb_scale=1.0, ncp_scale=1.0):
        ood_x = x + input_noise * torch.randn_like(x)
        
        (target_loc, target_scale), (mean_loc, mean_scale) = self(x)
        _, (ood_mean_loc, ood_mean_scale) = self(ood_x)
        
        nll = -torch.distributions.Normal(target_loc, target_scale).log_prob(y).sum() / len(x)
        
        ood_q = torch.distributions.Normal(ood_mean_loc, ood_mean_scale)
        ood_p = torch.distributions.Normal(y, ood_mean_std * torch.ones_like(ood_mean_scale))
        
        ncp_loss = torch.distributions.kl_divergence(ood_p, ood_q).sum() / len(x)
        
        w_q = torch.distributions.Normal(
            self.mean_weight_loc, 
            self.eps + F.softplus(self.mean_weight_scale_p),
        )
        w_p = torch.distributions.Normal(
            torch.zeros_like(self.mean_weight_loc), 
            self.prior_scale * torch.ones_like(self.mean_weight_scale_p),
        )
        
        b_q = torch.distributions.Normal(
            self.mean_bias_loc,
            self.eps + F.softplus(self.mean_bias_scale_p),
        )
        b_p = torch.distributions.Normal(
            torch.zeros_like(self.mean_bias_loc),
            self.prior_scale * torch.ones_like(self.mean_bias_scale_p),
        )
        
        bbb_loss = (
            torch.distributions.kl_divergence(w_q, w_p).sum() / n_samples
            + torch.distributions.kl_divergence(b_q, b_p).sum() / n_samples
        )
        
        return nll + bbb_scale * bbb_loss + ncp_scale * ncp_loss

In [None]:
dataset = torch.utils.data.TensorDataset(
    torch.as_tensor(data['train']['inputs'], dtype=torch.float),
    torch.as_tensor(data['train']['targets'], dtype=torch.float),
)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=20,
    shuffle=True,
    drop_last=True,
)

In [None]:
model = NCPEstimator(
    transform=torch.nn.Sequential(
        torch.nn.Linear(1, 30),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 30),
        torch.nn.LeakyReLU(),
    ),
    transform_features=30,
    out_features=1,
    prior_scale=0.1,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
losses = []
for loop, _ in Loop.over(range(1_000)):
    for x, y in dataloader:
        optimizer.zero_grad()
        loss = model.loss(
            x, y, 
            input_noise=0.1, 
            ood_mean_std=1.0, 
            n_samples=len(dataset), 
            ncp_scale=1e-2,
        )
        loss.backward()
        optimizer.step()
        
        losses.append(float(loss))
        loop.print(f'{loop} {losses[-1]}')
        

In [None]:
plt.plot(losses)

In [None]:
(target_loc, target_scale), (mean_loc, mean_scale) = t2n(model(torch.as_tensor(data['domain'], dtype=torch.float)))

In [None]:
c0, c1, c2 = get_color_cycle(3)

plt.plot(data['domain'][:, 0], mean_loc[:, 0], '.')
plt.fill_between(
    data['domain'][:, 0], 
    mean_loc[:, 0] - (target_scale[:, 0] ** 2.0 + mean_scale[:, 0] ** 2.0) ** 0.5,
    mean_loc[:, 0] + (target_scale[:, 0] ** 2.0 + mean_scale[:, 0] ** 2.0) ** 0.5, 
    color=c0,
    alpha=0.2,
)

for label, c in [('train', c1), ('test', c2)]:
    plt.plot(data[label]['inputs'][:, 0], data[label]['targets'][:, 0], '.', alpha=0.2, c=c, label=label)
    
plt.legend(loc='best')

In [None]:
plt.plot(data['domain'][:, 0], mean_scale[:, 0] ** 2.0)
plt.plot(data['domain'][:, 0], target_scale[:, 0] ** 2.0)
plt.plot(data['domain'][:, 0], target_scale[:, 0] ** 2.0 + mean_scale[:, 0] ** 2.0)