# L3: VI and Gaussian mixture model (GMM)

Here we explore how to use Variational Inference to learn parameters in the GMM


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import networkx as nx

In [None]:
import sys
sys.path.append('../../../src/')
import tools as tl
import plot as viz
# import pysbm


In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
colormap = plt.cm.tab10
colors = {i: colormap(i) for i in range(20)}

In [None]:
outdir_fig = '../figures/'
lecture_id = 3

In [None]:
seed = 10
prng = np.random.RandomState(seed)

# 1. Generate data from GMM
Let's generate some synthetic data from the model.

#### Ground truth parameters

In [None]:
K=3
SAMPLE=1000 # number of data per cluster

#### Sample data

In [None]:
num_components = K
mu_arr = [0,1.,5.]

# Ground truth cluster assignments
c_GT = np.zeros((K * SAMPLE,K))
c_GT[np.arange(SAMPLE),0]=1
c_GT[np.arange(SAMPLE,SAMPLE*2),1]=1
c_GT[np.arange(SAMPLE*2,SAMPLE*3),2]=1

# Ground truth X
assert np.all(np.sum(c_GT,axis=1)==1)
X = prng.normal(loc=mu_arr[0], scale=1, size=SAMPLE)
for i, mu in enumerate(mu_arr[1:]):
    X = np.append(X, prng.normal(loc=mu, scale=1, size=SAMPLE))
X.shape

## 1.1 Plot data

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))
sns.histplot(X[:SAMPLE], ax=ax,kde=True)
sns.histplot(X[SAMPLE:SAMPLE*2], ax=ax,kde=True)
sns.histplot(X[SAMPLE*2:], ax=ax,kde=True)
# sns.distplot(X[:SAMPLE], ax=ax, rug=True)
ax.set_xlabel('x')
ax.set_ylabel('Count')

filename = tl.get_filename("GMMexample",lecture_id=lecture_id)
filename = None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)



# 2. Implement CAVI updates

In [None]:



class UGMM(object):
    '''Univariate GMM with CAVI'''
    def __init__(self, X, K=2, sigma=1):
        self.X = X
        self.K = K
        self.N = self.X.shape[0]
        self.sigma2 = sigma**2

    def _init(self):
        self.phi = np.random.dirichlet([np.random.random()*np.random.randint(1, 10)]*self.K, self.N)
        self.m = np.random.randint(int(self.X.min()), high=int(self.X.max()), size=self.K).astype(float)
        self.m += self.X.max()*np.random.random(self.K)
        self.s2 = np.ones(self.K) * np.random.random(self.K)
        print('Init mean')
        print(self.m)
        print('Init s2')
        print(self.s2)

    def get_elbo(self):
        t1 = np.log(self.s2) - self.m/self.sigma2
        t1 = t1.sum()
        t2 = -0.5*np.add.outer(self.X**2, self.s2+self.m**2)
        t2 += np.outer(self.X, self.m)
        t2 -= np.log(self.phi)
        t2 *= self.phi
        t2 = t2.sum()
        return t1 + t2

    def fit(self, max_iter=100, tol=1e-10):
        self._init()
        self.elbo_values = [self.get_elbo()]
        self.m_history = [self.m]
        self.s2_history = [self.s2]
        print(f"it mean")
        for iter_ in range(1, max_iter+1):
            self._cavi()
            self.m_history.append(self.m)
            self.s2_history.append(self.s2)
            self.elbo_values.append(self.get_elbo())
            if iter_ % 5 == 0:
                print(iter_, self.m_history[iter_])
            if np.abs(self.elbo_values[-2] - self.elbo_values[-1]) <= tol:
                print('ELBO converged with ll %.3f at iteration %d'%(self.elbo_values[-1],
                                                                     iter_))
                break

        if iter_ == max_iter:
            print('ELBO ended with ll %.3f'%(self.elbo_values[-1]))


    def _cavi(self):
        self._update_phi()
        self._update_mu()

    def _update_phi(self):
        t1 = np.outer(self.X, self.m)
        t2 = -(0.5*self.m**2 + 0.5*self.s2)
        exponent = t1 + t2[np.newaxis, :]
        self.phi = np.exp(exponent)
        self.phi = self.phi / self.phi.sum(1)[:, np.newaxis]

    def _update_mu(self):
        self.m = (self.phi*self.X[:, np.newaxis]).sum(0) * (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.m.size == self.K
        #print(self.m)
        self.s2 = (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.s2.size == self.K

# 3. Fit model to data

In [None]:
ugmm = UGMM(X, 3)
ugmm.fit()

## 3.1 Plot results at convergence

In [None]:
inferred_colors=[ [] for k in range(K)]
for i in range(SAMPLE*K):
    q=np.argmax(ugmm.phi[i])
    inferred_colors[q].append(i)

In [None]:

fig, ax = plt.subplots(figsize=(10, 4))
sns.histplot(X[:SAMPLE], ax=ax, kde=True,label='Data')
sns.histplot(prng.normal(ugmm.m[0], 1, SAMPLE),color=colors[0], kde=True,line_kws={'ls':'--'},alpha=0.3,label='Inferred')
sns.histplot(X[SAMPLE:SAMPLE*2], ax=ax, kde=True)
sns.histplot(prng.normal(ugmm.m[1], 1, SAMPLE),ax=ax,color=colors[4] , kde=True,line_kws={'ls':'--'},alpha=0.1)
sns.histplot(X[SAMPLE*2:], ax=ax, kde=True)
sns.histplot(prng.normal(ugmm.m[2], 1, SAMPLE),ax=ax,color=colors[2],kde=True,line_kws={'ls':'--'},alpha=0.1)

plt.figtext(0.15,0.75,f't = 78\n(convergence)',fontsize=14)
plt.legend(loc='best')

ax.set_xlabel('x')
ax.set_ylabel('Count')

filename = tl.get_filename("GMMexample_itConv",lecture_id=lecture_id)
filename = None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)


# 4. Analyze ELBO
How did the ELBO evolve during training?


In [None]:
fs=20

# Focus iteration points
it1=8
it2=15
# ---------

plt.figure()
plt.scatter(np.arange(len(ugmm.elbo_values)),ugmm.elbo_values)
plt.plot(np.arange(len(ugmm.elbo_values)),ugmm.elbo_values, alpha=0.3)
plt.scatter(it1,ugmm.elbo_values[it1],marker='s',facecolors='none',edgecolors='r',s=200, linewidth=3)
plt.scatter(it2,ugmm.elbo_values[it2],marker='s',facecolors='none',edgecolors='r',s=200, linewidth=3)
plt.xlim([-1,25])
# plt.ylim([-850,-100])
plt.xlabel('Iterations',fontsize=fs)
plt.ylabel('ELBO',fontsize=fs)


filename = tl.get_filename("GMMexample_ELBO",lecture_id=lecture_id)
filename = None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)


We have highlighted two interesting points, where the ELBO changes more.

## 4.1 Plot changing points

In [None]:
for it in [0,it1,it2]:
    fig, ax = plt.subplots(figsize=(10, 4))
    
    sns.histplot(X[:SAMPLE], ax=ax, kde=True)
    sns.histplot(prng.normal(ugmm.m_history[it][0], 1, SAMPLE),color=colors[0], kde=True,line_kws={'ls':'--'},alpha=0.1)
    sns.histplot(X[SAMPLE:SAMPLE*2], ax=ax, kde=True)
    sns.histplot(prng.normal(ugmm.m_history[it][1], 1, SAMPLE),ax=ax,color=colors[4] , kde=True,line_kws={'ls':'--'},alpha=0.1)
    sns.histplot(X[SAMPLE*2:], ax=ax, kde=True)
    sns.histplot(prng.normal(ugmm.m_history[it][2], 1, SAMPLE),ax=ax,color=colors[2],kde=True,line_kws={'ls':'--'},alpha=0.1)
    
    plt.figtext(0.8,0.8,f't = {it}',fontsize=20)
    ax.set_xlabel('x')
    ax.set_ylabel('Count')
    
    filename = tl.get_filename(f"GMMexample_it{it}",lecture_id=lecture_id)
    filename = None
    tl.savefig(plt,outfile = filename,outdir = outdir_fig)

# 5. Evaluate model performance

How do we evaluate if results are good?

In [None]:
assert np.all(np.sum(ugmm.phi,axis=1))

## 5.1 Cluster assignments posteriors

We can start by visualizing the posterior distributions on the cluster assignments, to see what samples are more **uncertain**.

In [None]:
sorted_std = np.argsort(np.std(ugmm.phi,axis=1))

In [None]:
L = 3
fig, ax = plt.subplots(1,L,figsize=(10, 4), sharey=True)
for i in range(L):
    
    sns.barplot(ugmm.phi[sorted_std[i]], ax=ax[i])

    msg = f'i = {sorted_std[i]}\nx = {X[sorted_std[i]]:.2f}'
    ax[i].text(0.0,0.46,msg,fontsize=12)
    ax[i].set_xlabel('k')
    ax[i].set_ylabel('P(k)')
    
    # filename = get_filename(f"GMMexample_it{it}",lecture_id=lecture_id)
    # savefig(plt,outfile = filename,outdir = outdir_fig)

Higher uncertainty is placed on samples that fall at the intersection between clusters.  
We can now see what are the samples with **lower uncertainity**.

In [None]:
L = 3
fig, ax = plt.subplots(1,L,figsize=(10, 4), sharey=True)
for i in range(L):
    
    sns.barplot(ugmm.phi[sorted_std[-i-1]], ax=ax[i])

    msg = f'i = {sorted_std[-i-1]}\nx = {X[sorted_std[-i-1]]:.2f}'
    ax[i].text(-0.4,0.9,msg,fontsize=12)
    ax[i].set_xlabel('k')
    ax[i].set_ylabel('P(k)')
    
    # filename = get_filename(f"GMMexample_it{it}",lecture_id=lecture_id)
    # savefig(plt,outfile = filename,outdir = outdir_fig)

Lower uncertainty is placed on samples that fall far from the intersection between clusters

## 5.2 Gaussian centers posteriors

We can also check the posteriors of the gaussian means

In [None]:
ugmm.phi.shape, c_GT.shape

In [None]:
P = np.array([0,2,1]) 
# P0 = tl.CalculatePermuation(ugmm.phi,c_GT) # permutation to match cluster by cluster
# P = np.argmax(P0,axis=0) 


In [None]:
L = ugmm.m.shape[0]

fig, ax = plt.subplots(1,L,figsize=(10, 4), sharey=True)
for i in range(L):
    
    sns.histplot(prng.normal(ugmm.m[P[i]], np.sqrt(ugmm.s2[P[i]]), SAMPLE),color=colors[P[i]], kde=True,line_kws={'ls':'--'},alpha=0.1, ax=ax[i])
    ax[i].axvline(x=ugmm.m[P[i]],color=colors[P[i]],ls='--',label='Estimated')
    ax[i].axvline(x=mu_arr[i],color='black',ls='-.',alpha=0.5,label='GT')
    xlim = ax[i].get_xlim()
    msg = f'k = {i}\nx_GT = {mu_arr[i]:.2f}\nx_est = {ugmm.m[P[i]]:.2f}'
    ax[i].text(xlim[0],100,msg,fontsize=12)
    ax[i].set_xlabel(r'$\mu_k$')
    # ax[i].set_ylabel('P(k)')
plt.legend(loc='best')
    
    # filename = get_filename(f"GMMexample_it{it}",lecture_id=lecture_id)
    # savefig(plt,outfile = filename,outdir = outdir_fig)

## 5.3 Evaluation metrics
We can for instance measure prediction quality in reconstructing the cluster assignments

In [None]:
from sklearn.metrics import log_loss

In [None]:
log_loss(c_GT,ugmm.phi)

We need a **baseline** for comparison, as this bare number is not interpretable (is it good? bad? ). 

For instance, we can build a random **permutation** of the ground truth. The worse performance is expected when we permute the whole ground truth cluster assignment vector.  
Best performance is when we do not permute anything (GT is not manipulated).

We can vary the proportion of manipulated GT entries.

In [None]:
n_samples = c_GT.shape[0]
perm = prng.permutation(np.arange(n_samples))

print('rho','logL')
for rho in np.linspace(0,1,21):
    # rho = 0.9 # permuted %
    n_perm = int(rho * n_samples)
    
    c_GT_perm = np.copy(c_GT)
    c_GT_perm[perm[:n_perm]] = np.zeros((n_perm,K))
    assert np.sum(np.sum(c_GT_perm,axis=1) == 0) == n_perm
    c_GT_perm[perm[:n_perm], prng.choice(np.arange(K), n_perm)] = 1 
    
    assert np.all(np.sum(c_GT_perm,axis=1)==1)

    print(f"{rho:.2f} {log_loss(c_GT,c_GT_perm):.2f}")

Alternatively, we can check accuracy

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
accuracy_score(np.argmax(c_GT,axis=1), np.argmax(ugmm.phi,axis=1))

Don't forget to permute!

In [None]:
accuracy_score(np.argmax(c_GT,axis=1), P[np.argmax(ugmm.phi,axis=1)])

In [None]:
n_samples = c_GT.shape[0]
perm = prng.permutation(np.arange(n_samples))

print('rho','logL')
for rho in np.linspace(0,1,21):
    # rho = 0.9 # permuted %
    n_perm = int(rho * n_samples)
    
    c_GT_perm = np.copy(c_GT)
    c_GT_perm[perm[:n_perm]] = np.zeros((n_perm,K))
    assert np.sum(np.sum(c_GT_perm,axis=1) == 0) == n_perm
    c_GT_perm[perm[:n_perm], prng.choice(np.arange(K), n_perm)] = 1 
    
    assert np.all(np.sum(c_GT_perm,axis=1)==1)

    print(f"{rho:.2f} {accuracy_score(np.argmax(c_GT,axis=1), np.argmax(c_GT_perm,axis=1)):.2f}")

# 6. Appendix: 2D example

In [None]:
from sklearn.mixture import BayesianGaussianMixture

In [None]:
seed = 10
prng = np.random.RandomState(seed)

## 6.1 Generate data

In [None]:
n_samples = 500
n_components = 3
covars = np.array(
    [[[0.7, 0.2], [0.2, 0.1]], [[0.5, 0.0], [0.0, 0.1]], [[0.5, 0.0], [0.0, 0.1]]]
)
samples = np.array([n_samples, n_samples, n_samples])
means = np.array([[0.0, -0.70], [0.0, 0.0], [0.0, 0.70]])
c_GT = np.hstack([np.zeros(n_samples), np.ones(n_samples), 2 * np.ones(n_samples)]).astype(int)


c_GT_vect = np.zeros((n_components * n_samples,n_components))
c_GT_vect[np.arange(n_samples),0]=1
c_GT_vect[np.arange(n_samples,n_samples*2),1]=1
c_GT_vect[np.arange(n_samples*2,n_samples*3),2]=1
assert np.all(np.sum(c_GT_vect,axis=1)==1)

c_GT.shape

In [None]:
X = np.vstack(
    [
        prng.multivariate_normal(means[j], covars[j], samples[j])
        for j in range(n_components)
    ]
)
X.shape

## 6.2 Fit VI-GMM to data

In [None]:
estimator = BayesianGaussianMixture(
            covariance_type = 'diag',
            weight_concentration_prior_type="dirichlet_distribution",
            n_components=1 * n_components,
            reg_covar=0,
            init_params="random",
            max_iter=1500,
            mean_precision_prior=0.8,
            random_state=seed,
        )

estimator.fit(X)

In [None]:
estimator.means_.shape, estimator.covariances_.shape, estimator.weights_

In [None]:
# estimator.predict_proba(X)

In [None]:

def multivariate_gaussian(pos, mu, Sigma):
    """Return the multivariate Gaussian distribution on array pos."""

    n = mu.shape[0]
    Sigma_det = np.linalg.det(Sigma)
    Sigma_inv = np.linalg.inv(Sigma)
    N = np.sqrt((2*np.pi)**n * Sigma_det)
    # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
    # way across all the input variables.
    fac = np.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu)

    return np.exp(-fac / 2) / N


In [None]:
N =  n_components * n_samples
Xs = np.linspace(-2, 2, N)
Ys = np.linspace(-2, 2, N)
Xs, Ys = np.meshgrid(Xs, Ys)

# Pack X and Y into a single 3-dimensional array
pos = np.empty(Xs.shape + (2,))
pos[:, :, 0] = Xs
pos[:, :, 1] = Ys

In [None]:
P = tl.CalculatePermuation(c_GT_vect,estimator.predict_proba(X))
print(P.shape)
P = np.argmax(P,axis=1)
# P = np.array([1,2,0])
P

We have inferred a diagonal covariance, need to transform in 2 X 2 matrix each

In [None]:
if estimator.covariances_.ndim != covars.ndim:
    K = means.shape[0]
    estimated_cov = np.zeros_like(covars)
    for k in range(K):
        np.fill_diagonal(estimated_cov[k], estimator.covariances_[k])
else:
    estimated_cov = np.copy(estimator.covariances_)
estimated_cov


In [None]:
estimator.means_.shape

In [None]:
fig,ax = plt.subplots(1,means.shape[0],figsize=(12,4),sharex=True,sharey=True)
Z_gt = np.zeros(pos.shape[:2])
for k in np.arange(means.shape[0]):
    selected = c_GT == k
    Z_gt = multivariate_gaussian(pos, means[k], covars[k])
    Z_est = multivariate_gaussian(pos, estimator.means_[P[k]], estimated_cov[P[k]])
    
    ax[k].contour(Xs,Ys,Z_est,cmap='Reds',alpha=0.8,label='Estimated')
    ax[k].contourf(Xs,Ys,Z_gt,cmap='Reds',label='GT')
    ax[k].scatter(X[selected,0],X[selected,1],zorder=1,color=colors[k],alpha=0.5,s=10,label='Data')
plt.legend(loc='best')

In [None]:
fig,ax = plt.subplots(1,1,figsize=(4,4))
Z_gt = np.zeros(pos.shape[:2])
k = 0
selected = c_GT == k
Z_gt = multivariate_gaussian(pos, means[k], covars[k])
Z_est = multivariate_gaussian(pos, estimator.means_[P[k]],   estimated_cov[P[k]])
# Z_est = multivariate_gaussian(pos, estimator.means_[P[k]] + np.array([-0.5,-0.2]), 3 * estimated_cov[P[k]])


ax.contour(Xs,Ys,Z_est,cmap='Reds',alpha=0.8,label='Estimated')
ax.contourf(Xs,Ys,Z_gt,cmap='Reds',label='GT')
ax.scatter(X[selected,0],X[selected,1],zorder=1,color=colors[k],alpha=0.5,s=10,label='Data')
plt.legend(loc='best')

filename = tl.get_filename(f"GMMexample_2D_inf",lecture_id=lecture_id)
filename=None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)

fig,ax = plt.subplots(1,1,figsize=(4,4))
Z_gt = np.zeros(pos.shape[:2])
k = 0
selected = c_GT == k
Z_gt = multivariate_gaussian(pos, means[k], covars[k])
Z_est = multivariate_gaussian(pos, estimator.means_[P[1]],   estimated_cov[P[1]])
# Z_est = multivariate_gaussian(pos, estimator.means_[P[k]] + np.array([-0.5,-0.2]), 3 * estimated_cov[P[k]])


ax.contour(Xs,Ys,Z_est,cmap='Reds',alpha=0.8,label='Estimated')
ax.contourf(Xs,Ys,Z_gt,cmap='Reds',label='GT')
ax.scatter(X[selected,0],X[selected,1],zorder=1,color=colors[k],alpha=0.5,s=10,label='Data')
plt.legend(loc='best')

filename = tl.get_filename(f"GMMexample_2D_inf1",lecture_id=lecture_id)
filename=None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)


fig,ax = plt.subplots(1,1,figsize=(4,4))
Z_gt = np.zeros(pos.shape[:2])
k = 0
selected = c_GT == k
Z_gt = multivariate_gaussian(pos, means[k], covars[k])
Z_est = multivariate_gaussian(pos, estimator.means_[P[k]] + np.array([-0.5,-0.2]), 3 * estimated_cov[P[k]])


ax.contour(Xs,Ys,Z_est,cmap='Reds',alpha=0.8,label='Estimated')
ax.contourf(Xs,Ys,Z_gt,cmap='Reds',label='GT')
ax.scatter(X[selected,0],X[selected,1],zorder=1,color=colors[k],alpha=0.5,s=10,label='Data')
plt.legend(loc='best')

filename = tl.get_filename(f"GMMexample_2D_inf2",lecture_id=lecture_id)
filename=None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)

fig,ax = plt.subplots(1,1,figsize=(4,4))
Z_gt = np.zeros(pos.shape[:2])
k = 0
selected = c_GT == k
Z_gt = multivariate_gaussian(pos, means[k], covars[k])
Z_est = multivariate_gaussian(pos, estimator.means_[P[k]] + np.array([-0.5,-0.2]), 0.5 * estimated_cov[P[k]])


ax.contour(Xs,Ys,Z_est,cmap='Reds',alpha=0.8,label='Estimated')
ax.contourf(Xs,Ys,Z_gt,cmap='Reds',label='GT')
ax.scatter(X[selected,0],X[selected,1],zorder=1,color=colors[k],alpha=0.5,s=10,label='Data')
plt.legend(loc='best')

filename = tl.get_filename(f"GMMexample_2D_inf3",lecture_id=lecture_id)
filename=None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)