# Developing an RNN track filter model with gaussian uncertainty

In this notebook I want to play around with the code for making an RNN track filter model which predicts a next hit location and also produces uncertainties on that prediction in the form of a gaussian.

Basically, the model should produce both the central values and a covariance matrix.
The kind of loss that we want to minimize is then a Gaussian log-likelihood loss:

$L(x,y) = \log{|\Sigma|} + (y - f(x))^{\rm T}\Sigma^{-1}(y-f(x))$

I want to use pytorch for now, but unfortunately there is no nice functionality to compute a matrix determinant with autograph auto-differentiation. I can piece one together using stuff I've found online, however.

In [1]:
# Select a GPU first
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
cuda = False

In [2]:
# External imports
import numpy as np

# Torch imports
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

# Visualization
import matplotlib.pyplot as plt

# Magic
%matplotlib notebook

In [3]:
torch.__version__

'0.2.0_2'

## Utilities

To compute the determinant with gradients, we need a cholesky decomposition with gradient.

In [4]:
class Cholesky(torch.autograd.Function):
    """
    Cholesky decomposition with gradient. Taken from
    https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/gaussian_process_regression_basic.ipynb
    """
    @staticmethod
    def forward(ctx, a):
        l = torch.potrf(a, False)
        ctx.save_for_backward(l)
        return l

    @staticmethod
    def backward(ctx, grad_output):
        l, = ctx.saved_variables
        # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
        # Ideally, this should use some form of solve triangular instead of inverse...
        linv =  l.inverse()
        
        inner = (torch.tril(torch.mm(l.t(), grad_output)) * 
                 torch.tril(1.0 - Variable(l.data.new(l.size(1)).fill_(0.5).diag())))
        s = torch.mm(linv.t(), torch.mm(inner, linv))
        return s

## Test determinant calculation

In [5]:
# Generate a symmetric positive semi definite matrix
x = torch.randn(2, 2)
x = torch.mm(x, x.t())
v = Variable(x)
v

Variable containing:
 0.1574  0.2672
 0.2672  5.8062
[torch.FloatTensor of size 2x2]

In [6]:
# Compare ways to calculate the determinant
print((Cholesky.apply(v).diag().log().sum()*2).exp().data[0])
print(np.exp(torch.potrf(x).diag().log().sum()*2))
print(np.linalg.det(x.numpy()))

0.842461764812
0.842461755804
0.842462


## Convert model outputs into gaussian parameters

For dimension $d$, there will be $d$ mean values and a covariance matrix with $d(d+1) \over 2$ unique values.

That's a total of $d(d+3) \over 2$ values.

Ok, I think this will be the trick:

Separate out the means, variances, and covariances from the model outputs right away.
Construct the covariance matrix initially with the outer product of the variances.
Then the off-diagonal elements will already have the scale factors which can be multiplied by the correlations.
Construct the correlation matrix using the triangular indexing stuff above plus identity.

In [7]:
def parse_outputs(outputs, d):
    means = output[:d]
    variances = output[d:2*d].exp()
    correlations = output[2*d:].tanh()
    offdiag_idx = torch.ones(d, d).triu(1).nonzero().t()
    # Populate the covariance matrix initially with sqrt(vi*vj)
    cov_matrix = torch.mm(variances[:, None], variances[None, :]).sqrt()

In [8]:
def cov_log_determinant(cov):
    return Cholesky.apply(cov).diag().log().sum()*2

In [9]:
d = 2
output_size = d * (d + 3) / 2
output = Variable(torch.randn(output_size), requires_grad=True)

means = output[:d]
variances = output[d:2*d].exp()
correlations = output[2*d:].tanh()

# Constant tensor of upper off-diagonal indices
offdiag_idx = torch.ones(d, d).triu(1).nonzero().t()

# Populate the covariance matrix initially with sqrt(vi*vj)
cov_matrix = torch.mm(variances[:, None], variances[None, :]).sqrt()

# Scale the off-diagonal terms with the correlation to get covariance
cov_matrix[offdiag_idx[0], offdiag_idx[1]] = cov_matrix[offdiag_idx[0], offdiag_idx[1]] * correlations
cov_matrix[offdiag_idx[1], offdiag_idx[0]] = cov_matrix[offdiag_idx[1], offdiag_idx[0]] * correlations

means, variances, correlations, cov_matrix

(Variable containing:
 -1.1518
 -0.0865
 [torch.FloatTensor of size 2], Variable containing:
  0.4833
  1.0233
 [torch.FloatTensor of size 2], Variable containing:
  0.2890
 [torch.FloatTensor of size 1], Variable containing:
  0.4833  0.2032
  0.2032  1.0233
 [torch.FloatTensor of size 2x2])

## Develop gaussian log likelihood loss

In [10]:
# Target value
y = Variable(torch.randn(d))

# Residual term
error = y - means
res_term = error.dot(torch.mv(cov_matrix.inverse(), error))

# Log determinant term
det_term = cov_log_determinant(cov_matrix)

# Total loss
lgl_loss = res_term + det_term
lgl_loss

Variable containing:
 2.4388
[torch.FloatTensor of size 1]

## Extend it to work on batches

### Constructing the outputs

In [11]:
# Let's implement specifically for d=2 for simplicity
d = 2
batch_size = 2
output_size = d * (d + 3) / 2
output = Variable(torch.randn(batch_size, output_size), requires_grad=True)

In [12]:
means = output[:, :2]
variances = output[:, 2:4].exp()
correlations = output[:, 4].tanh()

cov_matrix = torch.bmm(variances[:, :, None], variances[:, None, :]).sqrt()
cov_matrix[:, 0, 1] = cov_matrix[:, 0, 1] * correlations
cov_matrix[:, 1, 0] = cov_matrix[:, 1, 0] * correlations

means, variances, correlations, cov_matrix

(Variable containing:
 -0.0125 -1.3630
  1.9493  0.8524
 [torch.FloatTensor of size 2x2], Variable containing:
  0.6606  0.2098
  0.4833  1.9598
 [torch.FloatTensor of size 2x2], Variable containing:
 -0.2505
  0.4468
 [torch.FloatTensor of size 2], Variable containing:
 (0 ,.,.) = 
   0.6606 -0.0932
  -0.0932  0.2098
 
 (1 ,.,.) = 
   0.4833  0.4348
   0.4348  1.9598
 [torch.FloatTensor of size 2x2x2])

### Calculating the loss

I could just loop over the samples in the batch. It's not very parallelizable, though. Maybe if it's just in the loss it won't be too bad.

In [13]:
# Target value
y = Variable(torch.randn(batch_size, d))

In [14]:
# Residual error
error = y - means

In [15]:
# Calculate each inverse separately
inv_cov_matrix = torch.stack([cov.inverse() for cov in cov_matrix])

In [16]:
inv_cov_matrix.size(), error.size()

(torch.Size([2, 2, 2]), torch.Size([2, 2]))

In [17]:
# Compute Cov^{-1} * error as a batch matrix multiply
res_right = torch.bmm(inv_cov_matrix, error.unsqueeze(-1)).squeeze(-1)
# Expand dimensions to perform dot products as batch matrix multiply
res_term = torch.bmm(error[:,None,:], res_right[:,:,None]).squeeze()
res_term

Variable containing:
 14.6666
  4.3070
[torch.FloatTensor of size 2]

In [18]:
diag_chols = torch.stack([Cholesky.apply(cov).diag() for cov in cov_matrix])
log_det = diag_chols.log().sum(1) * 2
log_det

Variable containing:
-2.0412
-0.2768
[torch.FloatTensor of size 2]

In [19]:
# Total loss
lgl_loss = (res_term + det_term).sum()
lgl_loss

Variable containing:
 17.3911
[torch.FloatTensor of size 1]