In [2]:
import numpy as np
import pandas as pd
from scipy.stats import multivariate_normal as mvn


In [4]:
data = pd.read_csv('EMGaussian.data', delim_whitespace=True, header=None, names=['x', 'y'])
test = pd.read_csv('EMGaussian.test', delim_whitespace=True, header=None, names=['x', 'y'])

In [42]:
def forward(X, A, pi, mu, sigma):
    (T, p) = X.shape
    alpha = np.ones((T,K))
    # The LOG of the messages alpha are contained in the matrix Alpha. 
    # The t-th row corresponds to the time t
    # The k-th column corresponds to the case where the state takes the value k

    # Computation of the first alpha(q_0)
    for k in range(K):
        # Watch out we directly use logpdf, not pdf
        alpha[0,k] = mvn.logpdf(X[k], mu[k], sigma[k])*pi[k]

    for t in range(1,T):
        for k in range(K):
            # Alpha message formula p9 chp 12.4 of the book
            constant_term = mvn.logpdf(X[t-1], mu[k], sigma[k])
            log_proba_vec = alpha[t-1] + np.log(A[:,k])
            m = max(log_proba_vec)
            alpha[t, k] = np.log(np.exp(log_proba_vec-m).sum()) + m + const
            
    return alpha

In [43]:
X = np.array(data)
(T,p) = data.shape
K = 4 # Number of states (assumed)

#parameters of the MV Gaussian in R^2
mu = np.ones((4,2)) # Mean mu1
sigma = [np.eye(2) for k in range(K)]
pi = 1.0/4 * np.ones(4)

# Transition matrix
A = np.eye(K)*(1/2-1/6) + np.ones((K,K))*1/6

alpha = forward(X, A, pi , mu, sigma)
alpha





array([[ -5.81460025e-01,  -6.07650734e+00,  -8.63118562e+00,
         -1.10910930e+00],
       [ -6.23218063e+01,  -6.31295416e+01,  -6.31342799e+01,
         -6.25807546e+01],
       [ -1.23802199e+02,  -1.24073679e+02,  -1.24074867e+02,
         -1.23905209e+02],
       ..., 
       [ -3.04318016e+04,  -3.04318016e+04,  -3.04318016e+04,
         -3.04318016e+04],
       [ -3.04930296e+04,  -3.04930296e+04,  -3.04930296e+04,
         -3.04930296e+04],
       [ -3.05542575e+04,  -3.05542575e+04,  -3.05542575e+04,
         -3.05542575e+04]])

array([[ -5.81460025e-01,  -6.07650734e+00,  -8.63118562e+00,
         -1.10910930e+00],
       [ -6.23218063e+01,  -6.31295416e+01,  -6.31342799e+01,
         -6.25807546e+01],
       [ -1.23802199e+02,  -1.24073679e+02,  -1.24074867e+02,
         -1.23905209e+02],
       ..., 
       [ -3.04318016e+04,  -3.04318016e+04,  -3.04318016e+04,
         -3.04318016e+04],
       [ -3.04930296e+04,  -3.04930296e+04,  -3.04930296e+04,
         -3.04930296e+04],
       [ -3.05542575e+04,  -3.05542575e+04,  -3.05542575e+04,
         -3.05542575e+04]])