# MAD Mix GMM extension

To handle multivariate data and learn weights and covariance matrices.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import sys,time
sys.path.insert(1, '../src/')
import madmix
import aux

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 40})

In [19]:
####################
####################
#   auxiliary fns  #
####################
####################
def madmix_gmm_flatten(ws,mus,sigmas):
    """
    Flatten weights, meand, and covariances into 2D array.
    
    Inputs:
        ws     : (K,B) array, weights
        mus    : (K,D,B) array, cluster means
        sigmas : (K,D,D,B) array, cluster covariances
    
    Outpus:
        xc     : (K',B) array, flattened values
        
    Note:
    K is the number of clusters and B is the number of data points (for vectorizing)
    K'= K (weights) + KxD (means) + KxDxD (covariances)
    """
    
    flat_mus=mus.reshape(K*D,B)
    flat_sigmas=sigmas.reshape(K*D*D,B)
    return np.vstack((ws,flat_mus,flat_sigmas))

In [20]:
K=3
D=2
B=4
ws=0.1*np.arange(K*B).reshape(K,B)
mus=np.arange(K*D*B).reshape(K,D,B)
sigmas=10*np.arange(K*D*D*B).reshape(K,D,D,B)

In [22]:
print(ws)
print(mus)
print(sigmas)
print()
xc=madmix_gmm_flatten(ws,mus,sigmas)
print(xc.shape)

[[0.  0.1 0.2 0.3]
 [0.4 0.5 0.6 0.7]
 [0.8 0.9 1.  1.1]]
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]]
[[[[  0  10  20  30]
   [ 40  50  60  70]]

  [[ 80  90 100 110]
   [120 130 140 150]]]


 [[[160 170 180 190]
   [200 210 220 230]]

  [[240 250 260 270]
   [280 290 300 310]]]


 [[[320 330 340 350]
   [360 370 380 390]]

  [[400 410 420 430]
   [440 450 460 470]]]]

(21, 4)


## Old Faithful

In [2]:
####################
####################
#  data wrangling  #
####################
####################
of_dat=pd.read_table('https://gist.githubusercontent.com/curran/4b59d1046d9e66f2787780ad51a1cd87/raw/9ec906b78a98cf300947a37b56cfe70d01183200/data.tsv')
dat=np.array(of_dat)