In [306]:
import numpy as np
from scipy.stats import norm

from easydict import EasyDict as edict

from copy import deepcopy

from tqdm import tqdm_notebook

import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib tk
plt.style.use('ggplot')

In [307]:
cmap = cm.Set1

In [308]:
def new_state():
    state = edict({
        'points': [],
        'labels': [],
        'real_dists': [],
        'real_labels'
        'curr_dists': [],
        'curr_labels': []
    })
    return state

In [324]:
def plot_state(s):
    plt.figure()
    plt.subplot(2, 1, 1)
    support = np.linspace(min(s.points), max(s.points), len(s.points))
    n = len(s.points)
    plt.scatter(x=s.points, y=[1]*n, c=cmap(s.real_labels))
    
    def plot_dists(dists, labels, h, f=1):
        for dist,lab in zip(dists, labels):
            plt.plot(support, 1.001 + dist.pdf(support), 
                     color=cmap(lab+f))
            plt.fill_between(support, 1.001 + dist.pdf(support), [1.001]*n,
                        alpha=0.2, color=cmap(lab+f), hatch=h)
    plot_dists(s.real_dists, s.labels, None)
    
    plt.subplot(2,1,2)
    plt.scatter(x=s.points, y=[1]*n, c=cmap(np.array(s.curr_labels) + 5))
    if s.curr_dists:
        plot_dists(s.curr_dists, s.labels, 'X', 5)    
    plt.show()

In [325]:
def gen_cluster():
    n_clusters = 3
    n_points = 100
    params = [(np.random.randint(0,500), np.random.uniform(3,10)) 
              for i in range(n_clusters)]
    dists = [norm(loc=p[0], scale=p[1]) for p in params]
    
    r_labs = np.repeat(range(n_clusters), 100)
    print(r_labs)
    points = [np.random.normal(params[i][0], params[i][1]) for i in r_labs]
    
    state = new_state()
    state.points = points
    state.real_labels = r_labs
    state.real_dists = dists
    state.labels = range(n_clusters)
    
    return state

In [326]:
def E_step(state):
    def assign_lab(p, dists):
        point_likelihood = [dist.pdf(p) for dist in dists]
        return np.argmax(point_likelihood)
    
    s = deepcopy(state)
    s.curr_labels = [assign_lab(pt, s.curr_dists) for pt in s.points]
    return s

def M_step(state):
    def new_pdf(points):
        mean = np.mean(points)
        std = np.std(points)
        return norm(loc=mean, scale=std)
    
    s = deepcopy(state)
    
    new_dists = range(len(s.curr_dists))
    new_dists = [new_pdf([x[0] for x in zip(s.points, s.curr_labels)
                         if x[1] == i]) for i in new_dists]
    s.curr_dists = new_dists
    return s
    
        

In [327]:
def EM():
    s = gen_cluster()
    s.curr_labels = np.random.randint(0, len(s.labels),
                                      size=len(s.points))
    min_x, max_x = (min(s.points), max(s.points))
    s.curr_dists = [norm(loc=np.random.uniform(min_x, max_x), 
                    scale=np.random.uniform(1, 2)) 
                    for i in s.real_dists]
    while True:
        s = E_step(s)
        yield s
        s = M_step(s)

In [328]:
G = EM()

In [330]:
st = next(G)
plot_state(st)

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)
  return (self.a <= x) & (x <= self.b)
  return (self.a <= x) & (x <= self.b)


In [331]:
for i in tqdm_notebook(range(8)):
    st = next(G)
plot_state(st)

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)





  return (self.a <= x) & (x <= self.b)
  return (self.a <= x) & (x <= self.b)
