In [242]:
'''
!pip install numpy
!pip install scipy
!pip install seaborn
!pip install matplotlib
'''

'\n!pip install numpy\n!pip install scipy\n!pip install seaborn\n!pip install matplotlib\n'

In [678]:
import numpy as np
import scipy as sp

import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline

## Data Generation

In [683]:
sigma = 3
num_clusters = 5
sample_size = 1000


In [684]:
SEED = 1

In [685]:
np.random.seed(SEED)

cluster_mus = np.random.normal(size=num_clusters, loc=0, scale=sigma)
cluster_choices = np.random.choice(num_clusters, 
                            size=sample_size, 
                            p=[1/num_clusters]*num_clusters)
sample_mus = [cluster_mus[cluster_choice] for cluster_choice in cluster_choices]
X = np.random.normal(size=sample_size, loc=sample_mus, scale=1)

In [686]:
cluster_mus

array([ 4.87303609, -1.83526924, -1.58451526, -3.21890587,  2.59622289])

## Model

In [702]:
class VIGMM:
    
    """
    Variational Inference Gaussian Mixture Model
    """

    def __init__(self, 
                num_clusters,
                sample_size,
                random_seed=1):
        self._num_clusters = num_clusters
        self._sample_size = sample_size
        self._random_seed = random_seed
        
        np.random.seed(self._random_seed)
        
        self._params = None
        
    @property
    def params(self):
        
        return self._params
        
    def _elbo(self, mu_k, s2_k, phis, X, sigma):
        
        t = s2_k + np.square(mu_k)
        a = -(1/(2*(sigma**2)))*sum(t)
        b = 2*sum(sum((phis*mu_k)*X[:,np.newaxis])) - 0.5*sum(sum(phis*t))
        c = -0.5*sum(np.log(2*np.pi*s2_k))
        d = sum(sum(phis*np.log(phis)))
        return a+b+c+d
        
    def fit(self, X, sigma=1, num_iterations=100):
    
        np.random.seed(self._random_seed)
        
        mu_k = np.random.normal(size=self._num_clusters, loc = 0, scale = 1) 
        s2_k = np.random.gamma(size=self._num_clusters, shape=5)
        phis = np.random.dirichlet(size=self._sample_size, alpha=np.repeat(1, self._num_clusters))

        elbos_res = np.zeros(num_iterations+1)

        elbos_res[0] = self._elbo(mu_k, s2_k, phis, X, sigma)

        for i in range(num_iterations):
            phis_updated = np.zeros((self._sample_size, self._num_clusters))
            for j in range(self._sample_size):
                phis_updated[j] = np.exp(X[j]*mu_k - 0.5*(s2_k + np.square(mu_k)))
                phis_updated[j] = phis_updated[j]/np.sum(phis_updated[j])
            phis = phis_updated

            mu_k_updated = np.zeros(num_clusters)
            s2_k_updated = np.zeros(num_clusters)
            for k in range(self._num_clusters):
                s2_k_updated[k] = 1/(1/(sigma**2)+sum(phis[:,k]))
                mu_k_updated[k] = s2_k_updated[k]*sum(phis[:,k]*X)   
            s2_k = s2_k_updated
            mu_k = mu_k_updated

            elbos_res[i+1] = elbo(mu_k, s2_k, phis, X, sigma)
            if i % 10 == 0:
                print(f"Iteration: {i}:", "ELBO Difference: ", np.abs(elbos_res[i+1]-elbos_res[i-1]))
            if np.abs(elbos_res[i+1]-elbos_res[i-1]) < 1e-10:
                break
        
        self._params = {"mu":mu_k,
                       "sigma_2":s2_k,
                       "phi":phis}
    
    

In [703]:
vigmm = VIGMM(num_clusters, sample_size)
    
vigmm.fit(X)


Iteration: 0: ELBO Difference:  12016.870330622976
Iteration: 10: ELBO Difference:  14.857553495998218
Iteration: 20: ELBO Difference:  0.021279346492519835
Iteration: 30: ELBO Difference:  0.01827372305888275
Iteration: 40: ELBO Difference:  0.009023976286698598
Iteration: 50: ELBO Difference:  0.004384191102872137
Iteration: 60: ELBO Difference:  0.002131248515070183
Iteration: 70: ELBO Difference:  0.0010364467507315567
Iteration: 80: ELBO Difference:  0.0005041299318691017
Iteration: 90: ELBO Difference:  0.0002452326134516625


In [704]:
vigmm.params

{'mu': array([ 5.06517401, -1.61555367, -1.61108706, -3.22413929,  2.57812598]),
 'sigma_2': array([0.00511833, 0.00491193, 0.00491211, 0.00496382, 0.00497507]),
 'phi': array([[4.18231281e-18, 7.96014882e-02, 7.87681618e-02, 8.41630349e-01,
         8.84351884e-10],
        [2.05116423e-18, 7.05220663e-02, 6.97551938e-02, 8.59722739e-01,
         5.40523678e-10],
        [5.08349336e-06, 4.60900258e-01, 4.64390996e-01, 9.15912826e-03,
         6.55445338e-02],
        ...,
        [3.29398232e-11, 4.26335248e-01, 4.26044966e-01, 1.47584551e-01,
         3.52354475e-05],
        [1.38303257e-01, 1.55259169e-05, 1.58674701e-05, 2.21187120e-09,
         8.61665348e-01],
        [2.86725614e-15, 2.09861821e-01, 2.08466762e-01, 5.81671341e-01,
         7.64953706e-08]])}

1e-10

In [656]:
cluster_mus

array([ 4.87303609, -1.83526924, -1.58451526, -3.21890587,  2.59622289])

In [657]:
mu_k


array([ 5.10806429, -1.62425388, -1.61792085, -3.25050285,  2.60912006])

In [660]:
mu_k

array([ 5.10806429, -1.62425388, -1.61792085, -3.25050285,  2.60912006])

In [667]:
%%timeit 
np.square(mu_k)

219 ns ± 0.461 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [668]:
%%timeit 
mu_k**2

237 ns ± 0.392 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [487]:
import numpy as np

class UGMM(object):
    '''Univariate GMM with CAVI'''
    def __init__(self, X, K=2, sigma=1):
        self.X = X
        self.K = K
        self.N = self.X.shape[0]
        self.sigma2 = sigma**2

    def _init(self):
        self.phi = np.random.dirichlet([np.random.random()*np.random.randint(1, 10)]*self.K, self.N)
        self.m = np.random.randint(int(self.X.min()), high=int(self.X.max()), size=self.K).astype(float)
        self.m += self.X.max()*np.random.random(self.K)
        self.s2 = np.ones(self.K) * np.random.random(self.K)
        print('Init mean')
        print(self.m)
        print('Init s2')
        print(self.s2)

    def get_elbo(self):
        t1 = np.log(self.s2) - self.m/self.sigma2
        t1 = t1.sum()
        t2 = -0.5*np.add.outer(self.X**2, self.s2+self.m**2)
        t2 += np.outer(self.X, self.m)
        t2 -= np.log(self.phi)
        t2 *= self.phi
        t2 = t2.sum()
        return t1 + t2

    def fit(self, max_iter=10000, tol=1e-50):
        self._init()
        self.elbo_values = [self.get_elbo()]
        self.m_history = [self.m]
        self.s2_history = [self.s2]
        for iter_ in range(1, max_iter+1):
            self._cavi()
            self.m_history.append(self.m)
            self.s2_history.append(self.s2)
            self.elbo_values.append(self.get_elbo())
            if iter_ % 5 == 0:
                print(iter_, self.m_history[iter_])
            #if np.abs(self.elbo_values[-2] - self.elbo_values[-1]) <= tol:
            #    print('ELBO converged with ll %.3f at iteration %d'%(self.elbo_values[-1],
            #                                                         iter_))
             #   break

        if iter_ == max_iter:
            print('ELBO ended with ll %.3f'%(self.elbo_values[-1]))


    def _cavi(self):
        self._update_phi()
        self._update_mu()

    def _update_phi(self):
        t1 = np.outer(self.X, self.m)
        t2 = -(0.5*self.m**2 + 0.5*self.s2)
        exponent = t1 + t2[np.newaxis, :]
        self.phi = np.exp(exponent)
        self.phi = self.phi / self.phi.sum(1)[:, np.newaxis]

    def _update_mu(self):
        self.m = (self.phi*self.X[:, np.newaxis]).sum(0) * (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.m.size == self.K
        #print(self.m)
        self.s2 = (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.s2.size == self.K

In [488]:
num_components = 3
mu_arr = np.random.choice(np.arange(-10, 10, 2),
                          num_components) +\
         np.random.random(num_components)
mu_arr

array([ 0.07567094,  4.80765538, -7.24094278])

In [489]:
SAMPLE = 1000

In [490]:
X = np.random.normal(loc=mu_arr[0], scale=1, size=SAMPLE)
for i, mu in enumerate(mu_arr[1:]):
    X = np.append(X, np.random.normal(loc=mu, scale=1, size=SAMPLE))

In [491]:
ugmm = UGMM(X, 3)
ugmm.fit()


Init mean
[5.42493938 4.54248928 4.98731226]
Init s2
[0.07104972 0.87961869 0.99397203]
5 [ 4.82448762 -7.19264041  0.12632493]
10 [ 4.8022313  -7.19274795  0.10163018]
15 [ 4.80222709 -7.19274797  0.10162582]
20 [ 4.80222709 -7.19274797  0.10162582]
25 [ 4.80222709 -7.19274797  0.10162582]
30 [ 4.80222709 -7.19274797  0.10162582]
35 [ 4.80222709 -7.19274797  0.10162582]
40 [ 4.80222709 -7.19274797  0.10162582]
45 [ 4.80222709 -7.19274797  0.10162582]
50 [ 4.80222709 -7.19274797  0.10162582]
55 [ 4.80222709 -7.19274797  0.10162582]
60 [ 4.80222709 -7.19274797  0.10162582]
65 [ 4.80222709 -7.19274797  0.10162582]
70 [ 4.80222709 -7.19274797  0.10162582]
75 [ 4.80222709 -7.19274797  0.10162582]
80 [ 4.80222709 -7.19274797  0.10162582]
85 [ 4.80222709 -7.19274797  0.10162582]
90 [ 4.80222709 -7.19274797  0.10162582]
95 [ 4.80222709 -7.19274797  0.10162582]
100 [ 4.80222709 -7.19274797  0.10162582]
105 [ 4.80222709 -7.19274797  0.10162582]
110 [ 4.80222709 -7.19274797  0.10162582]
115 [ 4.

1350 [ 4.80222709 -7.19274797  0.10162582]
1355 [ 4.80222709 -7.19274797  0.10162582]
1360 [ 4.80222709 -7.19274797  0.10162582]
1365 [ 4.80222709 -7.19274797  0.10162582]
1370 [ 4.80222709 -7.19274797  0.10162582]
1375 [ 4.80222709 -7.19274797  0.10162582]
1380 [ 4.80222709 -7.19274797  0.10162582]
1385 [ 4.80222709 -7.19274797  0.10162582]
1390 [ 4.80222709 -7.19274797  0.10162582]
1395 [ 4.80222709 -7.19274797  0.10162582]
1400 [ 4.80222709 -7.19274797  0.10162582]
1405 [ 4.80222709 -7.19274797  0.10162582]
1410 [ 4.80222709 -7.19274797  0.10162582]
1415 [ 4.80222709 -7.19274797  0.10162582]
1420 [ 4.80222709 -7.19274797  0.10162582]
1425 [ 4.80222709 -7.19274797  0.10162582]
1430 [ 4.80222709 -7.19274797  0.10162582]
1435 [ 4.80222709 -7.19274797  0.10162582]
1440 [ 4.80222709 -7.19274797  0.10162582]
1445 [ 4.80222709 -7.19274797  0.10162582]
1450 [ 4.80222709 -7.19274797  0.10162582]
1455 [ 4.80222709 -7.19274797  0.10162582]
1460 [ 4.80222709 -7.19274797  0.10162582]
1465 [ 4.80

2785 [ 4.80222709 -7.19274797  0.10162582]
2790 [ 4.80222709 -7.19274797  0.10162582]
2795 [ 4.80222709 -7.19274797  0.10162582]
2800 [ 4.80222709 -7.19274797  0.10162582]
2805 [ 4.80222709 -7.19274797  0.10162582]
2810 [ 4.80222709 -7.19274797  0.10162582]
2815 [ 4.80222709 -7.19274797  0.10162582]
2820 [ 4.80222709 -7.19274797  0.10162582]
2825 [ 4.80222709 -7.19274797  0.10162582]
2830 [ 4.80222709 -7.19274797  0.10162582]
2835 [ 4.80222709 -7.19274797  0.10162582]
2840 [ 4.80222709 -7.19274797  0.10162582]
2845 [ 4.80222709 -7.19274797  0.10162582]
2850 [ 4.80222709 -7.19274797  0.10162582]
2855 [ 4.80222709 -7.19274797  0.10162582]
2860 [ 4.80222709 -7.19274797  0.10162582]
2865 [ 4.80222709 -7.19274797  0.10162582]
2870 [ 4.80222709 -7.19274797  0.10162582]
2875 [ 4.80222709 -7.19274797  0.10162582]
2880 [ 4.80222709 -7.19274797  0.10162582]
2885 [ 4.80222709 -7.19274797  0.10162582]
2890 [ 4.80222709 -7.19274797  0.10162582]
2895 [ 4.80222709 -7.19274797  0.10162582]
2900 [ 4.80

4225 [ 4.80222709 -7.19274797  0.10162582]
4230 [ 4.80222709 -7.19274797  0.10162582]
4235 [ 4.80222709 -7.19274797  0.10162582]
4240 [ 4.80222709 -7.19274797  0.10162582]
4245 [ 4.80222709 -7.19274797  0.10162582]
4250 [ 4.80222709 -7.19274797  0.10162582]
4255 [ 4.80222709 -7.19274797  0.10162582]
4260 [ 4.80222709 -7.19274797  0.10162582]
4265 [ 4.80222709 -7.19274797  0.10162582]
4270 [ 4.80222709 -7.19274797  0.10162582]
4275 [ 4.80222709 -7.19274797  0.10162582]
4280 [ 4.80222709 -7.19274797  0.10162582]
4285 [ 4.80222709 -7.19274797  0.10162582]
4290 [ 4.80222709 -7.19274797  0.10162582]
4295 [ 4.80222709 -7.19274797  0.10162582]
4300 [ 4.80222709 -7.19274797  0.10162582]
4305 [ 4.80222709 -7.19274797  0.10162582]
4310 [ 4.80222709 -7.19274797  0.10162582]
4315 [ 4.80222709 -7.19274797  0.10162582]
4320 [ 4.80222709 -7.19274797  0.10162582]
4325 [ 4.80222709 -7.19274797  0.10162582]
4330 [ 4.80222709 -7.19274797  0.10162582]
4335 [ 4.80222709 -7.19274797  0.10162582]
4340 [ 4.80

5660 [ 4.80222709 -7.19274797  0.10162582]
5665 [ 4.80222709 -7.19274797  0.10162582]
5670 [ 4.80222709 -7.19274797  0.10162582]
5675 [ 4.80222709 -7.19274797  0.10162582]
5680 [ 4.80222709 -7.19274797  0.10162582]
5685 [ 4.80222709 -7.19274797  0.10162582]
5690 [ 4.80222709 -7.19274797  0.10162582]
5695 [ 4.80222709 -7.19274797  0.10162582]
5700 [ 4.80222709 -7.19274797  0.10162582]
5705 [ 4.80222709 -7.19274797  0.10162582]
5710 [ 4.80222709 -7.19274797  0.10162582]
5715 [ 4.80222709 -7.19274797  0.10162582]
5720 [ 4.80222709 -7.19274797  0.10162582]
5725 [ 4.80222709 -7.19274797  0.10162582]
5730 [ 4.80222709 -7.19274797  0.10162582]
5735 [ 4.80222709 -7.19274797  0.10162582]
5740 [ 4.80222709 -7.19274797  0.10162582]
5745 [ 4.80222709 -7.19274797  0.10162582]
5750 [ 4.80222709 -7.19274797  0.10162582]
5755 [ 4.80222709 -7.19274797  0.10162582]
5760 [ 4.80222709 -7.19274797  0.10162582]
5765 [ 4.80222709 -7.19274797  0.10162582]
5770 [ 4.80222709 -7.19274797  0.10162582]
5775 [ 4.80

7100 [ 4.80222709 -7.19274797  0.10162582]
7105 [ 4.80222709 -7.19274797  0.10162582]
7110 [ 4.80222709 -7.19274797  0.10162582]
7115 [ 4.80222709 -7.19274797  0.10162582]
7120 [ 4.80222709 -7.19274797  0.10162582]
7125 [ 4.80222709 -7.19274797  0.10162582]
7130 [ 4.80222709 -7.19274797  0.10162582]
7135 [ 4.80222709 -7.19274797  0.10162582]
7140 [ 4.80222709 -7.19274797  0.10162582]
7145 [ 4.80222709 -7.19274797  0.10162582]
7150 [ 4.80222709 -7.19274797  0.10162582]
7155 [ 4.80222709 -7.19274797  0.10162582]
7160 [ 4.80222709 -7.19274797  0.10162582]
7165 [ 4.80222709 -7.19274797  0.10162582]
7170 [ 4.80222709 -7.19274797  0.10162582]
7175 [ 4.80222709 -7.19274797  0.10162582]
7180 [ 4.80222709 -7.19274797  0.10162582]
7185 [ 4.80222709 -7.19274797  0.10162582]
7190 [ 4.80222709 -7.19274797  0.10162582]
7195 [ 4.80222709 -7.19274797  0.10162582]
7200 [ 4.80222709 -7.19274797  0.10162582]
7205 [ 4.80222709 -7.19274797  0.10162582]
7210 [ 4.80222709 -7.19274797  0.10162582]
7215 [ 4.80

8540 [ 4.80222709 -7.19274797  0.10162582]
8545 [ 4.80222709 -7.19274797  0.10162582]
8550 [ 4.80222709 -7.19274797  0.10162582]
8555 [ 4.80222709 -7.19274797  0.10162582]
8560 [ 4.80222709 -7.19274797  0.10162582]
8565 [ 4.80222709 -7.19274797  0.10162582]
8570 [ 4.80222709 -7.19274797  0.10162582]
8575 [ 4.80222709 -7.19274797  0.10162582]
8580 [ 4.80222709 -7.19274797  0.10162582]
8585 [ 4.80222709 -7.19274797  0.10162582]
8590 [ 4.80222709 -7.19274797  0.10162582]
8595 [ 4.80222709 -7.19274797  0.10162582]
8600 [ 4.80222709 -7.19274797  0.10162582]
8605 [ 4.80222709 -7.19274797  0.10162582]
8610 [ 4.80222709 -7.19274797  0.10162582]
8615 [ 4.80222709 -7.19274797  0.10162582]
8620 [ 4.80222709 -7.19274797  0.10162582]
8625 [ 4.80222709 -7.19274797  0.10162582]
8630 [ 4.80222709 -7.19274797  0.10162582]
8635 [ 4.80222709 -7.19274797  0.10162582]
8640 [ 4.80222709 -7.19274797  0.10162582]
8645 [ 4.80222709 -7.19274797  0.10162582]
8650 [ 4.80222709 -7.19274797  0.10162582]
8655 [ 4.80

9975 [ 4.80222709 -7.19274797  0.10162582]
9980 [ 4.80222709 -7.19274797  0.10162582]
9985 [ 4.80222709 -7.19274797  0.10162582]
9990 [ 4.80222709 -7.19274797  0.10162582]
9995 [ 4.80222709 -7.19274797  0.10162582]
10000 [ 4.80222709 -7.19274797  0.10162582]
ELBO ended with ll -1499.300


In [492]:
sorted(mu_arr)

[-7.240942775613905, 0.0756709413613913, 4.807655381843853]

In [493]:
sorted(ugmm.m)

[-7.192747968443836, 0.10162582204442755, 4.802227092299927]