# MAD Mix GMM extension

To handle multivariate data and learn weights and covariance matrices.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import sys,time
sys.path.insert(1, '../src/')
import madmix
import aux

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 40})

In [2]:
####################
####################
#   auxiliary fns  #
####################
####################
def madmix_gmm_flatten(ws,mus,sigmas):
    """
    Flatten weights, meand, and covariances into 2D array
    
    Inputs:
        ws     : (K,B) array, weights
        mus    : (K,D,B) array, cluster means
        sigmas : (K,D,D,B) array, cluster covariances
    
    Outpus:
        xc     : (K',B) array, flattened values
        
    Note:
    K is the number of clusters, D is data dimension, 
    and B is the number of data points (for vectorizing)
    K'= K (weights) + KxD (means) + KxDxD (covariances)
    """
    K,D,B=mus.shape
    
    flat_mus=mus.reshape(K*D,B)
    flat_sigmas=sigmas.reshape(K*D*D,B)
    return np.vstack((ws,flat_mus,flat_sigmas))


def madmix_gmm_unflatten(xc,K,D):
    """
    Unflatten xc into weights, meand, and covariances
    
    Inputs:
        xc     : (K',B) array, flattened values
    
    Outputs:
        ws     : (K,B) array, weights
        mus    : (K,D,B) array, cluster means
        sigmas : (K,D,D,B) array, cluster covariances
        
    Note:
    K is the number of clusters, D is data dimension, 
    and B is the number of data points (for vectorizing)
    K'= K (weights) + KxD (means) + KxDxD (covariances)
    """
    B=xc.shape[-1]
    
    # recover each flattened var
    ws=xc[:K,:]
    flat_mus=xc[K:(K*D+K),:]
    flat_sigmas=xc[(K*D+K):,:]
    
    # unflatten separately
    mus=flat_mus.reshape(K,D,B)
    sigmas=flat_sigmas.reshape(K,D,D,B)
    
    return ws,mus,sigmas

In [12]:
########################
########################
# target specification #
########################
########################
def lp(xd,xc,axis=None):
    # compute the univariate log joint and conditional target pmfs
    #
    # inputs:
    #    xd     : (N,B) array with labels
    #    xc     : (K',B) array with means
    #    axis   : int (0<axis<N), axis to find full conditional; if None then returns the log joint
    # outputs:
    #   ext_lprb : if axis is None, (B,) array with log joint; else, (B,K) array with d conditionals 
    N,B=xd.shape
    
    ws,mus,sigmas=madmix_gmm_unflatten(xc,K,D)
    lprbs=np.zeros((N,K,B))
    for k in range(K): 
        for b in range(B):
            lprbs[:,k,b]=stats.multivariate_normal(mus[k,:,b],sigmas[k,:,:,b]).logpdf(y)
        # end for
    # end for
    lprbs=lprbs-aux.LogSumExp(np.moveaxis(lprbs,1,0))[:,np.newaxis,:]
    #lprbs=lprbs-aux.LogSumExp2(lprbs,axis=1)
    
    ext_lprb=np.zeros((N,B))
    if axis is None: 
        ext_lprb=np.zeros((N,B))
        for b in range(B): ext_lprb[:,b]=lprbs[np.arange(0,N),xd[:,b],b]
        return np.sum(ext_lprb,axis=0)
    # end if
    return lprbs[axis,:,:].T

In [13]:
y=dat
K=2
xd=np.random.randint(low=0,high=2,size=(y.shape[0],2))
ws_=np.array([[0.6,0.6],[0.4,0.4]])
mus_=np.zeros((2,2,2))
for b in range(2): mus_[:,:,b]=np.array([[2,60],[4.5,80]])
sigmas_=np.zeros((2,2,2,2))
for k in range(2):
    for b in range(2):
        sigmas_[k,:,:,b]=np.eye(2)

xc=madmix_gmm_flatten(ws_,mus_,sigmas_)

In [22]:
np.exp(lp(xd,xc,axis=0))

array([[2.79888842e-79, 1.00000000e+00],
       [2.79888842e-79, 1.00000000e+00]])

## Old Faithful

In [4]:
####################
####################
#  data wrangling  #
####################
####################
of_dat=pd.read_table('https://gist.githubusercontent.com/curran/4b59d1046d9e66f2787780ad51a1cd87/raw/9ec906b78a98cf300947a37b56cfe70d01183200/data.tsv')
dat=np.array(of_dat)
N,D=dat.shape