In [1]:
import numpy as np
import tensorflow as tf
import autograd
from autograd import numpy as anp

import scipy as sp
from scipy import optimize

In [2]:
from matplotlib import pyplot as plt
%matplotlib inline

In [26]:
def normal_lp_suff(mu, mu2, tau, log_tau, xsum, x2sum, num_obs):
    mu_dim = len(xsum)
    quad_term = \
        tf.linalg.trace(x2sum) - \
        2.0 * tf.tensordot(xsum, mu, 1) + \
        num_obs * tf.linalg.trace(mu2)
    lp = \
        -0.5 * tau * quad_term + \
        0.5 * mu_dim * num_obs * log_tau
    return lp

In [25]:
dim = 3
num_obs = 10000

mu_true = np.arange(dim, dtype=np.float64)
sd_true = 0.5
tau_true = 1 / sd_true ** 2
print(tau_true)

x = np.random.normal(loc=mu_true, scale=sd_true, size=(num_obs, dim))
xsum = np.sum(x, axis=0)
x2sum = x.T @ x

muhat = xsum / num_obs
covhat = x2sum / num_obs - np.outer(muhat, muhat) 
tauhat = 1 / np.mean(np.diag(covhat))
print(tauhat)

quad_term = \
    (tf.linalg.trace(x2sum) - \
     2.0 * tf.tensordot(xsum, muhat, 1) + \
     num_obs * tf.linalg.trace(tf.tensordot(muhat, muhat, 0)))

print(num_obs / quad_term)
print(quad_term / num_obs , np.diag(covhat))

4.0
3.9835313226529028
tf.Tensor(1.3278437742176368, shape=(), dtype=float64)
tf.Tensor(0.7531006428743457, shape=(), dtype=float64) [0.25234554 0.24833973 0.25241537]


In [5]:
normal_lp_suff(
    mu=mu_true,
    mu2=np.outer(mu_true, mu_true),
    tau=tau_true,
    log_tau=np.log(tau_true),
    xsum=xsum,
    x2sum=x2sum,
    num_obs=num_obs)

<tf.Tensor: shape=(), dtype=float64, numpy=-8004.203393872718>

In [6]:
def normal_lp(par, data):
    mu = tf.convert_to_tensor(par['mu'], dtype=tf.float64)
    tau = tf.convert_to_tensor(par['tau'], dtype=tf.float64)
    return normal_lp_suff(
        mu=mu,
        mu2=tf.tensordot(mu, mu, 0),
        tau=tau,
        log_tau=tf.math.log(tau),
        **data)

data = { 'xsum': xsum, 'x2sum': x2sum, 'num_obs': num_obs}
par = { 'mu': mu_true, 'tau': tau_true }
normal_lp(par, data)

<tf.Tensor: shape=(), dtype=float64, numpy=-8004.203393872718>

In [19]:
def flatten_par(par):
    tau = tf.convert_to_tensor(par['tau'], dtype=tf.float64)
    tau = tf.reshape(tau, (1, ))
    return tf.concat([ par['mu'], tf.math.log(tau) ], axis=0)

def fold_par(par_flat):
    par_flat = tf.convert_to_tensor(par_flat)
    par_len = par_flat.get_shape()[0]
    mu_dim = len(par_flat) - 1
    return { 'mu': par_flat[0:mu_dim], 'tau': tf.math.exp(par_flat[mu_dim]) } 

par_flat = flatten_par(par)
par_fold = fold_par(par_flat)
print(par['tau'], par_fold['tau'])
normal_lp(par_fold, data)

4.0 tf.Tensor(4.0, shape=(), dtype=float64)


<tf.Tensor: shape=(), dtype=float64, numpy=-8004.203393872718>

In [14]:
def normal_objective(par_flat, to_numpy=True):
    lp = -1 * normal_lp(fold_par(par_flat), data) / data['num_obs']
    print(lp.numpy())
    if to_numpy:
        return lp.numpy()
    else: return lp

def normal_objective_grad(par_flat, to_numpy=True):
    par_flat_tf = tf.Variable(par_flat)
    with tf.GradientTape() as tape:
        lp = normal_objective(par_flat_tf, to_numpy=False)
    grad = tape.gradient(lp, par_flat_tf)
    if to_numpy:
        return grad.numpy()
    else:
        return grad

def normal_objective_hessian(par_flat, to_numpy=True):
    par_flat_tf = tf.Variable(par_flat)
    with tf.GradientTape() as tape:
        with tf.GradientTape() as gtape:
            lp = normal_objective(par_flat_tf, to_numpy=False)
        grad = gtape.gradient(lp, par_flat_tf)
    hess = tape.jacobian(grad, par_flat_tf)
    if to_numpy:
        return hess.numpy()
    else:
        return hess

lp = normal_objective(par_flat)
grad = normal_objective_grad(par_flat)
hess = normal_objective_hessian(par_flat)

lp, grad, hess

2.461774584468713
2.461774584468713
2.461774584468713


(2.461774584468713,
 array([ 1.19318164, -0.7885926 , -2.39945047,  2.27735079]),
 array([[ 1.87977559,  0.        ,  0.        ,  1.19318164],
        [ 0.        ,  1.87977559,  0.        , -0.7885926 ],
        [ 0.        ,  0.        ,  1.87977559, -2.39945047],
        [ 1.19318164, -0.7885926 , -2.39945047,  2.77735079]]))

In [15]:
par_flat = np.random.random(dim  + 1)
lp = normal_objective(tf.convert_to_tensor(par_flat))
grad = normal_objective_grad(tf.convert_to_tensor(par_flat))
hess = normal_objective_hessian(par_flat)

lp, grad, hess

2.948708499610251
2.948708499610251
2.948708499610251


(2.948708499610251,
 array([ 1.25433344, -1.31664649, -2.07587351,  2.64095621]),
 array([[ 1.46887297,  0.        ,  0.        ,  1.25433344],
        [ 0.        ,  1.46887297,  0.        , -1.31664649],
        [ 0.        ,  0.        ,  1.46887297, -2.07587351],
        [ 1.25433344, -1.31664649, -2.07587351,  3.14095621]]))

In [16]:
opt_result = sp.optimize.minimize(
    x0=np.zeros(dim + 1),
    fun=normal_objective,
    jac=normal_objective_grad,
    hess=normal_objective_hessian,
    method="bfgs")

2.863398098918461
2.863398098918461
1.1139570525411449
1.1139570525411449
0.7653548074088837
0.7653548074088837
0.46728885841601203
0.46728885841601203
0.4333120935951877
0.4333120935951877
0.4110301139189698
0.4110301139189698
0.38071921101806194
0.38071921101806194
0.3597169314778389
0.3597169314778389
0.3544213306579572
0.3544213306579572
0.3539902919460659
0.3539902919460659
0.3539849428004315
0.3539849428004315
0.35398493835143024
0.35398493835143024


In [17]:
lp = normal_objective(opt_result.x)
grad = normal_objective_grad(opt_result.x)
hess = normal_objective_hessian(opt_result.x)
grad, hess

0.35398493835143024
0.35398493835143024
0.35398493835143024


(array([-2.12450497e-10, -1.26470173e-07, -2.54039874e-07,  1.38086059e-07]),
 array([[ 1.33914373e+00,  0.00000000e+00,  0.00000000e+00,
         -2.12450497e-10],
        [ 0.00000000e+00,  1.33914373e+00,  0.00000000e+00,
         -1.26470173e-07],
        [ 0.00000000e+00,  0.00000000e+00,  1.33914373e+00,
         -2.54039874e-07],
        [-2.12450497e-10, -1.26470173e-07, -2.54039874e-07,
          5.00000138e-01]]))

In [18]:
par_opt = fold_par(opt_result.x)
print(par_opt['mu'].numpy(), muhat)
print(par_opt['tau'].numpy(), tauhat, tau_true)

[1.67067895e-03 9.94542538e-01 1.99773160e+00] [1.67067911e-03 9.94542632e-01 1.99773179e+00]
1.339143726279069 4.01743006933528 4.0
