In [12]:
import pickle
import numpy as np

from sklearn.mixture import GaussianMixture

In [2]:
def get_values_from_trace(model, trace, thin=1, burn=0):
    """
    :param model: pymc3 model
    :param trace: pymc3 trace object
    :param thin: int
    :param burn: int, number of steps to exclude
    :return: dict: varname --> ndarray
    """
    varnames = [var.name for var in model.vars]
    trace_values = {var: trace.get_values(var, thin=thin, burn=burn) for var in varnames}
    return trace_values


In [76]:
class GaussMix(object):
    def __init__(self, n_components, covariance_type="diag"):
        self._n_components = n_components
        self._vars = []
        self._gm = GaussianMixture(n_components=self._n_components, covariance_type=covariance_type)
    
    def fit(self, sample_dict):
        """
        :param sample_dict: dict, var --> 1d array
        """
        self._vars = list(sample_dict.keys())
        X_train = self._dict_to_array(sample_dict)
        self._gm.fit(X_train)
        return self
    
    def score_samples(self, sample_dict):
        X = self._dict_to_array(sample_dict)
        logp = self._gm.score_samples(X)
        return logp
    
    def sample(self, n_samples=1):
        X = self._gm.sample(n_samples=n_samples)
        X = X[0]
        X_dict = {}
        for i, v in enumerate(self._vars):
            X_dict[v] = X[:, i]
        return X_dict
    
    def get_vars(self):
        return self._vars
    
    def get_model(self):
        return self._gm
    
    def get_gm_fited_params(self):
        weights = self._gm.weights_
        means = self._gm.means_
        covariances = self._gm.covariances_
        
        results = {}
        for i, v in enumerate(self._vars):
            results[v] = {}
            results[v]["weights"] = weights
            results[v]["means"] = [means[j][i] for j in range(self._n_components)]
            results[v]["sigmas"] = [np.sqrt(covariances[j][i]) for j in range(self._n_components)]
        return results
    
    def get_n_components(self):
        return self._n_components
    
    def _dict_to_array(self, sample_dict):
        X = [sample_dict[v] for v in self._vars]
        X = np.stack(X, axis=1)
        return X

In [97]:
def log_normal_pdf(mu, sigma, y):
    sigma2 = sigma * sigma
    res = - 0.5 * np.log(2 * np.pi * sigma2) - (0.5 / sigma2) * (y - mu) ** 2
    return res


def log_mult_normal_pdf(mu_vec, sigma_vec, y_vec):
    logp = 0.
    for mu, sigma, y in zip(mu_vec, sigma_vec, y_vec):
        logp += log_normal_pdf(mu, sigma, y)
    return logp


def log_gm_pdf(weights, mu_mat, sigma_mat, y_vec):
    """
    :param weights: ndarray of shape (n_components,)
    :param mu_mat: ndarray of shape (n_components, n_features)
    :param sigma_mat: ndarray of shape (n_components, n_features)
    :param y_vec: ndarray of shape (n_features,)
    """
    n_components = mu_mat.shape[0]
    assert n_components == len(weights), "wrong weight len"
    
    prop = 0.
    for i, w in enumerate(weights):
        mu_vec = mu_mat[i, :]
        sigma_vec = sigma_mat[i, :]
        prop += w * np.exp(log_mult_normal_pdf(mu_vec, sigma_vec, y_vec))
    logp = np.log(prop)
    return logp


def make_param_mats(var_names, gm_fited_params):
    n_features = len(var_names)
    n_components = len(gm_fited_params[var_names[0]]["weights"])
    
    weights = gm_fited_params[var_names[0]]["weights"]
    mean_mat = np.zeros([n_components, n_features])
    sigma_mat = np.zeros([n_components, n_features])
    
    for i in range(n_features):
        for j in range(n_components):
            mean_mat[j, i] = gm_fited_params[var_names[i]]["means"][j]
            sigma_mat[j, i] = gm_fited_params[var_names[i]]["sigmas"][j]
    return weights, mean_mat, sigma_mat


def logp_gm(sample_dict, gm_fited_params):
    var_names = list(sample_dict)
    weights, mu_mat, sigma_mat = make_param_mats(var_names, gm_fited_params)
    
    nsamples = len(sample_dict[var_names[0]])
    
    logps = []
    for i in range(nsamples):
        y_vec = [sample_dict[v][i] for v in var_names]
        y_vec = np.array(y_vec)
        
        logps.append(log_gm_pdf(weights, mu_mat, sigma_mat, y_vec))
    return np.array(logps)

In [6]:
model = pickle.load(open("data/pm_model.pickle", "rb"), encoding="latin1")
trace = pickle.load(open("data/trace_obj.pickle", "rb"), encoding="latin1")

In [7]:
sample = get_values_from_trace(model, trace, thin=10, burn=1000)

In [8]:
sample.keys()

dict_keys(['P0_interval__', 'Ls_log__', 'rho_interval__', 'DeltaG1_interval__', 'DeltaDeltaG_interval__', 'DeltaH1_interval__', 'DeltaH2_interval__', 'DeltaH_0_interval__', 'log_sigma_interval__'])

In [9]:
vars_redun = ["DeltaDeltaG_interval__", "DeltaH2_interval__", "rho_interval__"]
sample_redun = {v: sample[v] for v in vars_redun}

In [77]:
gm = GaussMix(n_components=2)
gm.fit(sample_redun)

<__main__.GaussMix at 0x1c248637f0>

In [92]:
var_names = gm.get_vars()

In [88]:
gm_params = gm.get_gm_fited_params()
gm_params 

{'DeltaDeltaG_interval__': {'weights': array([0.61743209, 0.38256791]),
  'means': [-2.0987785802756624, -2.072725326086235],
  'sigmas': [0.011334851006068677, 0.012968816658843139]},
 'DeltaH2_interval__': {'weights': array([0.61743209, 0.38256791]),
  'means': [-0.03600911484325556, -0.04360622257306616],
  'sigmas': [0.004091506001760568, 0.0052164789017458265]},
 'rho_interval__': {'weights': array([0.61743209, 0.38256791]),
  'means': [-1.6878444101523138, -1.8079744597974272],
  'sigmas': [0.05067689854942538, 0.06244811469778752]}}

In [86]:
X_test = gm.sample(n_samples=10)
X_test

{'DeltaDeltaG_interval__': array([-2.10588629, -2.1319254 , -2.09933495, -2.10669531, -2.08521986,
        -2.11142978, -2.09508337, -2.0940456 , -2.06823434, -2.09333465]),
 'DeltaH2_interval__': array([-0.03345192, -0.03537208, -0.03189837, -0.03412944, -0.03299124,
        -0.03660914, -0.03200123, -0.03110118, -0.04213963, -0.05291524]),
 'rho_interval__': array([-1.66723132, -1.67807022, -1.73722825, -1.71262908, -1.68888036,
        -1.66292375, -1.67674359, -1.71731014, -1.91959383, -1.64624435])}

In [87]:
gm.score_samples(X_test)

array([9.24760389, 5.41545617, 8.74600932, 9.2540315 , 8.74616119,
       8.96769106, 9.16635635, 8.75098411, 6.96002973, 2.61373642])

In [98]:
logp_gm(X_test, gm_params)

array([9.24760389, 5.41545617, 8.74600932, 9.2540315 , 8.74616119,
       8.96769106, 9.16635635, 8.75098411, 6.96002973, 2.61373642])