In [92]:
import pandas as pd
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import os
dir_path = os.getcwd()


In [93]:
d_types = ['data', 'test']
data = dict()
for d_type in d_types:
        path = dir_path +"/"+"classification_data_HWK2/EMGaussian." + d_type 
        data[d_type] = pd.read_csv(path, sep=' ', header=None).values

In [94]:
K = 4

In [95]:
a = np.ones((K,K))/K

In [96]:
# We use the parameters found in HW2 - Q3 (We use our data)
pi = [0.20, 0.24, 0.25, 0.31]
centers = [[3.79, -3.64], [ 3.99, -3.64 ], [ -2.03, 4.16 ], [ -3.08, -3.56 ]]

sigma_0 = np.matrix([[ 0.87, 0.06 ], [ 0.06, 2.21]])
sigma_1 = np.matrix([[ 0.20, 0.22 ], [ 0.22, 10.40]])
sigma_2 = np.matrix([[ 2.92, 0.17 ], [ 0.17, 2.77]])
sigma_3 = np.matrix([[ 6.14, 5.94 ], [ 5.94, 6.07]])

sigmas = [sigma_0, sigma_1, sigma_2, sigma_3]

multivar = dict()
for k in range(K):
    m = scipy.stats.multivariate_normal(centers[k], sigmas[k])
    multivar[k] =m
    
# We also need to initiate A as it is not specified what is its value
# We follow Mr Chopin's advice to have strong weight on the diagonal
A =  1/10 * np.ones((4,4)) + 6/10 * np.identity(4)

Y_t_data = data['data']
T = len(Y_t_data)-1

In [97]:
def computeAlpha(z, Y_t_data, A,  t): 
    if t == 0:
        return pi[z]
    p_yt_zt = multivar[z].pdf(Y_t_data[t])
    alpha = 0
    for i in range(4):
        alpha += p_yt_zt * A[z, i] * computeAlpha(i, Y_t_data, A, t-1)
        
    return alpha

In [98]:
computeAlpha(2, Y_t_data, A, 8)

9.711462643780658e-22

In [99]:
def computeBeta(z, Y_t_data, A,  t): 
    if t == len(Y_t_data):
        return 1
    beta = 0
    for i in range(4):
        beta += A[z, i] * multivar[i].pdf(Y_t_data[t]) *  computeBeta(i, Y_t_data, A, t+1)
        
    return beta

In [100]:
computeBeta(2, Y_t_data, A, T-3)

2.7999202097087733e-08

In [101]:
# we note that values are very low, we thus need to use the log version to prevent numerical errors
log_A = np.log(A)
log_pi = np.log(pi)

# We will also use a matrix to store the already computed values of logAlpha and logBeta to prevent us from recomputing
# the whole recursivity at every step


In [138]:
def computeLogAlpha(z, Y_t_data, log_A, log_pi, memory_log_alpha, multivar,  t): 
    if t == 0:
        memory_log_alpha[z, t] = log_pi[z]
        return log_pi[z]
    p_yt_zt = multivar[z].pdf(Y_t_data[t])
    log_alpha = 0
    log_inside_sum = [0,0,0,0]
    for i in range(4):
        log_alpha_i_tminus1 = 0 
        if memory_log_alpha[i, t-1] == 0:
            log_alpha_i_tminus1 = computeLogAlpha(i, Y_t_data, log_A, log_pi, memory_log_alpha,multivar,  t-1)
        else:
            log_alpha_i_tminus1 = memory_log_alpha[i, t-1]
        log_inside_sum[i] = log_alpha_i_tminus1 + log_A[z, i]
        
    max_log = np.max(log_inside_sum)
    
    log_alpha += np.log(p_yt_zt) + max_log + np.log(np.sum(np.exp(log_inside_sum - max_log)))
    
    memory_log_alpha[z, t] = log_alpha  
    return log_alpha

In [139]:
memory_log_alpha = np.zeros((4, T + 1))
computeLogAlpha(2, Y_t_data, log_A, log_pi, memory_log_alpha,multivar,  8)

-48.38356514218258

In [142]:
def computeLogBeta(z, Y_t_data, log_A, log_pi, memory_log_beta, multivar,  t): 
    if t == T:  
        memory_log_beta[z, t] = 0
        return 0
    log_beta = 0
    log_inside_sum = [0,0,0,0]
    for i in range(4):
        log_beta_i_tplus1 = 0 
        if memory_log_beta[i, t+1] == 0:
            log_beta_i_tplus1 = computeLogBeta(i, Y_t_data, log_A, log_pi, memory_log_beta,multivar, t+1)
        else:
            log_beta_i_tplus1 = memory_log_beta[i, t+1]
        log_inside_sum[i] = log_A[z, i] + np.log(multivar[i].pdf(Y_t_data[t])) +  log_beta_i_tplus1
    
    max_log = np.max(log_inside_sum)
    log_beta += max_log + np.log(np.sum(np.exp(log_inside_sum - max_log)))  
    memory_log_beta[z, t] = log_beta  
    return log_beta

In [143]:
memory_log_beta = np.zeros((4, T + 1))
computeLogBeta(2, Y_t_data, log_A, log_pi, memory_log_beta,multivar, T-3)

-14.547237885771237

## 3. EM Algorithm

We will use the formulas found computed on the pdf file 
We will also initialize $\Pi_0$, $\mu_i$ and $\Sigma_i$ with the values found in the previous homework
For A, we will follow Mr Chopin's advice and we'll initiate our matrix with strong diagonal weights

In [173]:
def computeSmoothing(memory_log_alpha, memory_log_beta):
    p_zt_i = np.zeros((4, T))
    for t in range(T):
        memory_log_alpha_t = memory_log_alpha[:, t]
        memory_log_beta_t = memory_log_alpha[:, t]
        sum_memory_log_t = memory_log_alpha_t + memory_log_beta_t
        max_sum_log = np.min(-sum_memory_log_t)
        augmented_exp_sum_memory_log_t = np.exp(sum_memory_log_t + max_sum_log)
        for i in range(4):
            p_zt_i[i, t] = augmented_exp_sum_memory_log_t[i] / np.sum(augmented_exp_sum_memory_log_t)
            
    return p_zt_i

def computePijt(memory_log_alpha, memory_log_beta, log_A, multivar, Y_t_data):
    pijt = np.zeros((4,4, T))

    for t in range(T):
        sum_memory_log_t = memory_log_alpha[:, t] + memory_log_beta[:, t]

        max_sum_memory_log_t = np.max(sum_memory_log_t)
        log_denominator = max_sum_memory_log_t + np.log(np.sum(np.exp(sum_memory_log_t - max_sum_memory_log_t)))

        sum_memory_alphat_betat1 = memory_log_alpha[:, t] + memory_log_beta[:, t+1]
        tot = 0
        for i in range(K):
            for j in range(K):
                log_numerator = memory_log_alpha[:, t][i] + memory_log_beta[:, t+1][j] + log_A[i,j] + np.log(multivar[j].pdf(Y_t_data[t+1]))
                value = np.exp(log_numerator - log_denominator)
                pijt[i, j, t] = value
                tot+=value
                
    return pijt

In [174]:
def EM():
    Y_t_data = data['data']
    
    # init
    pi = [0.20, 0.24, 0.25, 0.31]
    centers = [[3.79, -3.64], [ 3.99, -3.64 ], [ -2.03, 4.16 ], [ -3.08, -3.56 ]]

    sigma_0 = np.matrix([[ 0.87, 0.06 ], [ 0.06, 2.21]])
    sigma_1 = np.matrix([[ 0.20, 0.22 ], [ 0.22, 10.40]])
    sigma_2 = np.matrix([[ 2.92, 0.17 ], [ 0.17, 2.77]])
    sigma_3 = np.matrix([[ 6.14, 5.94 ], [ 5.94, 6.07]])

    sigmas = [sigma_0, sigma_1, sigma_2, sigma_3]
    
    multivar = dict()
    for k in range(K):
        m = scipy.stats.multivariate_normal(centers[k], sigmas[k])
        multivar[k] =m
        A =  1/10 * np.ones((4,4)) + 6/10 * np.identity(4)
    
    
    p_z0_i = np.zeros((4,1))
    p_zt_i = np.zeros((4, T))
    p_zt1_j_zt_i = np.zeros((T, T))
    
    for j in range(1):
        # E step 
        
        # we use this memory to avoid unnecessary recomputation
        memory_log_alpha = np.zeros((4, T + 1))
        memory_log_beta = np.zeros((4, T + 1))
        log_A = np.log(A)
        log_pi = np.log(pi)
        for i in range(4):
            # we fill memory_log_alpha and memory_log_beta
            _ = computeLogAlpha(i, Y_t_data, log_A, log_pi, memory_log_alpha,multivar, T)
            _ = computeLogBeta(i, Y_t_data, log_A, log_pi, memory_log_beta,multivar, 0)
        p_zt = computeSmoothing(memory_log_alpha, memory_log_beta)
        pijt = computePijt(memory_log_alpha, memory_log_beta, log_A, multivar, Y_t_data) # todo verify values of pijt
        
        # M step
        
        pi = p_zt[:, 0]
        for i in range(4):
            for j in range(4):
                print(np.sum(pijt[i,j,:]))
                A[i,j] = np.sum(pijt[i,j,:])
#         A = 
        

In [175]:
EM()

7.480075254177476e-22
4.916812919833613e-54
0.265871790552576
7.910783944357342e-65
1.282298615002402e-22
4.130122852660904e-53
0.3190461486631853
9.492940733231611e-65
1.3357277239608007e-22
6.14601614979367e-54
2.326378167335356
9.888479930449338e-65
1.656302377710988e-22
7.621060025742287e-54
0.41210127535650287
8.583200579626785e-64
1.0294310470866773e-34
1.3144300121953075e-23
7.940390695077666e-26
2.6371780943778115e-38
7.504851870999829e-67
4.6954628544340394e-54
4.052143103041503e-57
1.3458056960912125e-69
5.939950738371218e-15
0.005309102227390718
0.0002245037154821439
1.0651801895106814e-17
1.4153184585856646e-77
1.265005504575071e-65
7.641820290586471e-68
1.7766114152037586e-79
4.861254568810684e-15
4.299739656234835e-35
6.88475865762233e-12
2.0302032672499998e-09
9.130788026327007e-05
3.9572923612068404e-23
0.9052044018212834
266.9300443919007
2.183643451943644e-05
1.3519902627495596e-24
1.5153681821389686
63.836773111233
1.2453475244313107e-18
7.710497449900785e-38
1.23460

3.632507395523995e-15
208.0316026966315
8.455618662286115e-09
1.296534849336724e-19
2.391394501785402e-27
1.9564826193734293e-11
3.896620830605589e-20
1.2470001444509043e-14
2.152935547091269e-09
1.192388569255573e-11
1.719860576420162e-11
1.7558327492823085e-22
1.4854011294014745e-15
1.1752560225109684e-18
1.6951491757245674e-18
1.0200661634944983e-07
0.12327951199374892
0.0047794257838481525
0.0009848115186022197
7.648467826867614e-17
9.243511989124168e-11
5.119455643014772e-13
5.168889665771309e-12
6.1457125462948835e-22
3.678349149954741e-69
3.054344632208615e-05
2.770916567895644e-33
1.7424099403626132e-17
5.110073269152827e-63
6.0616963221992375
5.499200906018369e-28
2.412950273926323e-18
1.0109440468429985e-64
5.876114468589821
7.615485899828735e-29
7.031948008739331e-19
2.9461469032169404e-65
0.2446355039365802
1.5535417800059398e-28
6.999696272833869e-20
4.744385337818893e-18
3.996888696271758e-21
2.1487697468296196e-20
2.551210453081535e-66
8.473115472147217e-63
1.01973466804

1.0690628200205465e-50
1.500983739680946e-56
2.7483701806707003e-79
1.57811318934999e-11
0.016172410628510336
0.00037334420312703105
5.419368630532744e-12
2.624232947871333e-15
0.00013177562180633113
4.345819525163371e-07
6.308280083466118e-15
2.2104916626158893e-21
1.5857091474794545e-11
2.562454863155978e-12
5.313705302442312e-21
5.78216020615412e-45
4.1478665068747276e-35
9.575455083707488e-37
9.729639385951783e-44
2.0397381057633018e-29
5.830821445505791e-97
1.2088489189230563e-07
5.949360813972626e-47
5.516986111325759e-22
7.727754270203406e-88
22.887457337201322
1.1264082689069502e-38
2.712312447627849e-23
5.4274157448435295e-90
7.876502482385596
5.5377539606187006e-40
1.0016590847417146e-30
2.0043488323947926e-97
4.155426369729721e-08
1.4315677745748176e-46
1.1408735455023737e-27
4.993318699667573e-23
1.3844780375190643e-25
1.6948681398992994e-24
2.1912273300775321e-94
4.6993229515762e-88
1.8613756870471508e-91
2.2786828413780045e-90
1.2374818000368627e-06
0.3791304238658973
0.0

5.767809463694969e-43
5.442341953448722e-18
2.5451079066230163e-15
4.8728895914696715e-19
2.0525895568661414e-18
3.254068388498219e-50
7.456639209732423e-46
2.0395082266257173e-50
8.590946313339195e-50
0.0001057919454232735
0.3463147728308461
0.00046414046070742223
0.00027929742556360646
5.246219106753627e-32
1.717373823599744e-28
3.2881014623813483e-32
9.695245126813264e-31
4.0862639207323106e-20
6.99365859769136e-60
0.005157458786348619
3.674796249542512e-39
2.8019529267824413e-18
2.349821806142347e-56
2.4755301945852604
1.7638665574546856e-36
2.4646044673276684e-20
2.9527269383570215e-59
0.15242375782912637
1.5515012246350404e-38
1.511931985270115e-20
1.8113747504122938e-59
0.013357944903075533
6.662468767161832e-38
1.1321539019194643e-17
3.3754358612602565e-18
3.142777869581266e-23
2.4576003420973337e-29
1.321506900052027e-56
1.930589356402499e-55
2.5678857291913803e-61
2.0080441916073324e-67
0.21212641287674394
0.4427078067314295
2.885353620723703e-05
3.2232840498014967e-12
1.0378

1.1483706357533211e-75
5.894335528611812e-39
3.881084551617908e-34
2.5219161296485842e-15
4.69496756851354e-48
1.180813886106059e-09
1.1107125063321013e-05
5.074577854710694e-11
9.447173191712608e-44
3.3943192901851903e-06
1.56447627854128
5.570704068097599e-18
8.392064966690563e-43
4.6087135631297205e-15
2.6218734352259902e-11
1.481541841221476e-51
1.0936258079587063e-74
8.579887436412586e-48
4.881053820904694e-44
5.324202851733421e-14
5.614503835955353e-38
2.1583422026877353e-09
1.7540996784343932e-06
3.50485564108343e-09
3.6959571205888735e-33
2.0297278176483265e-05
0.8082911990636777
9.940336253887214e-20
1.129313418705932e-60
2.836167314498672e-15
2.829950651510749e-09
1.4974761123062556e-44
8.336224389864728e-84
2.9908092511587764e-39
2.9842536248089425e-33
8.215863083691405e-17
6.533782090242365e-57
1.148629709717243e-10
1.637302858253332e-05
4.679101813912729e-13
3.7211223359904607e-53
9.345258249636748e-08
0.6527341908151303
2.122337943475691e-24
3.308015152543429e-69
1.267596

9.35929803442829e-15
4.083214199889925e-27
2.1699573145936747e-23
8.301985552115485e-25
4.915036928559273e-30
1.050707303888054e-40
1.1301419006778388
0.017245452950024647
2.352727125454953e-08
4.6372226605710515e-21
0.004979225364337172
0.003723055342111733
7.256018912298595e-10
1.4301605554505857e-22
3.545609031340198e-08
3.78730702620037e-09
3.6168084417102394e-14
1.018389369959182e-27
1.0827999880625605e-19
1.1566125781242504e-20
1.577919579196336e-26
2.1770544679950727e-38
0.9000520462516455
0.0013648261837625694
4.0001966399060157e-10
1.9480237113490557e-30
0.010979594555350681
0.0008158157872464259
3.415844250510266e-11
1.6634546231771448e-31
1.8068863398301878e-08
1.9179558099636072e-10
3.934962806580223e-16
2.7375085850336817e-37
3.561369245696788e-21
3.780286942039643e-23
1.1079719383562348e-29
3.776936638880883e-49
0.8978977414732703
0.0007747441818271002
6.2322885639830814e-09
4.847582132660233e-23
0.0019997450979989573
8.454777153821986e-05
9.71613085837864e-11
7.557375090