In [102]:
# def func_gradient(func, theta_sym, values):
#    grad = np.array([diff(L(*theta_sym), i) for i in theta_sym])
#    substitutes = list(zip(theta_sym, values))
#
#    for i, element in enumerate(grad):
#        for j in substitutes:
#            grad[i] = grad[i].subs(*j)
#    return grad  

In [140]:
# from Algorithm 1: HMC, page1353
def Leapfrog(theta, r, eps, f):
    logp_tilde, grad_tilde = f(theta) 
    r_tilde = r + (eps/2) * f(theta)[1]
    theta_tilde = theta + eps * r_tilde
    r_tilde = r_tilde + (eps/2) * f(theta_tilde)[0]
    return theta_tilde, r_tilde, logp_tilde, grad_tilde

In [150]:
# from Algorithm 3: Efficeint NUTS pg.1364
def BuildTree(theta, r, grad, u, v, j, eps, f):
    if j == 0:
        # base case, take one leapfrog step in the direction v
        theta_prime, r_prime, logp_prime, grad_prime = Leapfrog(theta, r, v*eps, f)
        
        n_prime = int(u <= np.exp(logp_prime - 0.5 * np.dot(r_prime, r_prime)))
        
        s_prime = int(logp_prime - 0.5 * np.dot(r_prime, r_prime) > np.log(u) - 1000)
        
        theta_minus = theta_prime
        theta_plus = theta_prime
        
        r_minus = r_prime
        r_plus = r_prime
        
        grad_minus = grad_prime
        grad_plus = grad_prime
    else:
        # recursion, build left and right subtrees
        theta_minus, r_minus, grad_minus, theta_plus, r_plus, grad_plus, theta_prime, grad_prime, logp_prime, n_prime, s_prime = BuildTree(theta, r, grad, u, v, j-1, eps, f)
        
        if s_prime == 1:
            if v == -1:
                theta_minus, r_minus, grad_minus, _,_,_, theta_doub_prime, grad_doub_prime, logp_doub_prime, n_doub_prime, s_doub_prime = BuildTree(theta_minus, r_minus, grad_minus, u, v, j-1, eps, f)
            else:
                _, _, _, theta_plus, r_plus, grad_plus, theta_doub_prime, grad_doub_prime, logp_doub_prime, n_doub_prime, s_doub_prime = BuildTree(theta_plus, r_plus, grad_plus, u, v, j-1, eps, f)

            # Use Metropolis-Hastings
            prob = n_doub_prime / max((n_prime + n_doub_prime),1)
            if (np.random.uniform(0,1,1) < prob):
                theta_prime = theta_doub_prime
                grad_prime = grad_doub_prime
                logp_prime = logp_doub_prime
            
            ind_1 = int(np.dot(theta_plus-theta_minus, r_minus) >= 0)
            ind_2 = int(np.dot(theta_plus-theta_minus, r_plus) >= 0)
            s_prime = s_prime * s_doub_prime * ind_1 * ind_2
            n_prime = n_prime + n_doub_prime
        
    return theta_minus, r_minus, grad_minus, theta_plus, r_plus, grad_plus, theta_prime, grad_prime, logp_prime, n_prime, s_prime

In [151]:
def efficient_nuts(theta0, eps, M, f):
    # initialize samples matrix
    # put initial theta0 in first row of matrix
    parems = len(theta0)
    samples = np.empty((M+1, parems))
    samples[0, :] = theta0
    logp, grad = f(theta0)
    
    for m in range(1, M+1):
        # resample
        norm_samp = np.random.multivariate_normal(np.zeros(parems), np.identity(parems), 1)
        r0 = norm_samp.ravel()

        inside = logp - 0.5 * np.dot(r0,r0)
        # resample u ~ uniform([0, exp(inside)])
        u = np.random.uniform(0, np.exp(inside), 1)

        # initialize minus's and plus's
        theta_minus = samples[m-1, :]
        theta_plus = samples[m-1 ,:]
        r_minus = r0
        r_plus = r0
        grad_minus = grad
        grad_plus = grad
        
        j = 0
        samples[m,:] = samples[m-1,:]
        n = 1
        s = 1
        
        while s == 1:
            v_j = np.random.uniform(-1,1,1)
            if v_j == -1:
                theta_minus, r_minus, grad_minus, _, _, _, theta_prime, grad_prime, logp_prime, n_prime, s_prime = BuildTree(theta_minus, r_minus, grad_minus, u, v_j, j, eps, f)
            else:
                _, _, _, theta_plus, r_plus, grad_plus, theta_prime, grad_prime, logp_prime, n_prime, s_prime = BuildTree(theta_plus, r_plus, grad_plus, u, v_j, j, eps, f)
            
            if s_prime == 1:
                # Use Metropolis-Hastings
                prob = min(1, n_prime/n)
                if (np.random.uniform(0,1,1) < prob):
                    samples[m,:] = theta_prime
                    logp = logp_prime
                    grad = grad_prime
                    
            n = n + n_prime

            boolean_1 = int(np.dot(theta_plus-theta_minus, r_minus) >= 0)
            boolean_2 = int(np.dot(theta_plus-theta_minus, r_plus) >= 0)
            s = s_prime * boolean_1 * boolean_2
            j = j + 1
    return samples

array([[ 0.90665779,  1.94079932],
       [ 0.51203117,  1.08595953],
       [ 1.00571311,  2.00746388]])

In [184]:
X = np.random.poisson(5, 100)
N = 100
def f(theta):
    grad = sum(X)/theta - N
    logp = sum(X)*np.log(theta) - theta*N
    return logp, grad


In [199]:
# simple example
theta0 = np.array([5])
eps = 0.05
M = 5000
results = efficient_nuts(theta0, eps, M, f)
np.mean(results)


2.5851798697107835

In [154]:
# more complicated example 

In [None]:
# try to optimize code

In [None]:
# comparison of NUTS, metropolis, and Gibbs (speed and plots)