# MAD Mix GMM extension

To handle multivariate data and learn weights and covariance matrices.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import sys,time
sys.path.insert(1, '../src/')
import madmix
import aux

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 40})

In [30]:
####################
####################
#   auxiliary fns  #
####################
####################
def madmix_gmm_flatten(ws,mus,sigmas):
    """
    Flatten weights, meand, and covariances into 2D array
    
    Inputs:
        ws     : (K,B) array, weights
        mus    : (K,D,B) array, cluster means
        sigmas : (K,D,D,B) array, cluster covariances
    
    Outpus:
        xc     : (K',B) array, flattened values
        
    Note:
    K is the number of clusters, D is data dimension, 
    and B is the number of data points (for vectorizing)
    K'= K (weights) + KxD (means) + KxDxD (covariances)
    """
    
    flat_mus=mus.reshape(K*D,B)
    flat_sigmas=sigmas.reshape(K*D*D,B)
    return np.vstack((ws,flat_mus,flat_sigmas))


def madmix_gmm_unflatten(xc,K,D):
    """
    Unflatten xc into weights, meand, and covariances
    
    Inputs:
        xc     : (K',B) array, flattened values
    
    Outputs:
        ws     : (K,B) array, weights
        mus    : (K,D,B) array, cluster means
        sigmas : (K,D,D,B) array, cluster covariances
        
    Note:
    K is the number of clusters, D is data dimension, 
    and B is the number of data points (for vectorizing)
    K'= K (weights) + KxD (means) + KxDxD (covariances)
    """
    
    # recover each flattened var
    ws=xc[:K,:]
    flat_mus=xc[K:(K*D+K),:]
    flat_sigmas=xc[(K*D+K):,:]
    
    # unflatten separately
    mus=flat_mus.reshape(K,D,B)
    sigmas=flat_sigmas.reshape(K,D,D,B)
    
    return ws,mus,sigmas

## Old Faithful

In [2]:
####################
####################
#  data wrangling  #
####################
####################
of_dat=pd.read_table('https://gist.githubusercontent.com/curran/4b59d1046d9e66f2787780ad51a1cd87/raw/9ec906b78a98cf300947a37b56cfe70d01183200/data.tsv')
dat=np.array(of_dat)