In [1]:
import numpy as np
import random
import pandas as pd

In [2]:
data = pd.read_csv("/Users/dhritiman/Downloads/3gaussian.txt", sep = ' ', header = None)

In [3]:
data.head()

Unnamed: 0,0,1
0,2.946933,3.162225
1,5.983996,4.846717
2,5.30143,8.168113
3,6.498912,8.594854
4,4.994876,8.675435


In [4]:
n = len(data)
np.random.seed(3)
random_probabilities = np.random.rand(n,3)

In [5]:
random_probabilities

array([[ 0.5507979 ,  0.70814782,  0.29090474],
       [ 0.51082761,  0.89294695,  0.89629309],
       [ 0.12558531,  0.20724288,  0.0514672 ],
       ..., 
       [ 0.83131083,  0.64489214,  0.81788898],
       [ 0.116008  ,  0.75327364,  0.45213295],
       [ 0.09279916,  0.55822492,  0.65017238]])

In [6]:
sum_random_probabilities = np.sum(random_probabilities, axis = 1, keepdims = True)

In [7]:
prob_split = random_probabilities/sum_random_probabilities

In [8]:
prob_split

array([[ 0.35538777,  0.45691364,  0.18769858],
       [ 0.22209243,  0.38822639,  0.38968119],
       [ 0.32679369,  0.53928015,  0.13392615],
       ..., 
       [ 0.36237032,  0.28110998,  0.3565197 ],
       [ 0.08779077,  0.57005095,  0.34215829],
       [ 0.07131833,  0.42900894,  0.49967273]])

In [9]:
# Initializing probability split

cluster1 = prob_split[:,0]
cluster2 = prob_split[:,1]
cluster3 = prob_split[:,2]

In [10]:
# Initializing M-Step values

not_converged = True

mean_1 = np.zeros(2)
mean_2 = np.zeros(2)
mean_3 = np.zeros(2)
cov_1 = np.zeros(shape = [2, 2], dtype = float)
cov_2 = np.zeros(shape = [2, 2], dtype = float)
cov_3 = np.zeros(shape = [2, 2], dtype = float)

In [11]:
iterations = 0
while not_converged:
    mean_old_1 = mean_1
    mean_old_2 = mean_2
    mean_old_3 = mean_3
    cov_old_1 = cov_1
    cov_old_2 = cov_2
    cov_old_3 = cov_3
    
    ##################################
    # M - step
    sum_cluster1 = np.sum(cluster1)
    sum_cluster2 = np.sum(cluster2)
    sum_cluster3 = np.sum(cluster3)
    
    mean_1 = np.array([np.sum(data[0] * cluster1), np.sum(data[1] * cluster1)])/sum_cluster1
    
    mean_2 = np.array([np.sum(data[0] * cluster2), np.sum(data[1] * cluster2)]/sum_cluster2)
    
    mean_3 = np.array([np.sum(data[0] * cluster3), np.sum(data[1] * cluster3)]/sum_cluster3)
    
    w1 = sum_cluster1/n
    w2 = sum_cluster2/n
    w3 = sum_cluster3/n
    
    cov_1 = np.zeros(shape = [2, 2], dtype = float)
    cov_2 = np.zeros(shape = [2, 2], dtype = float)
    cov_3 = np.zeros(shape = [2, 2], dtype = float)
    for i in range(n):        
        tmp_1 = np.matrix(data.iloc[[i]] - mean_1)
        cov_1 += cluster1[i] * (tmp_1.T * tmp_1)
        
        tmp_2 = np.matrix(data.iloc[[i]] - mean_2)
        cov_2 += cluster2[i] * (tmp_2.T * tmp_2)
        
        tmp_3 = np.matrix(data.iloc[[i]] - mean_3)
        cov_3 += cluster3[i] * (tmp_3.T * tmp_3)
        
    
    cov_1 /= sum_cluster1
    cov_2 /= sum_cluster2
    cov_3 /= sum_cluster3
    
    ##################################
    # E - step
    inv_1 = np.linalg.pinv(cov_1)
    inv_2 = np.linalg.pinv(cov_2)
    inv_3 = np.linalg.pinv(cov_3)
    
    det_1 = np.linalg.det(cov_1)
    det_2 = np.linalg.det(cov_2)
    det_3 = np.linalg.det(cov_3)
    
    print("MEAN:", mean_1, mean_2, mean_3)
    print("COV:", cov_1, cov_2, cov_3)
    
    prob_1 = []
    prob_2 = []
    prob_3 = []
    for i in range(n):
        tmp = np.matrix(data.iloc[[i]])
        denominator_1 = (2 * np.pi * np.sqrt(det_1))
        prob_1.append(np.exp(-np.sum((tmp - mean_1) * inv_1 * (tmp - mean_1).T)/2)/denominator_1)
        
        denominator_2 = (2 * np.pi * np.sqrt(det_2))
        prob_2.append(np.exp(-np.sum((tmp - mean_2) * inv_2 * (tmp - mean_2).T)/2)/denominator_2)
        
        denominator_3 = (2 * np.pi * np.sqrt(det_3))
        prob_3.append(np.exp(-np.sum((tmp - mean_3) * inv_3 * (tmp - mean_3).T)/2)/denominator_3)
        
    
    prob_1 = np.array(prob_1)
    prob_2 = np.array(prob_2)
    prob_3 = np.array(prob_3)

    
    total = []    
    for i in range(n):
        tmp_total = prob_1[i] * w1 + prob_2[i] * w2 + prob_3[i] * w3
        cluster1[i] = prob_1[i] * w1/tmp_total
        cluster2[i] = prob_2[i] * w2/tmp_total
        cluster3[i] = prob_3[i] * w3/tmp_total
        total.append(tmp_total)
    
    
    if (np.round(cov_old_1, 5) == np.round(cov_1, 5)).all() and (np.round(cov_old_2, 5) == np.round(cov_2, 5)).all() and (np.round(cov_old_3, 5) == np.round(cov_3, 5)).all() and (np.round(mean_old_1, 5) == np.round(mean_1, 5)).all() and (np.round(mean_old_2, 5) == np.round(mean_2, 5)).all() and (np.round(mean_old_3, 5) == np.round(mean_3, 5)).all():
        not_converged = False
        
    iterations += 1

('MEAN:', array([ 5.21018001,  5.32754853]), array([ 5.2102518 ,  5.28023352]), array([ 5.19792433,  5.28521018]))
('COV:', array([[ 2.93446696,  0.3646659 ],
       [ 0.3646659 ,  4.38966812]]), array([[ 2.99837257,  0.41026997],
       [ 0.41026997,  4.44141977]]), array([[ 2.94783813,  0.39394267],
       [ 0.39394267,  4.50745366]]))
('MEAN:', array([ 5.21804271,  5.33983795]), array([ 5.20654519,  5.27185762]), array([ 5.19381241,  5.28133886]))
('COV:', array([[ 2.91647254,  0.33121411],
       [ 0.33121411,  4.36074864]]), array([[ 3.00402946,  0.4177929 ],
       [ 0.4177929 ,  4.46391534]]), array([[ 2.95985612,  0.41906696],
       [ 0.41906696,  4.51245804]]))
('MEAN:', array([ 5.23399699,  5.36211879]), array([ 5.20016234,  5.2597422 ]), array([ 5.18433562,  5.27126797]))
('COV:', array([[ 2.88847173,  0.27721688],
       [ 0.27721688,  4.30094898]]), array([[ 3.01499923,  0.43768626],
       [ 0.43768626,  4.49408836]]), array([[ 2.9757382 ,  0.45094881],
       [ 0.450948

('MEAN:', array([ 5.97207124,  5.90017491]), array([ 4.82382092,  4.29920201]), array([ 3.58055725,  4.54528098]))
('COV:', array([[ 1.77242007, -1.25033159],
       [-1.25033159,  2.90551934]]), array([[ 3.37928636,  0.7506213 ],
       [ 0.7506213 ,  4.81992602]]), array([[ 1.39335515,  1.55114308],
       [ 1.55114308,  5.79134838]]))
('MEAN:', array([ 5.97215576,  5.90215478]), array([ 4.83981737,  4.23261293]), array([ 3.56422801,  4.57830279]))
('COV:', array([[ 1.76403634, -1.24116437],
       [-1.24116437,  2.90754365]]), array([[ 3.40245921,  0.77750078],
       [ 0.77750078,  4.82487179]]), array([[ 1.35873118,  1.54836148],
       [ 1.54836148,  5.73884669]]))
('MEAN:', array([ 5.9721111,  5.9044976]), array([ 4.85102143,  4.16390946]), array([ 3.55279358,  4.61292494]))
('COV:', array([[ 1.75725188, -1.23373781],
       [-1.23373781,  2.90824127]]), array([[ 3.42391688,  0.81314453],
       [ 0.81314453,  4.8224959 ]]), array([[ 1.3336872 ,  1.54626372],
       [ 1.54626372

('MEAN:', array([ 5.88778782,  6.1516977 ]), array([ 5.65850765,  2.96846683]), array([ 4.0095748 ,  5.96469273]))
('COV:', array([[ 1.65605992, -1.1311777 ],
       [-1.1311777 ,  2.38457874]]), array([[ 3.85480469,  2.30264276],
       [ 2.30264276,  2.28703665]]), array([[ 1.76303284,  1.72675172],
       [ 1.72675172,  3.23728633]]))
('MEAN:', array([ 5.87370583,  6.17641193]), array([ 5.70689616,  2.99907939]), array([ 4.0308899 ,  5.98021784]))
('COV:', array([[ 1.64444683, -1.12401092],
       [-1.12401092,  2.33914787]]), array([[ 3.8264113 ,  2.29521959],
       [ 2.29521959,  2.27829398]]), array([[ 1.77674211,  1.73085898],
       [ 1.73085898,  3.22864676]]))
('MEAN:', array([ 5.85972858,  6.20068421]), array([ 5.74854429,  3.02585721]), array([ 4.04906888,  5.99278441]))
('COV:', array([[ 1.63304247, -1.11613418],
       [-1.11613418,  2.29499279]]), array([[ 3.79744557,  2.2858101 ],
       [ 2.2858101 ,  2.26969628]]), array([[ 1.78849776,  1.73385988],
       [ 1.733859

('MEAN:', array([ 5.2585719,  6.977346 ]), array([ 6.35988721,  3.46081178]), array([ 3.90221609,  5.59022072]))
('COV:', array([[ 1.01533206, -0.19605759],
       [-0.19605759,  0.89990075]]), array([[ 2.83981161,  1.88661768],
       [ 1.88661768,  2.09652754]]), array([[ 1.91149712,  1.94912026],
       [ 1.94912026,  3.98524573]]))
('MEAN:', array([ 5.22268373,  7.01386889]), array([ 6.39947082,  3.49760682]), array([ 3.85232966,  5.49169461]))
('COV:', array([[ 0.97363924, -0.12436651],
       [-0.12436651,  0.83543861]]), array([[ 2.72524549,  1.83291907],
       [ 1.83291907,  2.08544824]]), array([[ 1.90397629,  1.9569356 ],
       [ 1.9569356 ,  4.13204591]]))
('MEAN:', array([ 5.19530781,  7.04027011]), array([ 6.44148419,  3.53498224]), array([ 3.79790936,  5.3795101 ]))
('COV:', array([[ 0.94322849, -0.065875  ],
       [-0.065875  ,  0.79056999]]), array([[ 2.59956632,  1.76759097],
       [ 1.76759097,  2.06513529]]), array([[ 1.88776254,  1.95359087],
       [ 1.95359087

('MEAN:', array([ 5.05769823,  7.03873176]), array([ 7.0401057 ,  4.02554044]), array([ 3.15073768,  3.34387771]))
('COV:', array([[ 0.93038513,  0.16034576],
       [ 0.16034576,  0.92061945]]), array([[ 0.96050131,  0.49013378],
       [ 0.49013378,  0.99592809]]), array([[ 1.12144459,  0.25480556],
       [ 0.25480556,  4.01013965]]))
('MEAN:', array([ 5.05390011,  7.03590641]), array([ 7.03861582,  4.02473617]), array([ 3.13984646,  3.31680358]))
('COV:', array([[ 0.93391677,  0.16171663],
       [ 0.16171663,  0.92535733]]), array([[ 0.96278685,  0.49092238],
       [ 0.49092238,  0.99582159]]), array([[ 1.11157106,  0.2311364 ],
       [ 0.2311364 ,  3.9515242 ]]))
('MEAN:', array([ 5.05028027,  7.03316951]), array([ 7.03718774,  4.02396277]), array([ 3.12979025,  3.29150462]))
('COV:', array([[ 0.93737045,  0.16311606],
       [ 0.16311606,  0.9297735 ]]), array([[ 0.96499838,  0.49169349],
       [ 0.49169349,  0.99573454]]), array([[ 1.10259032,  0.20948569],
       [ 0.209485

('MEAN:', array([ 5.01441466,  7.0038455 ]), array([ 7.02271692,  4.01608552]), array([ 3.04520299,  3.06410911]))
('COV:', array([[ 0.97645818,  0.1831452 ],
       [ 0.1831452 ,  0.97136413]]), array([[ 0.98846118,  0.5002135 ],
       [ 0.5002135 ,  0.99557362]]), array([[ 1.03276116,  0.03717557],
       [ 0.03717557,  3.41632121]]))
('MEAN:', array([ 5.01409215,  7.00356206]), array([ 7.02257978,  4.01601131]), array([ 3.04453876,  3.06222771]))
('COV:', array([[ 0.97684693,  0.18338283],
       [ 0.18338283,  0.97174401]]), array([[ 0.98869225,  0.50030132],
       [ 0.50030132,  0.99558143]]), array([[ 1.03224704,  0.03591923],
       [ 0.03591923,  3.41249748]]))
('MEAN:', array([ 5.01380767,  7.0033117 ]), array([ 7.02245854,  4.01594571]), array([ 3.04395373,  3.06057013]))
('COV:', array([[ 0.97719034,  0.18359335],
       [ 0.18359335,  0.97207953]]), array([[ 0.98889664,  0.50037911],
       [ 0.50037911,  0.99558857]]), array([[ 1.03179443,  0.03481458],
       [ 0.034814

('MEAN:', array([ 5.01182821,  7.00156087]), array([ 7.02160762,  4.01548562]), array([ 3.03990513,  3.04908927]))
('COV:', array([[ 0.97959182,  0.18508169],
       [ 0.18508169,  0.97442549]]), array([[ 0.99033499,  0.50092944],
       [ 0.50092944,  0.9956453 ]]), array([[ 1.02866652,  0.02722008],
       [ 0.02722008,  3.38590444]]))
('MEAN:', array([ 5.01181524,  7.00154934]), array([ 7.02160199,  4.01548258]), array([ 3.03987871,  3.04901431]))
('COV:', array([[ 0.97960763,  0.18509159],
       [ 0.18509159,  0.97444094]]), array([[ 0.99034452,  0.5009331 ],
       [ 0.5009331 ,  0.99564571]]), array([[ 1.02864613,  0.02717082],
       [ 0.02717082,  3.3857533 ]]))
('MEAN:', array([ 5.01180384,  7.00153921]), array([ 7.02159705,  4.01547991]), array([ 3.0398555 ,  3.04894848]))
('COV:', array([[ 0.97962152,  0.18510028],
       [ 0.18510028,  0.9744545 ]]), array([[ 0.99035289,  0.50093632],
       [ 0.50093632,  0.99564608]]), array([[ 1.02862821,  0.02712756],
       [ 0.027127

('MEAN:', array([ 5.01172581,  7.00146986]), array([ 7.0215632 ,  4.01546162]), array([ 3.03969663,  3.04849779]))
('COV:', array([[ 0.97971661,  0.18515981],
       [ 0.18515981,  0.97454743]]), array([[ 0.99041025,  0.50095838],
       [ 0.50095838,  0.9956486 ]]), array([[ 1.02850558,  0.02683146],
       [ 0.02683146,  3.38471194]]))
('MEAN:', array([ 5.0117253 ,  7.00146942]), array([ 7.02156298,  4.0154615 ]), array([ 3.0396956 ,  3.04849488]))
('COV:', array([[ 0.97971723,  0.1851602 ],
       [ 0.1851602 ,  0.97454803]]), array([[ 0.99041062,  0.50095852],
       [ 0.50095852,  0.99564861]]), array([[ 1.02850478,  0.02682954],
       [ 0.02682954,  3.38470607]]))
('MEAN:', array([ 5.01172486,  7.00146902]), array([ 7.02156279,  4.01546139]), array([ 3.0396947 ,  3.04849232]))
('COV:', array([[ 0.97971777,  0.18516054],
       [ 0.18516054,  0.97454856]]), array([[ 0.99041095,  0.50095865],
       [ 0.50095865,  0.99564863]]), array([[ 1.02850409,  0.02682786],
       [ 0.026827