In [1]:
import autograd

import autograd.numpy as np

import autograd.scipy as sp

from bnpmodeling_runjingdev.modeling_lib import my_slogdet3d

import bnpgmm_runjingdev.gmm_clustering_lib as gmm_lib
import bnpgmm_runjingdev.gmm_preconditioner_lib as precond_lib

from copy import deepcopy

import paragami

The defvjp method is deprecated. See the update guide and tutorial:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md
https://github.com/HIPS/autograd/blob/master/docs/tutorial.md


In [2]:
dim = 4

# Check log-partition function

In [3]:
mean = np.random.randn(dim)
info = np.random.randn(dim, dim)
info = np.dot(info, info.transpose())

In [4]:
# get paragami objects
mvn_params_paragami, mvn_nat_params_paragami = \
    precond_lib.get_mvn_paragami_objects(dim)

In [5]:
# dictionary of parameters
mvn_params_dict = mvn_params_paragami.random()
mvn_params_dict['mean'] = mean
mvn_params_dict['info'] = info

In [6]:
# vector of free canonical parameters
mvn_free_params = mvn_params_paragami.flatten(mvn_params_dict, free = True)

# vector of natural parameters
nat_vec = precond_lib.get_nat_vec(mvn_free_params, mvn_params_paragami, mvn_nat_params_paragami)

In [7]:
precond_lib.get_mvn_log_partition(nat_vec, mvn_nat_params_paragami)

1.0435665952180575

In [8]:
# get gradient of log partition function
get_grad_log_part = autograd.grad(precond_lib.get_mvn_log_partition)
grad_log_part = get_grad_log_part(nat_vec, mvn_nat_params_paragami)

In [9]:
# fold gradient into dictionary
grad_log_part_folded = mvn_params_paragami.fold(grad_log_part, free = False, validate_value = False)

In [10]:
# get expectations of sufficient statistics
e_x = mean
e_x2 = np.linalg.inv(info) + np.einsum('i, j -> ij', mean, mean)

In [11]:
# check they match with gradient of log partition
np.abs(e_x - grad_log_part_folded['mean']).max()

1.1102230246251565e-16

In [12]:
np.abs(-e_x2 - grad_log_part_folded['info']).max()

3.885780586188048e-16

# Check Fisher information

In [15]:
fishers_info = precond_lib.get_fishers_info(mvn_free_params, dim)

In [16]:
# a slightly different way to get the fishers info

In [17]:
# dA / deta, eta being the natural pameters
get_dA_deta = autograd.grad(precond_lib.get_mvn_log_partition)
dA_deta = get_dA_deta(nat_vec, mvn_nat_params_paragami)

In [18]:
def get_log_partition_free_canon(mvn_free_param, mvn_params_paragami, mvn_nat_params_paragami): 
    # this is A(eta(theta)), where eta(theta) is the natural parameters as a function of the 
    # free canonical parameters 
    
    nat_param = nat_vec = precond_lib.get_nat_vec(mvn_free_param, mvn_params_paragami, mvn_nat_params_paragami)
    return precond_lib.get_mvn_log_partition(nat_vec, mvn_nat_params_paragami)

In [20]:
get_dA2_dtheta2 = autograd.hessian(get_log_partition_free_canon)
dA2_dtheta2 = get_dA2_dtheta2(mvn_free_params, mvn_params_paragami, mvn_nat_params_paragami)

In [21]:
# this is d^2 eta / dtheta
get_deta2_dtheta2 = autograd.jacobian(precond_lib.get_jac_term, 0)
deta2_dtheta2 = get_deta2_dtheta2(mvn_free_params, mvn_params_paragami, mvn_nat_params_paragami)

In [22]:
# alternatively the fishers info can be computed as 
# d^2A/dtheta^2 - dA/deta d^2eta / dtheta^2
fishers_info2 = dA2_dtheta2 - np.einsum('j, jik -> ik', dA_deta, deta2_dtheta2)

In [23]:
np.abs(fishers_info - fishers_info2).max()

3.1086244689504383e-15

### We also check against sampling

In [63]:
def get_normal_logpdf(x, mean, info): 
    assert x.shape[1] == len(mean)
    assert info.shape[0] == len(mean)
    
    diff = x - mean[None, :]
    
    cross = np.einsum('ni, ij -> nj', diff, info)
    squared = np.einsum('nj, nj -> n', diff, cross)
    
    return -0.5 * squared + 0.5 * np.linalg.slogdet(info)[1]

In [116]:
def get_log_q(x, mvn_free_params, mvn_params_paragami): 
    mvn_params_dict = mvn_params_paragami.fold(mvn_free_params, free = True)
    
    mean = mvn_params_dict['mean']
    info = mvn_params_dict['info']
    
    return get_normal_logpdf(x, mean, info).mean()

In [117]:
x = np.random.multivariate_normal(mean, np.linalg.inv(info), 100000)

In [118]:
get_log_q(x, mvn_free_params, mvn_params_paragami)

-0.208378139780283

In [119]:
get_est_fishers_info = autograd.hessian(get_log_q, argnum = 1)
est_fishers_info = - get_est_fishers_info(x, mvn_free_params, mvn_params_paragami)

In [120]:
score_funs.shape

(14,)

In [124]:
np.abs(fishers_info - est_fishers_info).max()

0.0211458253964827

# OK, now we do it for the full set of vb_params

In [200]:
# get vb parameters
k_approx = 30
vb_params_dict, vb_params_paragami = gmm_lib.get_vb_params_paragami_object(dim, k_approx)
vb_params_dict = vb_params_paragami.random()

vb_free_params = vb_params_paragami.flatten(vb_params_dict, free = True)

In [230]:
def get_gmm_preconditioner(vb_free_params, vb_params_paragami):
    preconditioner = sparse.lil_matrix((len(vb_free_params), len(vb_free_params)))
    
    bool_dict = vb_params_paragami.empty_bool(False)
    
    k_approx = bool_dict['cluster_params']['centroids'].shape[1]
    dim = bool_dict['cluster_params']['centroids'].shape[0]
    
    # get preconditioners for cluster parameters 
    for k in range(k_approx): 
        bool_dict['cluster_params']['centroids'][:, k] = True
        bool_dict['cluster_params']['cluster_info'][k] = True
        
        # get indices
        indx_cluster_params_k = vb_params_paragami.flat_indices(bool_dict, free = True)
        indx_product = np.array(list(product(indx_cluster_params_k, indx_cluster_params_k)))
        
        # get free parameters
        free_params_cluster_params_k = vb_free_params[indx_cluster_params_k]
        
        # fisher information 
        fishers_info_cluster_params_k = precond_lib.get_fishers_info(free_params_cluster_params_k, dim)
        
        # update preconditioner
        preconditioner[indx_product[:, 0], indx_product[:, 1]] = \
            np.linalg.inv(fishers_info_cluster_params_k).flatten()
        
        
        # reset dictionary 
        bool_dict = vb_params_paragami.empty_bool(False)
    
    # get preconditioners for stick parameters 
    for k in range(k_approx - 1): 
        bool_dict['stick_params']['stick_propn_mean'][k] = True
        bool_dict['stick_params']['stick_propn_info'][k] = True
        
        # get indices
        indx_stick_params_k = vb_params_paragami.flat_indices(bool_dict, free = True)
        indx_product = np.array(list(product(indx_stick_params_k, indx_stick_params_k)))
        
        # get free parameters
        free_params_stick_params_k = vb_free_params[indx_stick_params_k]
        
        # fisher information 
        free_params_stick_params_k = precond_lib.get_fishers_info(free_params_stick_params_k, 1)
        
        # update preconditioner
        preconditioner[indx_product[:, 0], indx_product[:, 1]] = \
            np.linalg.inv(free_params_stick_params_k).flatten()
        
        
        # reset dictionary 
        bool_dict = vb_params_paragami.empty_bool(False)

    return preconditioner


In [231]:
get_gmm_preconditioner(vb_free_params, vb_params_paragami)

<478x478 sparse matrix of type '<class 'numpy.float64'>'
	with 5996 stored elements in List of Lists format>

In [221]:
bool_dict

OrderedDict([('cluster_params',
              OrderedDict([('centroids',
                            array([[ True, False, False, False, False, False, False, False, False,
                                    False, False, False, False, False, False, False, False, False,
                                    False, False, False, False, False, False, False, False, False,
                                    False, False, False],
                                   [ True, False, False, False, False, False, False, False, False,
                                    False, False, False, False, False, False, False, False, False,
                                    False, False, False, False, False, False, False, False, False,
                                    False, False, False],
                                   [ True, False, False, False, False, False, False, False, False,
                                    False, False, False, False, False, False, False, False, False,
                   

In [201]:
preconditioner = sparse.lil_matrix((len(vb_free_params), len(vb_free_params)))

In [202]:
k = 0

In [203]:
bool_dict = vb_params_paragami.empty_bool(False)

In [204]:
bool_dict['cluster_params']['centroids'][:, k] = True
bool_dict['cluster_params']['cluster_info'][k] = True

indx_cluster_params_k = vb_params_paragami.flat_indices(bool_dict, free = True)
free_params_cluster_params_k = vb_free_params[indx_cluster_params_k]

In [205]:
fishers_info_cluster_params_k = precond_lib.get_fishers_info(free_params_cluster_params_k, dim)

In [209]:
indx_product = np.array(list(product(indx_cluster_params_k, indx_cluster_params_k)))

In [216]:
preconditioner[indx_product[:, 0], indx_product[:, 1]] = np.linalg.inv(fishers_info_cluster_params_k).flatten()

In [217]:
preconditioner

<478x478 sparse matrix of type '<class 'numpy.float64'>'
	with 196 stored elements in List of Lists format>

In [206]:
free_params_cluster_params_k

array([0.9344494 , 0.04248797, 0.33043887, 0.2112118 , 0.90321367,
       0.81396135, 0.32261851, 0.89816291, 0.32479739, 0.40991716,
       0.63988727, 0.36597507, 0.00829408, 0.64773205])

In [None]:
bool_dict = vb_params_paragami.empty_bool

In [168]:
bool_dict = deepcopy(vb_params_dict)
bool_dict 

In [None]:
free_indx = 

In [None]:
fishers_info_k = 

In [162]:
from itertools import product

In [163]:
np.array(list(product([1, 2, 3], [1, 2, 3])))

array([[1, 1],
       [1, 2],
       [1, 3],
       [2, 1],
       [2, 2],
       [2, 3],
       [3, 1],
       [3, 2],
       [3, 3]])