In [1]:
import numpy as np
import torch
from pyro.infer import MCMC, NUTS, HMC
import matplotlib.pyplot as plt
import hamiltorch
import pandas as pd
hamiltorch.set_random_seed(123)

DATA_PATH = "../data/rdata"

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
def load_data():

    # Load Data

    df = pd.read_table(DATA_PATH, header=None, delim_whitespace=True)
    df.columns = ["X", "Y"]
    df["index"] = np.where(df.index < 100, "Train", "Test")

    # Create train and test

    X_train = np.array(df.loc[df["index"] == "Train", "X"]).reshape(-1, 1)
    Y_train = np.array(df.loc[df["index"] == "Train", "Y"])
    X_test = np.array(df.loc[df["index"] == "Test", "X"]).reshape(-1, 1)
    Y_test = np.array(df.loc[df["index"] == "Test", "Y"])

    return X_train, X_test, Y_train, Y_test

X_train, X_test, Y_train, Y_test = load_data()

In [3]:
X = torch.tensor(X_train).float()
Y = torch.tensor(Y_train)

def log_prob(params, x=X, y=Y):
    
    """
    
    Computes the log posterior density for any given params "state vector", and x and y
    
    Indexing for reference:
    
    Low level parameters: 
    
    w_ih = params[0:8]; weight parameters entering first hidden layer
    b_h = params[8:16]; bias parameters of hidden units in layer 1
    w_ho = params[16:24]; weight parameters exiting first hidden layer
    b_o = params[24]; bias parameter of the final outputs
    
    Hyperparameters (log of actual precisions):
    
    log_w_prec_ih = params[25]; precision parameter for weights entering hidden layer 1
    log_b_prec_h = params[26]; precision parameter for biases of hidden units in layer 1
    log_w_prec_ho = params[27]; precision parameter for weights entering hidden layer 1
    log_y_prec = params[28]; precision of noise
    
    """
    
    # Extract parameters from parameter vector
    
    w_ih = params[0:8].reshape(1, 8)
    b_h = params[8:16].reshape(8, )
    w_ho = params[16:24].reshape(8, 1)
    b_o = params[24].reshape(1, )
    
    log_w_ih_prec = params[25]
    log_b_h_prec = params[26]
    log_w_ho_prec = params[27]
    log_y_prec = params[28]
    
    # Exponentiate log transforms (and add small constant for underflow)
    
    w_ih_prec = torch.exp(log_w_ih_prec) + 1e-12
    b_h_prec = torch.exp(log_b_h_prec) + 1e-12
    w_ho_prec = torch.exp(log_w_ho_prec) + 1e-12
    y_prec = torch.exp(log_y_prec) + 1e-12
    
#     w_ih_prec = log_w_ih_prec
#     b_h_prec = log_b_h_prec
#     w_ho_prec = log_w_ho_prec
#     y_prec = log_y_prec
        
    # Prior PDFs
    
    # -- Hyperparameters
    
    w_ih_prec_dist = torch.distributions.Gamma(0.25, 0.000625)
    b_h_prec_dist = torch.distributions.Gamma(0.25, 0.000625)
    w_ho_prec_dist = torch.distributions.Gamma(0.25, 0.000625)
    y_prec_dist = torch.distributions.Gamma(0.25, 0.000625)
        
    # -- Low Level Parameters

    w_ih_dist = torch.distributions.Normal(loc=torch.zeros((1, 8)), 
                                           scale=1/torch.sqrt(w_ih_prec))
    b_h_dist = torch.distributions.Normal(loc=torch.zeros((8, )), 
                                          scale=1/torch.sqrt(b_h_prec))
    w_ho_dist = torch.distributions.Normal(loc=torch.zeros((8, 1)), 
                                           scale=1/torch.sqrt(w_ho_prec))
    b_o_dist = torch.distributions.Normal(loc=torch.zeros((1, )), 
                                          scale=100)
    
    # Likelihood PDF
    
    # -- Forward Pass of BNN
    
    z1 = torch.mm(x, w_ih) + b_h
    a1 = torch.tanh(z1)
    output = torch.mm(a1, w_ho) + b_o
    
    # -- Likelihood density
    
    likelihood = torch.distributions.Normal(
        loc=output, 
        scale=1 / torch.sqrt(y_prec),
    )
    
    # Log Posterior
    
    log_posterior = 0
    
    # -- Prior (hyperparameters and lower-level parameters)
    
    log_posterior += w_ih_prec_dist.log_prob(w_ih_prec) + log_w_ih_prec
    log_posterior += b_h_prec_dist.log_prob(b_h_prec) + log_b_h_prec
    log_posterior += w_ho_prec_dist.log_prob(w_ho_prec) + log_w_ho_prec
    log_posterior += y_prec_dist.log_prob(y_prec) + log_y_prec
    
#     log_posterior += w_ih_prec_dist.log_prob(w_ih_prec) 
#     log_posterior += b_h_prec_dist.log_prob(b_h_prec) 
#     log_posterior += w_ho_prec_dist.log_prob(w_ho_prec) 
#     log_posterior += y_prec_dist.log_prob(y_prec) 
    
    log_posterior += torch.sum(w_ih_dist.log_prob(w_ih))
    log_posterior += torch.sum(b_h_dist.log_prob(b_h))
    log_posterior += torch.sum(w_ho_dist.log_prob(w_ho))
    log_posterior += torch.sum(b_o_dist.log_prob(b_o))
    
    # -- Likelihood
    
#     log_posterior += torch.sum(likelihood.log_prob(y))
    
    return log_posterior

In [4]:
# Test 

y_prec_dist = torch.distributions.Gamma(0.25, 0.000625)
y_prec = torch.exp(torch.tensor(-23182.5137))
y_prec_dist.log_prob(0 + 1e-12)

tensor(17.5908)

In [5]:
# Initial state

params_init = torch.zeros(29)
params_init.requires_grad = True

# Shared Parameters

N = 200

In [6]:
# Standard HMC

L = 100
step_size = 0.0004
params_hmc = hamiltorch.sample(log_prob_func=log_prob, 
                               params_init=params_init, 
                               num_samples=N,
                               sampler=hamiltorch.Sampler.HMC,
                               step_size=step_size, 
                               num_steps_per_sample=L)

Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples | Samples/sec
0d:00:00:39 | 0d:00:00:00 | #################### | 200/200 | 5.13       
Acceptance Rate 1.00


In [None]:
# NUTS

step_size = 0.003
params_hmc = hamiltorch.sample(log_prob_func=log_prob, params_init=params_init, num_samples=N,
                               sampler=hamiltorch.Sampler.HMC_NUTS, burn=100,
                               step_size=step_size)

In [7]:
# Implicit RMHMC

L = 10
omega = 10.
softabs_const=10**24
step_size = 0.05

params_ermhmc = hamiltorch.sample(log_prob_func=log_prob, params_init=params_init, num_samples=100, jitter=10,
                                  step_size=step_size, num_steps_per_sample=L, sampler=hamiltorch.Sampler.RMHMC,
                                  integrator=hamiltorch.Integrator.IMPLICIT, 
                                  softabs_const=softabs_const, 
                                  fixed_point_max_iterations=10,
                                  metric=hamiltorch.Metric.HESSIAN)

Sampling (Sampler.RMHMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples | Samples/sec


RuntimeError: cholesky_cpu: U(28,28) is zero, singular U.

In [9]:
# Explicit RMHMC

L = 10
omega = 100.
softabs_const=10**3
step_size = 0.004

params_ermhmc = hamiltorch.sample(log_prob_func=log_prob, params_init=params_init, num_samples=100, jitter=10,
                                  step_size=step_size, num_steps_per_sample=L, sampler=hamiltorch.Sampler.RMHMC,
                                  integrator=hamiltorch.Integrator.EXPLICIT, 
                                  metric=hamiltorch.Metric.SOFTABS, 
                                  debug=2,
                                  softabs_const=softabs_const)

Sampling (Sampler.RMHMC; Integrator.EXPLICIT)
Time spent  | Time remain.| Progress             | Samples | Samples/sec
Invalid hessian: tensor([[ 1.7827e+17, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
                 nan, -0.0000e+00, -0.0000e+00, -0.0000e+00],
        [-0.0000e+00,  1.7827e+17, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
                 nan, -0.0000e+00, -0.0000e+00, -0.0000e+00],


0d:00:02:48 | 0d:00:08:31 | #####--------------- |  24/100 | 0.15       

ValueError: The parameter scale has invalid values

### Debug

In [None]:
# from hamiltorch.util import hessian
# from hamiltorch.samplers import fisher


# params_init = torch.zeros(29, requires_grad=True)
# log_prob_value = log_prob(params_init)

# params_init.requires_grad = True
# # log_prob_value.requires_grad = True
# temp = fisher(params_init, log_prob, jitter=1e5)[0]

# torch.distributions.multivariate_normal.MultivariateNormal(
#     loc=torch.zeros_like(params_init),
#     covariance_matrix=temp
# )

# torch.sort(torch.abs(torch.diag(fisher(params_init, log_prob, jitter=10)[0])))

# Sigma_k = torch.rand(29, 29)
# Sigma_k = torch.mm(Sigma_k, Sigma_k.t())
# Sigma_k.add_(torch.eye(29))
# torch.distributions.MultivariateNormal(torch.tensor([1]), Sigma_k)