<a href="https://colab.research.google.com/github/cruxcode/probabilistic_modeling/blob/master/em_discrete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np

In [0]:
# prepare test data 1

# Prepare K = 2 discrete distribution distribution
K = 2 # number of components in mixture
C = 10 # number of categories
comp1, comp2 = np.array([0., 0., 0.1, 0.2, 0.3, 0.2, 0.1, 0.1, 0., 0.]), np.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
pie1, pie2 = 0.3, 0.7
dist1 = pie1*comp1 + pie2*comp2
X1 = np.random.choice(np.arange(C), p=dist1, size=(10000, 1))

In [3]:
dist1

array([0.07, 0.07, 0.1 , 0.13, 0.16, 0.13, 0.1 , 0.1 , 0.07, 0.07])

In [0]:
# The plate notation consists of [(pie)] --[-> (t) --> (x) <-]-- [ (C_{k}) ]
# TODO: Draw the plate notation here

**EM algorithm for mixture of categorical distributions**

---


$N$ data points in data.

$C_k$ = mixture components (multinomial (categorical) distributions).

$t$ = latent variable is the mixture number.

EM Algo for discrete distribution have following parameters:

$\theta_{kc}$ = probability of each of the $C$ categories for mixture component $k$

$\pi_{k}$  = mixture proportions

$q(t_i)$ = variational distribution for each data point $i$

For categorical distribution $q(t_i)$ can be stored in $K*C$ matrix and would not require computing $p(x_i|t_i, \theta, \pi)$ for each data point. In case of mixture of continous distrbution we do not store $q(t_i)$ rather compute it for each data point. We compute $q(t)$ for each $c$ and store them. Here we define $N_c$ as number of data points with $x=c$.

In [0]:
# Init operations
# categorical distribution parameters
thetas = np.array([np.abs(np.random.randn()) for _ in range(K*C)]).reshape((K, C))
thetas = thetas / np.sum(thetas, axis=1).reshape((K, 1))
thetas = thetas[:, 0:C-1]
# mixing proportions
pies = np.array([np.abs(np.random.randn()) for _ in range(K)])
pies = pies / np.sum(pies)
pies = pies[0:K-1]

In [6]:
thetas

array([[0.09452016, 0.04414052, 0.11654581, 0.13735718, 0.10925236,
        0.11169507, 0.03510657, 0.14223729, 0.14198639],
       [0.11717396, 0.02651716, 0.04842748, 0.00178599, 0.16797863,
        0.11603825, 0.09820173, 0.05449376, 0.131717  ]])

In [7]:
pies

array([0.71160887])

In [0]:
# E - Step
# q(t_i) = p(t_i|x_i, \theta)
# A K*C array will help to store q(t_i) efficiently
def e_step():
  q_t = np.zeros((K, C))
  for k in range(K):
      for c in range(C):
          if c == C-1:
              theta = 1-np.sum(thetas[k,:])
          else:
              theta = thetas[k, c]
        
          if k == K-1:
              pie = 1 - np.sum(pies)
          else:
              pie = pies[k]

          q_t[k, c] = theta*pie

  q_t = q_t/np.sum(q_t, axis=0)
  return q_t

In [0]:
# M-step
# Update thetas
def m_step_thetas(q_t):
  N = np.unique(X1)
  for k in range(K-1):
    A = np.zeros((C-1, C-1))
    b = np.zeros((C-1, 1))
    for c in range(C-1):
      # Prepare one row of A in each iteration
      coeff1, coeff2 = N[c]*q_t[k, c], N[C-1]*q_t[k, C-1]
      A[c, :] = np.repeat(np.array([coeff1]), C-1)
      A[c, k] = coeff1 + coeff2
      b[c] = coeff1
    # Solve system of linear equations and update thetas
    thetas[k, :] = np.squeeze(np.linalg.solve(A, b))

In [0]:
# M-step
# Update pies
def m_step_pies(q_t):
  N = np.unique(X1)
  A = np.zeros((K-1, K-1))
  b = np.zeros((K-1, 1))
  for k in range(K-1):
    coeff1 = np.sum(np.array([N[c]*q_t[k, c] for c in range(C)]))
    coeff2 = np.sum(np.array([N[c]*q_t[K-1, c] for c in range(C)]))
    A[k, :] = np.repeat(np.array([coeff1]), K-1)
    A[k, k] = coeff1 + coeff2
    b[k] = coeff1
  pies = np.squeeze(np.linalg.solve(A, b), axis=0)

In [0]:
def em_run(n_iter=100):
  for i in range(n_iter):
    print("Iter ", i)
    q_t = e_step()
    print("E-step", q_t)
    m_step_thetas(q_t)
    m_step_pies(q_t)

In [12]:
em_run()

Iter  0
E-step [[0.66560294 0.80420668 0.85587308 0.99475814 0.61610172 0.70371749
  0.46868485 0.86560218 0.72676804 0.41081483]
 [0.33439706 0.19579332 0.14412692 0.00524186 0.38389828 0.29628251
  0.53131515 0.13439782 0.27323196 0.58918517]]


LinAlgError: ignored

In [0]:
pies

In [0]:
thetas

In [0]:
# Lower bound should increase at each iteration
# Otherwise their is implementation glitch