In [159]:
import numpy as np
from scipy.stats import norm

from easydict import EasyDict as edict

from copy import deepcopy

from tqdm import tqdm_notebook

import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib tk
plt.style.use('ggplot')

In [160]:
cmap = cm.Set1

In [161]:
def new_state():
    state = edict({
        'points': [],
        'labels': [],
        'real_dists': [],
        'real_labels'
        'curr_dists': [],
        'curr_labels': []
    })
    return state

In [162]:
def plot_state(s):
    plt.figure()
    ax1 = plt.subplot(2, 1, 1)
    support = np.linspace(min(s.points), max(s.points), len(s.points))
    n = len(s.points)
    plt.scatter(x=s.points, y=[1]*n, c=cmap(s.real_labels))
    
    def plot_dists(ax, dists, labels, h, f=0):
        z = [(p[0].pdf(support), p[1]) for p in zip(dists, labels)]
        m = max([p[0].max() for p in z])
        up = 1 + m * 0.1
        for pdf,lab in z:
            pdf += up
            ax.plot(support, pdf, 
                     color=cmap(lab+f))
            ax.fill_between(support, pdf, [up]*n,
                        alpha=0.2, color=cmap(lab+f), hatch=h)
        ax.plot(support, [up]*n, c='black')
        
        
    plot_dists(ax1, s.real_dists, s.labels, None)
    
    ax2 = plt.subplot(2,1,2)
    ax2.scatter(x=s.points, y=[1]*n, c=cmap(np.array(s.curr_labels) + 5))
    if s.curr_dists:
        plot_dists(ax2, s.curr_dists, s.labels, 'X', 5)    
        
    plt.show()

In [163]:
def gen_cluster():
    n_clusters = 3
    n_points = 100
    params = [(np.random.randint(0,500), np.random.uniform(3,10)) 
              for i in range(n_clusters)]
    dists = [norm(loc=p[0], scale=p[1]) for p in params]
    
    r_labs = np.repeat(range(n_clusters), n_points)
    points = [np.random.normal(params[i][0], params[i][1]) for i in r_labs]
    
    state = new_state()
    state.points = points
    state.real_labels = r_labs
    state.real_dists = dists
    state.labels = range(n_clusters)
    
    return state

In [202]:
def E_step(state):
    def assign_lab(p, dists):
        point_likelihood = [dist.pdf(p) for dist in dists]
        return np.argmax(point_likelihood)
    
    s = deepcopy(state)
    s.curr_labels = [assign_lab(pt, s.curr_dists) for pt in s.points]
    return s

def M_step(state):
    def new_pdf(points):
        if len(points) == 0:
            return norm(np.mean(s.points), np.std(s.points)/5)
        mean = np.mean(points)
        std = np.std(points)
        return norm(loc=mean, scale=std)
    
    s = deepcopy(state)
    
    new_dists = range(len(s.curr_dists))
    new_dists = [new_pdf([x[0] for x in zip(s.points, s.curr_labels)
                         if x[1] == i]) for i in new_dists]
    s.curr_dists = new_dists
    return s
    
        

In [216]:
def EM():
    s = gen_cluster()
    s.curr_labels = np.random.randint(0, len(s.labels),
                                      size=len(s.points))
    min_x, max_x = (min(s.points), max(s.points))
    s.curr_dists = [norm(loc=np.random.uniform(min_x, max_x), 
                    scale=np.random.uniform(1, 2)) 
                    for i in s.real_dists]
    
    while True:
        s = E_step(s)
        yield s
        s = M_step(s)

In [217]:
G = EM()

In [218]:
st = next(G)
plot_state(st)

In [219]:
for i in tqdm_notebook(range(5)):
    st = next(G)
    plot_state(st)


