In [11]:
import scipy.stats as stats
import numpy as np
import math
from scipy.stats import norm
from scipy.stats import gamma

n=50

true_beta = np.transpose(stats.norm.rvs(loc=0,scale=1,size=2))
true_phi = stats.gamma.rvs(a=3,scale=1/2,size=1)

x = np.transpose(np.array([np.ones(n),stats.norm.rvs(loc=0,scale=1,size=n)]))
y = np.random.normal(x.dot(true_beta), np.sqrt(1/true_phi))


beta0 = stats.norm.rvs(loc=0,scale=1,size=2)
phi0 = stats.gamma.rvs(a=3,scale=1/2,size=1)
theta0 = np.hstack([beta0, phi0])

#Set hyperparameters
a = 3.0
b= 2.0

print(true_beta,true_phi)
print(theta0)

[-0.04431882 -0.23033121] [ 0.52173396]
[-1.33771753  0.17439902  1.47057371]


In [26]:
import warnings
warnings.filterwarnings("ignore")

In [49]:
def leapfrog(theta, A, r, eps):
    #gradients are specific to this example
    gradients = A.dot(theta)
    r_upd = r + eps/2 * (gradients)
    theta_upd = theta + eps * r_upd
    gradients = A.dot(theta_upd)
    r_upd = r_upd + eps/2 * (gradients)
    return theta_upd, r_upd

In [50]:
def log_joint(theta, A):
    #seems like this can be up to a normalizing constant so that's what i did, but otherwise that could be the issue
    return -1/2 * theta.T.dot(A).dot(theta)


In [51]:
def BuildTree(theta,A, r, u, v, j, eps):
    triangle_max = 1000 #recommend value pg 1359
    if(j==0):
        #base case, take one leapfrog step in direction v
        theta_prime,r_prime = leapfrog(theta,A,r,v*eps)
        if(u <= np.exp(log_joint(theta_prime,A)-(1/2)*r_prime.dot(r_prime))):
            n_prime = 1
        else:
            n_prime = 0
        if(log_joint(theta_prime,A)-(1/2)*r_prime.dot(r_prime) > u-triangle_max):
            s_prime = 1
        else:
            s_prime = 0
        return theta_prime,r_prime,theta_prime,r_prime,theta_prime,n_prime,s_prime    
    else:
        #recursion-build the left and right subtrees
        theta_minus,r_minus,theta_plus,r_plus,theta_prime,n_prime,s_prime = BuildTree(theta,A,r,u,v,j-1,eps)
        if(s_prime==1):
            if(v == -1):
                theta_minus,r_minus,dash1,dash2,theta_primep,n_primep,s_primep = BuildTree(theta_minus,A,r_minus,u,v,j-1,eps)
            else:
                dash1,dash2,theta_plus,r_plus,theta_primep,n_primep,s_primep = BuildTree(theta_plus,A,r_plus,u,v,j-1,eps)
            #if(n_prime+n_primep==0):
                #print('Ahhhh cant divide by zero:',n_prime,n_primep)
                #p=0.99
            #else:    
            p = np.exp(np.log(n_primep)-np.log(n_prime+n_primep))
            unif = np.random.uniform()
            if(p>u):
                theta_prime = theta_primep
            if((theta_plus-theta_minus).dot(r_minus) >= 0 and (theta_plus-theta_minus).dot(r_plus) >= 0):  
                s_prime = s_primep  
            else:
                s_prime = 0  
            n_prime = n_prime+n_primep
    return theta_minus,r_minus,theta_plus,r_plus,theta_prime,n_prime,s_prime

In [64]:
def NUTS_Mvt(theta0,A,eps,M):
    no_par = theta0.shape[0]
    theta_m = np.zeros((M,no_par))
    theta_m[0,:] = theta0
    for m in range(1,M):
        #print('M: ',m,' Theta: ',theta_m[m-1,:])
        r0 = stats.norm.rvs(size=no_par)
        u = np.random.uniform(low=0,high=np.exp(log_joint(theta_m[m-1,:],A)-(1/2)*r0.dot(r0)))
        theta_minus = theta_m[m-1,:]
        theta_plus = theta_m[m-1,:]
        r_minus = r0
        r_plus = r0
        j=0
        theta_m[m,:] = theta_m[m-1,:]
        n = 1
        s=1
        while(s==1):
            v_j = np.random.choice([-1,1])
            if(v_j==-1):
                theta_minus,r_minus,dash1,dash2,theta_prime,n_prime,s_prime = BuildTree(theta_minus,A,r_minus,u,v_j,j,eps)
            else:
                dash1,dash2,theta_plus,r_plus,theta_prime,n_prime,s_prime = BuildTree(theta_plus,A,r_plus,u,v_j,j,eps)
            if(s_prime == 1):
                p = min(1,n_prime/n)
                unif = np.random.uniform()
                if(p>u):
                    theta_m[m,:] = theta_prime
            n = n+n_prime
            if((theta_plus-theta_minus).dot(r_minus) >= 0 and (theta_plus-theta_minus).dot(r_plus) >= 0):
                s = s_prime
            else:
                s = 0
            j = j+1    
    return(theta_m)

In [76]:
import scipy.stats as stats
#Known precision matrix A - df=250 with identity scale
A = stats.wishart.rvs(df=250, scale= np.eye(250))
#target distribution is zero-mean 250-dimensional multivariate normal with known precision
theta0 = stats.norm.rvs(size=250)
theta0[0:10]

array([-1.53448603,  0.08281695,  0.00444814, -0.29519678, -0.43522506,
       -0.69960618, -1.19593756, -1.20629432,  1.41021219,  0.1400045 ])

In [77]:
M = 100
eps = .01

In [78]:
%%time
results = NUTS_Mvt(theta0,A,eps, M)

CPU times: user 18.9 ms, sys: 1.27 ms, total: 20.1 ms
Wall time: 18.8 ms


In [79]:
print(np.mean(results[80:99,0]),np.mean(results[80:99,1]),np.mean(results[80:99,2]))

-1.53448602724 0.0828169505582 0.00444813831731
