# G-Wishart Inference

## 1. Static Inference

In [None]:
import numpy as np
import scipy as sp
from sklearn.covariance import GraphLasso
from sklearn.datasets import make_spd_matrix, make_sparse_spd_matrix
from sklearn.linear_model import LassoLars

In [None]:
tol = 1e-8

n_dim = 5
n_samples = 100

np.random.seed(0)
K = make_sparse_spd_matrix(n_dim, alpha=0.75)
Sigma = sp.linalg.pinvh(K)

X = np.random.multivariate_normal(np.zeros(n_dim), Sigma, size=n_samples)

In [None]:
from regain.bayesian import gwishart_inference; reload(gwishart_inference)
from regain.bayesian.gwishart_inference import *
n_samples, n_dim = X.shape
alphas = np.logspace(-2, 0, 20)

# get a series of Markov blankets for vaiours alphas
mdl = GraphLasso(verbose=False)
precisions = [
    mdl.set_params(alpha=a).fit(X).precision_
    for a in alphas]
mblankets = markov_blankets(precisions, tol=tol, unique=1)

In [None]:
normalized_scores = score_blankets(mblankets, X=X, alphas=[0.01, 0.5, 1])

graphs = get_graphs(mblankets, normalized_scores, n_dim=n_dim,
                    n_resampling=200)

nonzeros_all = [np.triu(g, 1) + np.eye(n_dim, dtype=bool) for g in graphs]

# Roverato'02: convert from HIW to G-Wishart (delta + |V| - 1)
d0 = 3 + n_dim - 1
S0 = np.eye(n_dim)  # same as Roverato'02

# Find non-zero elements of upper triangle of G
# make sure diagonal is non-zero
# G = nonzeros_all[1] # probably can discard if all zeros?
res = [GWishartScore(X, G, d0=d0, S0=S0, mode='gl', score_method='diaglaplace')
       for G in nonzeros_all]

sorted(res, key=lambda x: x.score)[::-1][0].P

In [None]:
res[0]

In [None]:
K

## Use GL with fixed graph. Time comparison

In [None]:
from sklearn.utils import Bunch
GWprior = Bunch(d0=d0, S0=S0, lognormconst=0, lognormconstDiag=0)

In [None]:
%timeit gwishart.GWishartFit(X, G, GWprior, mode='covsel')

In [None]:
%timeit gwishart.GWishartFit(X, G, GWprior, mode='gl')

## Plotting of 2-d covariance matrices

In [None]:
from regain.plot import *
Cov = make_spd_matrix(2)
Cov

In [None]:
show2d(np.array([2,4]), Cov, sdwidth=2)