In [1]:
import autograd

import autograd.numpy as np

import autograd.scipy as sp

from bnpmodeling_runjingdev.modeling_lib import my_slogdet3d
import bnpgmm_runjingdev.gmm_clustering_lib as gmm_lib

from copy import deepcopy

import paragami

The defvjp method is deprecated. See the update guide and tutorial:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md
https://github.com/HIPS/autograd/blob/master/docs/tutorial.md


In [2]:
n = 100
k_approx = 30 
dim = 4

In [3]:
_, vb_params_paragami = gmm_lib.get_vb_params_paragami_object(dim, k_approx)
vb_params_dict = vb_params_paragami.random()

In [4]:
from bnpgmm_runjingdev.gmm_preconditioner_lib import *

# Check log-partition function

In [5]:
mean = vb_params_dict['cluster_params']['centroids'].transpose()
info = vb_params_dict['cluster_params']['cluster_info']

In [6]:
# get paragami objects
mvn_params_paragami, mvn_nat_params_paragami = \
    get_mvn_paragami_objects(k_approx, dim)

In [7]:
# dictionary of parameters
mvn_params_dict = mvn_params_paragami.random()
mvn_params_dict['mean'] = mean
mvn_params_dict['info'] = info

In [8]:
# vector of canonical parameters
param_vec = mvn_params_paragami.flatten(mvn_params_dict, free = True)

# vector of natural parameters
nat_vec = get_nat_vec(param_vec, mvn_params_paragami, mvn_nat_params_paragami)

In [9]:
# get gradient of log partition function
get_grad_log_part = autograd.grad(get_mvn_log_partition)
grad_log_part = get_grad_log_part(nat_vec, mvn_nat_params_paragami)

In [11]:
# fold gradient into dictionary
grad_log_part_folded = mvn_params_paragami.fold(grad_log_part, free = False, validate_value = False)

In [12]:
# get expectations of sufficient statistics
e_x = mean
e_x2 = np.linalg.inv(info) + np.einsum('ki, kj -> kij', mean, mean)

In [13]:
# check they match with gradient of log partition
np.abs(e_x - grad_log_part_folded['mean']).max()

7.771561172376096e-16

In [18]:
np.abs(e_x2 - -grad_log_part_folded['info']).max()

3.9968028886505635e-15

# Check Fisher information

In [19]:
get_hess_nat_param = autograd.jacobian(get_jac_term, 0)

In [None]:
hess_nat_param = get_hess_nat_param(param_vec, mvn_params_paragami, mvn_nat_params_paragami)

In [None]:
hess_nat_param.shape

In [None]:
def get_log_q(vb_free_params, vb_params_paragami): 
    vb_params_dict = vb_free_params
    

In [12]:
k_approx = mean.shape[0]
dim = mean.shape[1]

# get paragami objects
mvn_params_paragami, mvn_nat_params_paragami = \
    get_mvn_paragami_objects(k_approx, dim)

# dictionary of parameters
mvn_params_dict = mvn_params_paragami.random()
mvn_params_dict['mean'] = mean
mvn_params_dict['info'] = info

# vector of parameters
param_vec = mvn_params_paragami.flatten(mvn_params_dict, free = True)
nat_vec = get_nat_vec(param_vec, mvn_params_paragami, mvn_nat_params_paragami)

fishers_info = get_log_part_hess(nat_vec, mvn_nat_params_paragami)

jac_term = get_jac_term(param_vec, mvn_params_paragami, mvn_nat_params_paragami)


In [13]:
fishers_info.shape

(600, 600)

In [14]:
jac_term.shape

(600, 420)

In [4]:
def get_nat_vec(param_vec, mvn_params_paragami, mvn_nat_params_paragami): 
    
    nat_params_dict = {}
    
    mvn_param_dict = mvn_params_paragami.fold(param_vec, free = False)
    
    mean = mvn_param_dict['mean']
    info = mvn_param_dict['info']
        
    nat_params_dict['nat1'] = np.einsum('kij, kj -> ki', info, mean)
    nat_params_dict['neg_nat2'] = 0.5 * info
    
    return mvn_nat_params_paragami.flatten(nat_params_dict, free = False)

In [5]:
def get_mvn_log_partition(nat_vec, mvn_nat_params_paragami): 
    
    nat_params_dict = mvn_nat_params_paragami.fold(nat_vec, free = False)
    
    nat1 = nat_params_dict['nat1']
    neg_nat2 = nat_params_dict['neg_nat2']
    
    nat2_inv = np.linalg.inv(-neg_nat2)

    nat2_inv_nat1 = np.einsum('kij, kj -> ki', nat2_inv, nat1)
    squared_term = np.einsum('ki, ki -> k', nat1, nat2_inv_nat1)

    return (- 0.25 * squared_term - 0.5 * my_slogdet3d(2 * neg_nat2)[1]).sum()

In [6]:
get_jac_term = autograd.jacobian(get_nat_vec, 0)
get_log_part_hess = autograd.hessian(get_mvn_log_partition, 0)

In [7]:
def get_mvn_paragami_objects(k_approx, dim):
    mvn_nat_params_paragami = paragami.PatternDict()
    mvn_nat_params_paragami['nat1'] = \
        paragami.NumericArrayPattern(shape=(k_approx, dim))
    mvn_nat_params_paragami['neg_nat2'] = \
        paragami.pattern_containers.PatternArray(array_shape = (k_approx, ), \
                    base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim))
    
    mvn_params_paragami = paragami.PatternDict()
    mvn_params_paragami['mean'] = \
        paragami.NumericArrayPattern(shape=(k_approx, dim))
    mvn_params_paragami['info'] = \
        paragami.pattern_containers.PatternArray(array_shape = (k_approx, ), \
                    base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim))
    
    return mvn_params_paragami, mvn_nat_params_paragami

In [16]:
def get_fishers_info(mean, info): 
    assert mean.shape[0] == info.shape[0]
    assert mean.shape[1] == info.shape[1]
    
    k_approx = mean.shape[0]
    dim = mean.shape[1]
    
    # get paragami objects 
    mvn_params_paragami, mvn_nat_params_paragami = \
        get_mvn_paragami_objects(k_approx, dim)
    
    # dictionary of parameters 
    mvn_params_dict = mvn_params_paragami.random()
    mvn_params_dict['mean'] = mean
    mvn_params_dict['info'] = info
    
    # vector of parameters 
    param_vec = mvn_params_paragami.flatten(mvn_params_dict, free = False)
    nat_vec = get_nat_vec(param_vec, mvn_params_paragami, mvn_nat_params_paragami)
    
    fishers_info = get_log_part_hess(nat_vec, mvn_nat_params_paragami)
    
    jac_term = get_jac_term(nat_vec, mvn_params_paragami, mvn_nat_params_paragami)
    
    return np.dot(jac_term, np.dot(fishers_info, jac_term))

In [17]:
foo = get_fishers_info(vb_params_dict['cluster_params']['centroids'].transpose(), 
                vb_params_dict['cluster_params']['cluster_info'])

In [20]:
get_fishers_info(vb_params_dict['stick_params']['stick_propn_mean'][:, None], 
                    vb_params_dict['stick_params']['stick_propn_info'][:, None, None])

array([[0.0885786 , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.61676455, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.57337277, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.33031977, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.20479928,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.01039804]])

In [8]:
def get_mvn_log_partition(mean, info): 
    info_mean = np.einsum('kij, kj -> ki', info, mean)
    mean_info_mean = np.einsum('ki, kj -> k', mean, info_mean)
    
    return (0.5 * mean_info_mean + 0.5 * my_slogdet3d(info)[1]).mean()

In [9]:
def get_log_partition_free(vb_params_free, vb_params_paragami, use_logitnormal_sticks = True): 
    vb_params_dict = vb_params_paragami.fold(vb_params_free, free = True)
    
    cluster_log_part = get_mvn_log_partition(vb_params_dict['cluster_params']['centroids'].transpose(), 
                                            vb_params_dict['cluster_params']['cluster_info'])
    
    if use_logitnormal_sticks: 
        stick_log_part = get_mvn_log_partition(vb_params_dict['stick_params']['stick_propn_mean'][:, None], 
                                            vb_params_dict['stick_params']['stick_propn_info'][:, None, None])
    
    else: 
        raise NotImplementedError()
    
    return cluster_log_part + stick_log_part

In [10]:
get_fishers_info = autograd.hessian(get_log_partition_free, 0)

In [11]:
vb_params_free = vb_params_paragami.flatten(vb_params_dict, free = True)

In [12]:
get_log_partition_free(vb_params_free, vb_params_paragami)

17.78567716300717

In [None]:
get_fishers_info = autograd.hessian(get_mvn_log_partition_free, 0)

In [13]:
get_fishers_info(vb_params_free, vb_params_paragami)

array([[0.17462724, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.15158418, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.08535172, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.00089329, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.00042449,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.00985399]])

In [61]:
NotImplementedError()

NotImplementedError()

In [55]:
cluster_params_paragami = vb_params_paragami['cluster_params']
cluster_params_free = cluster_params_paragami.flatten(vb_params_dict['cluster_params'], free = True)

In [56]:
get_mvn_log_partition_free(cluster_params_free, cluster_params_paragami)

15.191874266798116

In [58]:
get_fishers_info = autograd.hessian(get_mvn_log_partition_free, 0)

In [59]:
get_fishers_info(cluster_params_free, cluster_params_paragami)

array([[ 1.01438486e-01,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  3.85757948e-01,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  1.17817376e-01, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       ...,
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         1.05045462e-02, -8.36064380e-20,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -1.96279143e-19,  1.05045462e-02, -9.24498931e-19],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -3.33375462e-19, -1.11030716e-18,  8.95722405e-02]])