In this notebook, we illustrate how to fit GMM with different optimization routines. A similar approach can be followed for fitting other models with these optimizers.

In [1]:
from Mixture_Models import *

#### Simulating some data

In [2]:
data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3,
                         num_per_class=100, rate=0.4,rs=npr.RandomState(0))

#### Initializing the model

In [3]:
test_GMM = GMM(data)

#### Initializing the input parameters

In [4]:
npr.seed(10)
init_params = test_GMM.init_params(num_components=3,scale=0.5)
print(init_params)

{'log proportions': array([ 0.66579325,  0.35763949, -0.77270015]), 'means': array([[-0.00419192,  0.31066799],
       [-0.36004278,  0.13275579],
       [ 0.05427426,  0.00214572]]), 'lower triangles': array([[[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]]])}


#### Fitting the model with the above initialization parameters

Gradient Descent with momentum

In [5]:
params_store = test_GMM.fit(init_params,opt_routine = "grad_descent", learning_rate = 0.0005, mass = 0.9, maxiter = 100, tol = 1e-7) 

Log likelihood -720.8092616186832
Log likelihood -719.9324989689964
Log likelihood -718.274814939572
Log likelihood -715.9300034638668
Log likelihood -712.9900481929888
Log likelihood -709.5458269817979
Log likelihood -705.6880453853834
Log likelihood -701.5082241387775
Log likelihood -697.0995833835279
Log likelihood -692.557653132106
Log likelihood -687.9804054374333
Log likelihood -683.4676605634239
Log likelihood -679.1194852214626
Log likelihood -675.0333048263767
Log likelihood -671.2995356451652
Log likelihood -667.9957548102311
Log likelihood -665.1797992207992
Log likelihood -662.8826958402399
Log likelihood -661.1028521459016
Log likelihood -659.8032389258503
Log likelihood -658.9131007549157
Log likelihood -658.3348876031125
Log likelihood -657.9557758247458
Log likelihood -657.6618055087696
Log likelihood -657.3518342045413
Log likelihood -656.9484866199952
Log likelihood -656.4040216179089
Log likelihood -655.7002839476836
Log likelihood -654.8433206006828
Log likelihood -

In [6]:
params_store = test_GMM.fit(init_params,opt_routine = "rms_prop", learning_rate = 0.01, gamma = 0.9, maxiter = 100, tol = 1e-7) 

Log likelihood -720.8092616186832
Log likelihood -708.7687436918704
Log likelihood -700.5695943512029
Log likelihood -694.1365026952342
Log likelihood -688.7992701317978
Log likelihood -684.2343094906564
Log likelihood -680.2509729901997
Log likelihood -676.72377944885
Log likelihood -673.5643825913835
Log likelihood -670.706629891055
Log likelihood -668.0968903183066
Log likelihood -665.6867120416829
Log likelihood -663.4258206923182
Log likelihood -661.2538140270026
Log likelihood -659.0896525495962
Log likelihood -656.8207660722843
Log likelihood -654.2999317201795
Log likelihood -651.361311175862
Log likelihood -647.8531607362195
Log likelihood -643.6709453080134
Log likelihood -638.7843858447935
Log likelihood -633.2539663815364
Log likelihood -627.2221184965576
Log likelihood -620.8801549218979
Log likelihood -614.4636427271009
Log likelihood -608.2643929292244
Log likelihood -602.4560196420489
Log likelihood -596.9081858274919
Log likelihood -591.4105651077606
Log likelihood -58

Adam

In [7]:
params_store = test_GMM.fit(init_params,opt_routine="adam", learning_rate = 0.1, beta1 = 0.9, beta2 = 0.99, maxiter = 100,tol = 1e-7)

Log likelihood -720.8092616186832
Log likelihood -685.0765304238932
Log likelihood -662.6350166131663
Log likelihood -649.0726027318387
Log likelihood -631.6760439953945
Log likelihood -600.1727102280388
Log likelihood -630.9274646458698
Log likelihood -578.9238847231685
Log likelihood -562.165276661053
Log likelihood -599.5333653984799
Log likelihood -576.6578614655705
Log likelihood -571.7616447184505
Log likelihood -572.1276897293137
Log likelihood -570.8580451771827
Log likelihood -568.4294465585931
Log likelihood -563.1954191409978
Log likelihood -547.1253907516916
Log likelihood -535.9863256854494
Log likelihood -547.0920329764722
Log likelihood -533.8007869184902
Log likelihood -541.2124368508833
Log likelihood -544.895639511094
Log likelihood -542.4479327601205
Log likelihood -536.8051646725139
Log likelihood -526.1368045225732
Log likelihood -510.95425998896087
Log likelihood -543.4440078030493
Log likelihood -512.1522515529334
Log likelihood -527.9199129627668
Log likelihood 

Newtons CG

In [12]:
params_store = test_GMM.fit(init_params,opt_routine="Newton-CG", maxiter = 100, tol = 1e-7)

Log likelihood -686.1832497203502
Log likelihood -662.5130705607678
Log likelihood -582.3683234440628
Log likelihood -565.349055873944
Log likelihood -559.7514610450845
Log likelihood -556.7400049073273
Log likelihood -527.1685213755479
Log likelihood -510.3776376440154
Log likelihood -503.49668673474605
Log likelihood -491.3439022514285
Log likelihood -481.65444262064193
Log likelihood -468.53913023916493
Log likelihood -464.82906831707436
Log likelihood -461.59604667743866
Log likelihood -460.841081195283
Log likelihood -459.64656630858565
Log likelihood -457.7399580790257
Log likelihood -456.8178174626321
Log likelihood -455.33888113669093
Log likelihood -453.8382571209291
Log likelihood -451.27541712101697
Log likelihood -447.6385484419127
Log likelihood -441.1096579896714
Log likelihood -429.56336005267997
Log likelihood -428.21805675532295
Log likelihood -426.9664441206501
Log likelihood -424.90783424323513
Log likelihood -422.94938394306297
Log likelihood -421.6728304341238
Log 