In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib notebook

import numbers

import torch
from torch import nn
from torch import optim
from torch import distributions
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from tqdm.auto import tqdm, trange

import notebook_setup
import ppo, utils
from systems import CartPoleEnv, plot_cartpole

## Probabilistic model

Traditional models predict mapping $x \rightarrow y$

Model $M$ learns to predict mapping $x \rightarrow (\mu, \sigma) : y \sim N(\mu, \sigma)$

In [245]:
# TODO: Compare single variance parameter vs. variance output per instance
class Model(nn.Module):
    def __init__(self, inputs, outputs, hidden=32, variance=None):
        super().__init__()
        self.repr =  nn.Sequential(
            nn.Linear(inputs, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden)
        )
        self.state = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, outputs)
        )
        self.uncertainty = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, outputs)
        )
        if isinstance(variance, (torch.Tensor, np.ndarray)):
            if variance.ndim == 1:
                self.var = nn.Parameter(torch.diagflat(variance))
            elif variance.ndim == 2:
                self.var = nn.Parameter(torch.tensor(variance))
        elif isinstance(variance, numbers.Real):
            self.var = nn.Parameter(torch.eye(outputs) * variance)
        else:
            self.var = None
        

    def evaluate(self, x):
        x_ = self.repr(x)
        mean = self.state(x_)
        if self.var is None:
            variance = torch.diag(torch.abs(self.uncertainty(x_)) + 1e-3)
        else:
            variance = torch.abs(self.var)
        if x.ndim > 1:
            raise ValueError('x must me 1 dimensional i.e one instance')
        dist = distributions.MultivariateNormal(mean, covariance_matrix=variance)    
        return dist
        
    def forward(self, x):
        dist = self.evaluate(x)
        state = dist.sample()
        return state

## Example
Training for regression over a simple function $z = \sqrt{x^2 + y^2}$

Domain: $x, y \in [-5..5]$

In [231]:
torch.manual_seed(0)
m = Model(2, 1, variance=None)
o = optim.Adam(m.parameters(), lr=0.001)
X, Y = torch.rand(100, 1), torch.rand(100, 1)
with torch.no_grad():
    X, Y = -5 + (X * 10), -5 + (Y * 10)
    Z = torch.sqrt(X**2 + Y**2)

uncertainties, losses = [], []
for epoch in trange(20, leave=False):
    losses.append([])
    uncertainties.append([])
    for x, y, z in zip(X, Y, Z):
        d = m.evaluate(torch.cat((x, y)))
        z_ = d.rsample()
        loss = (z_ - z)**2
        loss.backward()
        o.step()
        o.zero_grad()
        losses[-1].append(loss.item())
        uncertainties[-1].append(d.variance.item())
    uncertainties[-1] = np.mean(uncertainties[-1])
    losses[-1] = sum(losses[-1]) / len(losses[-1])

  0%|          | 0/20 [00:00<?, ?it/s]

In [244]:
# Detecting anomaly
inputs = torch.tensor([3., 4.])
d = m.evaluate(inputs)
print('sqrt(3^2 + 4^2)=5: %.2f' % torch.exp(d.log_prob(torch.tensor([5.]))).item())
print('sqrt(3^2 + 4^2)=5.1: %.2f' % torch.exp(d.log_prob(torch.tensor([5.1]))).item())
print('sqrt(3^2 + 4^2)=5.2: %.2f' % torch.exp(d.log_prob(torch.tensor([5.2]))).item())
print('\nOut of distribution:')
inputs = torch.tensor([8., 6.])
d = m.evaluate(inputs)
print('sqrt(8^2 + 6^2)=10: %.2f' % torch.exp(d.log_prob(torch.tensor([10.]))).item())
print('sqrt(8^2 + 6^2)=10.1: %.2f' % torch.exp(d.log_prob(torch.tensor([10.1]))).item())
print('sqrt(8^2 + 6^2)=10.2: %.2f' % torch.exp(d.log_prob(torch.tensor([10.2]))).item())

sqrt(3^2 + 4^2)=5: 0.98
sqrt(3^2 + 4^2)=5.1: 0.56
sqrt(3^2 + 4^2)=5.2: 0.27

Out of distribution:
sqrt(8^2 + 6^2)=10: 1.07
sqrt(8^2 + 6^2)=10.1: 0.96
sqrt(8^2 + 6^2)=10.2: 0.80


In [None]:
plt.plot(losses)
plt.title('MSE Loss over epochs')

In [None]:
plt.plot(uncertainties)
plt.title('Average predicted variance over epochs')

In [None]:
X = torch.linspace(0, 10, 25).reshape(-1, 1)
Y = torch.linspace(0, 10, 25).reshape(-1, 1)
Z = torch.zeros(len(Y), len(X))
varZ = torch.zeros_like(Z)
for i, y in enumerate(Y):
    for j, x in enumerate(X):
        d = m.evaluate(torch.cat((x,y)))
        Z[i, j] = d.mean
        varZ[i, j] = d.variance.item()
Z, varZ = Z.detach().numpy(), varZ.detach().numpy()

In [None]:
xgrid, ygrid = torch.meshgrid(X[:,0], Y[:,0])

fig = plt.figure(figsize=(6,6))
ax = fig.gca(projection='3d')

ax.plot_surface(xgrid, ygrid, Z,cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)
ax.plot_surface(xgrid, ygrid, Z + 5 * varZ,
                       linewidth=0, antialiased=False, alpha=0.3, facecolor='r')
ax.plot_surface(xgrid, ygrid, Z - 5 * varZ,
                       linewidth=0, antialiased=False, alpha=0.3, facecolor='b')

## Example
Training for regression over a simple function $z = sin(x) * y$

In [None]:
torch.manual_seed(0)
m = Model(2, 1, variance=None)
o = optim.Adam(m.parameters(), lr=0.001)
X, Y = torch.rand(100, 1), torch.rand(100, 1)
with torch.no_grad():
    X, Y = -5 + (X * 10), -5 + (Y * 10)
    Z = torch.sin(X) * Y

uncertainties, losses = [], []
for epoch in trange(20, leave=False):
    losses.append([])
    uncertainties.append([])
    for x, y, z in zip(X, Y, Z):
        d = m.evaluate(torch.cat((x, y)))
        z_ = d.rsample()
        loss = (z_ - z)**2
        loss.backward()
        o.step()
        o.zero_grad()
        losses[-1].append(loss.item())
        uncertainties[-1].append(d.variance.item())
    uncertainties[-1] = np.mean(uncertainties[-1])
    losses[-1] = sum(losses[-1]) / len(losses[-1])

In [None]:
plt.plot(losses)
plt.title('MSE Loss over epochs')

In [None]:
plt.plot(uncertainties)
plt.title('Average predicted variance over epochs')

In [None]:
X = torch.linspace(0, 10, 25).reshape(-1, 1)
Y = torch.linspace(0, 10, 25).reshape(-1, 1)
Z = torch.zeros(len(Y), len(X))
varZ = torch.zeros_like(Z)
for i, y in enumerate(Y):
    for j, x in enumerate(X):
        d = m.evaluate(torch.cat((x,y)))
        Z[i, j] = d.mean
        varZ[i, j] = d.variance.item()
Z, varZ = Z.detach().numpy(), varZ.detach().numpy()

In [None]:
xgrid, ygrid = torch.meshgrid(X[:,0], Y[:,0])

fig = plt.figure(figsize=(6,6))
ax = fig.gca(projection='3d')

ax.plot_surface(xgrid, ygrid, Z,cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)
ax.plot_surface(xgrid, ygrid, Z + 5 * varZ,
                       linewidth=0, antialiased=False, alpha=0.3, facecolor='r')
ax.plot_surface(xgrid, ygrid, Z - 5 * varZ,
                       linewidth=0, antialiased=False, alpha=0.3, facecolor='b')