# `nb07`: Expectation-Maximization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy.stats import multinomial

In [None]:
from matplotlib.patches import Ellipse

def _plot_cov(ax, cov, pos):
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:,order]

    vals, vecs = eigsorted(cov)
    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
    
    w, h = 4 * np.sqrt(vals)
    ellip = Ellipse(xy=pos, width=w, height=h, 
                    angle=theta, alpha=0.25)
    ax.add_artist(ellip)

def make_plot(x, r=None, mu=None, sigma=None, ll=None):
    fig = plt.figure()
    ax = plt.axes()
    
    if r is None:
        ax.scatter(x[:, 0], x[:, 1], alpha=0.2)
    else:
        clusters = np.argmax(r, axis=1)
        ax.scatter(x[:, 0], x[:, 1], alpha=0.2, c=clusters)
    
    if mu is not None:
        ax.scatter(mu[:, 0], mu[:, 1])
        
    if sigma is not None:
        n_clusters = sigma.shape[0]
        
        for k in range(n_clusters):
            _plot_cov(ax, sigma[k], mu[k])
            
    if ll is not None:
        ax.set(title=r"$ll = {}$".format(ll))
                 
    plt.show()

In [None]:
# Load data
x = np.loadtxt("data/gmm.csv", delimiter=",")

In [None]:
make_plot(x)

# Expectation-Maximization

In [None]:
# Log-likelihood p(x | pi, mu, sigma)   
def ll(x, pi, mu, sigma):
    # ...

In [None]:
# E-step 
def e_step(x, pi, mu, sigma):
    # Compute the posterior r[i, j] = p(z=j | x_i), for i=1->N, j=1->K
    # ...
    
    return r

In [None]:
# M-step
def m_step(x, r):
    # Maximize E_r [ log p(x, z | pi, mu, sigma) ]
    # ...
        
    return pi, mu, sigma

In [None]:
# Initialization
n_samples = len(x)
n_clusters = 4

# ...

In [None]:
# Initial assignments
r = e_step(x, pi, mu, sigma)
make_plot(x, r, mu, sigma, ll=ll(x, pi, mu, sigma))

In [None]:
# Iterate manually until convergence
# ...
make_plot(x, r, mu, sigma, ll=ll(x, pi, mu, sigma))

In [None]:
# Simulate new data
def simulate(n_samples, pi, mu, sigma):
    # ...        
    return x

make_plot(simulate(5000, pi, mu, sigma))