# GMM multivariate Concrete test

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import sys,time
sys.path.insert(1, '../src/')
import gibbs
from concrete import *
import aux

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 40})

In [2]:
def gmm_lposterior(xd,w,mus,Sigmas,y,mu0,Sigma0,chol=None):
    """
    Evaluate the log posterior density of a GMM (given data y)
    
    Inputs:
        xd     : (N,B) array, labels (N = # of observations, B = Monte Carlo sample size)
        w      : (K,B) array, weights (K = # of clusters)
        mus    : (K,D,B) array, cluster means (D = dimension of observations)
        Sigmas : (K,D,D,B) array, cluster covariances
        y      : (N,D) array, observations
        mu0    : (K,D,B) array, prior cluster means 
        Sigma0 : (K,D,D,B) array, prior cluster covariances
    
    Outputs:
        lp     : (B,) array, posterior distribution log density up to normalizing constant
    """
    N=xd.shape[0]
    K,D,B=mus.shape
    chol=np.linalg.cholesky(np.moveaxis(Sigmas,3,1)) if chol is None else chol # (K,B,D,D)
    
    #lp = stats.dirichlet(np.ones(K)).logpdf(w) # prior weights
    lp = 0.6931471805599453*np.ones(B) # uninformative dirichlet value; prevents issues when sum(w)<1 but very close
    for k in range(K):
        std_mu = np.squeeze(np.matmul(chol[k,:,:,:],(mus[k,:,:]-mu0[k,:,None]).T[:,:,None]))
        lp += stats.invwishart(df=N/K,scale=Sigma0[k,:,:]*N/K).logpdf(Sigmas[k,:,:,:]) # prior Sigma
        lp += stats.multivariate_normal(mean=np.zeros(D),cov=np.eye(D)).logpdf(std_mu) # prior mu
    # end for
    
    for n in range(N): lp += np.log(w[xd[n,:].astype(int),np.arange(B)]) # prior labels
    
    for n in range(N):
        for k in range(K):
            std_y = np.squeeze(np.matmul(chol[k,:,:,:],(y[n,:,None]-mus[k,:,:]).T[:,:,None]))
            tmplp = stats.multivariate_normal(mean=np.zeros(D),cov=np.eye(D)).logpdf(std_y) # likelihood
            tmplp += np.log(w[k]) # likelihood
            
            idx = (xd[n,:]==k) # label n = cluster k
            ll  = np.zeros(B)
            ll[idx]=tmplp[idx]  # when xn=k, add gaussian lp
            lp += ll
        # end for
    # end for
    
    return lp

## Penguin data set

In [3]:
from palmerpenguins import load_penguins
penguins = load_penguins().dropna()
std_penguins=(penguins-penguins.mean())/penguins.std() # normalize data

  std_penguins=(penguins-penguins.mean())/penguins.std() # normalize data
  std_penguins=(penguins-penguins.mean())/penguins.std() # normalize data


In [4]:
pg_dat=np.array(std_penguins[['bill_length_mm','bill_depth_mm','flipper_length_mm','body_mass_g']])
pg_true=np.squeeze(np.array(penguins[['species']]))
pg_true[pg_true=='Adelie']=0
pg_true[pg_true=='Gentoo']=1
pg_true[pg_true=='Chinstrap']=2

In [21]:
N,D=pg_dat.shape
K=3

mu0=np.array([[-2.,1.,-1.,-1.],  # green
             [1.,1.,-0.5,-0.5],  # purple
             [1.,-1.5,1.5,2.]])  # blue
sigma0=np.zeros((K,D,D))
for k in range(K): sigma0[k,:,:]=0.5*np.eye(D)
w0=np.ones(K)/K

In [6]:
########################
########################
# target specification #
########################
########################
gibbs_path = '../examples/GMM/sockeye_run/penguin/'
pred_x  = aux.pkl_load(gibbs_path+'pred_x')
pred_w = aux.pkl_load(gibbs_path+'pred_w')
pred_mu = aux.pkl_load(gibbs_path+'pred_mu')
pred_sigma = aux.pkl_load(gibbs_path+'pred_sigma')

# convert gibbs output to torch tensors
xs_concrete = torch.from_numpy(pred_x)
ws_concrete = torch.from_numpy(pred_w)
mus_concrete = torch.from_numpy(pred_mu)
sigmas_concrete = torch.from_numpy(pred_sigma)

N,K,D = pred_x.shape[1], pred_mu.shape[1], pred_mu.shape[2] # 333, 3, 4
tau0=0.1

In [7]:
########################
########################
#      settings        #
########################
########################
temp=1.
depth=10
width=8

max_iters=101
lr=1e-5

conc_sample=gmm_concrete_sample(xs_concrete,ws_concrete,mus_concrete,sigmas_concrete,temp)

In [8]:
tmp_flow,tmp_loss=trainGMMRealNVP(
    temp=temp,depth=depth,N=N,K=K,D=D,tau0=tau0,sample=conc_sample,width=width,max_iters=max_iters,lr=lr,seed=2023,verbose=True
)

iter 0: loss = 2159.467
iter 10: loss = 2091.535
iter 20: loss = 2037.234
iter 30: loss = 1994.536
iter 40: loss = 1960.975
iter 50: loss = 1934.192
iter 60: loss = 1912.402
iter 70: loss = 1894.458
iter 80: loss = 1879.511
iter 90: loss = 1866.915
iter 100: loss = 1856.163


In [24]:
sample_size=1000
tmp_sample=tmp_flow.sample(sample_size)
xd_pg,ws_pg,mus_pg,Sigmas_pg=concrete_gmm_unpack(tmp_sample,N,K,D)
ws_pg,mus_pg,Sigmas_pg=ws_pg.detach().numpy(),mus_pg.detach().numpy(),Sigmas_pg.detach().numpy()
ws_pg-=aux.LogSumExp(ws_pg)
ws_pg=np.exp(ws_pg)
xd_labels=np.argmax(xd_pg,axis=1)
idx = np.zeros(sample_size, dtype=bool)
for j in range(sample_size):
    if np.sum(np.isinf(Sigmas_pg[:,:,:,j]))>0: continue
    idx[j] = True
# end for


t0 = time.perf_counter()
llq = tmp_flow.log_prob(tmp_sample).detach().numpy()
llp = gmm_lposterior(xd_labels[:,idx],ws_pg[:,idx],mus_pg[:,:,idx],Sigmas_pg[:,:,:,idx],pg_dat,mu0,sigma0,
                     chol=np.linalg.cholesky(np.moveaxis(Sigmas_pg[:,:,:,idx],3,1)))
print(np.mean(llp-llq))

-40505072.1313283


## Waveform data set

In [28]:
from sklearn.decomposition import PCA
waveform_dat=pd.read_table('https://hastie.su.domains/ElemStatLearn/datasets/waveform.train')
pca = PCA(n_components=4)
pca.fit(waveform_dat[waveform_dat.columns.difference(['row.names','y'])])
waveform_pca=np.array(waveform_dat[waveform_dat.columns.difference(['row.names','y'])])@pca.components_.T

In [29]:
wf_true=np.squeeze(np.array(waveform_dat[['y']]))-1
wf_dat=waveform_pca[:,:2]
N,D=wf_dat.shape
K=3

In [30]:
# initial arrays
mu0=np.array([[-3.,4.],  # blue
              [ 5.,4.],  # purple 
              [ 0.,0.]]) # yellow
sigma0=np.zeros((K,D,D))
for k in range(K): sigma0[k,:,:]=5.*np.eye(D)
w0=np.ones(K)/K

In [25]:
########################
########################
# target specification #
########################
########################
gibbs_path = '../examples/GMM/sockeye_run/waveform/'
pred_x  = aux.pkl_load(gibbs_path+'pred_x')
pred_w = aux.pkl_load(gibbs_path+'pred_w')
pred_mu = aux.pkl_load(gibbs_path+'pred_mu')
pred_sigma = aux.pkl_load(gibbs_path+'pred_sigma')

# convert gibbs output to torch tensors
xs_concrete = torch.from_numpy(pred_x)
ws_concrete = torch.from_numpy(pred_w)
mus_concrete = torch.from_numpy(pred_mu)
sigmas_concrete = torch.from_numpy(pred_sigma)

N,K,D = pred_x.shape[1], pred_mu.shape[1], pred_mu.shape[2] # 300, 3, 2
tau0=0.1

In [26]:
########################
########################
#      settings        #
########################
########################
temp=1.
depth=10
width=8

max_iters=101
lr=1e-5

conc_sample=gmm_concrete_sample(xs_concrete,ws_concrete,mus_concrete,sigmas_concrete,temp)

In [27]:
tmp_flow,tmp_loss=trainGMMRealNVP(
    temp=temp,depth=depth,N=N,K=K,D=D,tau0=tau0,sample=conc_sample,width=width,max_iters=max_iters,lr=lr,seed=2023,verbose=True
)

iter 0: loss = 1788.605
iter 10: loss = 1755.861
iter 20: loss = 1728.707
iter 30: loss = 1705.948
iter 40: loss = 1686.476
iter 50: loss = 1669.641
iter 60: loss = 1654.796
iter 70: loss = 1641.525
iter 80: loss = 1629.609
iter 90: loss = 1618.856
iter 100: loss = 1609.125


In [31]:
sample_size=1000
tmp_sample=tmp_flow.sample(sample_size)
xd_pg,ws_pg,mus_pg,Sigmas_pg=concrete_gmm_unpack(tmp_sample,N,K,D)
ws_pg,mus_pg,Sigmas_pg=ws_pg.detach().numpy(),mus_pg.detach().numpy(),Sigmas_pg.detach().numpy()
ws_pg-=aux.LogSumExp(ws_pg)
ws_pg=np.exp(ws_pg)
xd_labels=np.argmax(xd_pg,axis=1)
idx = np.zeros(sample_size, dtype=bool)
for j in range(sample_size):
    if np.sum(np.isinf(Sigmas_pg[:,:,:,j]))>0: continue
    idx[j] = True
# end for


t0 = time.perf_counter()
llq = tmp_flow.log_prob(tmp_sample).detach().numpy()
llp = gmm_lposterior(xd_labels[:,idx],ws_pg[:,idx],mus_pg[:,:,idx],Sigmas_pg[:,:,:,idx],wf_dat,mu0,sigma0,
                     chol=np.linalg.cholesky(np.moveaxis(Sigmas_pg[:,:,:,idx],3,1)))
print(np.mean(llp-llq))

-208710.63748522804
