# This notebook implements an expectation maximization algorithm for Gaussian mixtures 

In [44]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.stats import multivariate_normal
%matplotlib notebook

# Define model parameters

In [95]:
d = 2 # data dimension (2d)
N = 10**5 #Number of samples
K = 5 #number of hidden variables

# Define centers and variances of Gaussian variables

In [96]:
xmax = 10.
xmin=-xmax
mu_k = xmax*2*(np.random.rand(K*d).reshape((K,d))-0.5)
cov_mat = np.zeros((K,d,d))
for k in range(K):
    rd_matrix = np.random.randn(d,d)
    rd_matrix = rd_matrix.T@rd_matrix
    cov_mat[k,:,:] = rd_matrix
    del rd_matrix

# Visualize PDF in 2D and 1D 

In [97]:
if(d==2):
    X,Y = np.meshgrid(np.linspace(1.5*xmin,1.5*xmax,100),np.linspace(1.5*xmin,1.5*xmax,100),indexing="ij")
    pos = np.dstack((X,Y))
    cum_distrib = np.zeros_like(X)
    for k in range(K):
        rd = multivariate_normal(mean=mu_k[k,:],cov=cov_mat[k,...])
        add = rd.pdf(pos)
        cum_distrib+=add
    plt.figure()
    plt.imshow(cum_distrib,cmap="Greys",aspect="auto",origin="lower",extent=[X.min(),X.max(),Y.min(),Y.max()])
    plt.colorbar()
if(d==1):
    X=np.linspace(1.5*xmin,1.5*xmax,100)
    
    cum_distrib = np.zeros_like(X)
    for k in range(K):
        rd = multivariate_normal(mean=mu_k[k,:],cov=cov_mat[k,...])
        add = rd.pdf(X)
        cum_distrib+=add
    plt.figure()
    plt.plot(X,cum_distrib)

<IPython.core.display.Javascript object>

# Define hidden variables probabilities

In [125]:
pi_k = np.random.rand(K)
pi_k /= np.sum(pi_k)
cum_sum_pi_k = np.cumsum(pi_k)[0:-1] #np.concatenate((np.array([0]),np.cumsum(pi_k)))

# Construct data samples

In [30]:
#Sample Zi
z = np.random.rand(N)
Z = np.zeros((K,N))
for i in range(N):
    index = 

In [110]:
cum_sum_pi_k

array([0.09498065, 0.28935886, 0.54791766, 0.76256798, 1.        ])

In [120]:
np.where(0.93<cum_sum_pi_k)

(array([4]),)

In [126]:
cum_sum_pi_k

array([0.27369507, 0.46315726, 0.67855589, 0.77279306])