In this notebook, we illustrate how to fit GMM with different optimization routines. A similar approach can be followed for fitting other models with these optimizers.

In [1]:
from Mixture_Models import *

#### Simulating some data

In [2]:
data = make_pinwheel(
    radial_std=0.3,
    tangential_std=0.05,
    num_classes=3,
    num_per_class=100,
    rate=0.4,
    rs=npr.RandomState(0),
)

#### Initializing the model

In [3]:
test_GMM = GMM(data)

#### Initializing the input parameters

In [4]:
npr.seed(10)
init_params = test_GMM.init_params(num_components=3, scale=0.5)
print(init_params)

{'log proportions': array([ 0.66579325,  0.35763949, -0.77270015]), 'means': array([[-0.00419192,  0.31066799],
       [-0.36004278,  0.13275579],
       [ 0.05427426,  0.00214572]]), 'sqrt_covs': array([[[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]]])}


#### Fitting the model with the above initialization parameters

Gradient Descent with momentum

In [5]:
params_store = test_GMM.fit(
    init_params,
    opt_routine="grad_descent",
    learning_rate=0.0005,
    mass=0.9,
    maxiter=100,
    tol=1e-7,
)

Log likelihood -720.8092616186832
Log likelihood -719.9324989689964
Log likelihood -718.2748149395721
Log likelihood -715.9300034638669
Log likelihood -712.9900481929888
Log likelihood -709.5458269817979
Log likelihood -705.6880453853835
Log likelihood -701.5082241387776
Log likelihood -697.0995833835279
Log likelihood -692.5576531321062
Log likelihood -687.9804054374333
Log likelihood -683.467660563424
Log likelihood -679.1194852214626
Log likelihood -675.0333048263767
Log likelihood -671.2995356451652
Log likelihood -667.9957548102311
Log likelihood -665.1797992207992
Log likelihood -662.88269584024
Log likelihood -661.1028521459016
Log likelihood -659.8032389258503
Log likelihood -658.9131007549157
Log likelihood -658.3348876031125
Log likelihood -657.9557758247458
Log likelihood -657.6618055087696
Log likelihood -657.3518342045413
Log likelihood -656.9484866199953
Log likelihood -656.4040216179089
Log likelihood -655.7002839476836
Log likelihood -654.8433206006829
Log likelihood -6

RMSProp

In [10]:
params_store = test_GMM.fit(
    init_params,
    opt_routine="rms_prop",
    learning_rate=0.01,
    gamma=0.9,
    maxiter=100,
    tol=1e-7,
)

Log likelihood -720.8092616186832
Log likelihood -708.7687436918704
Log likelihood -700.5695943512029
Log likelihood -694.1365026952342
Log likelihood -688.7992701317978
Log likelihood -684.2343094906564
Log likelihood -680.2509729901997
Log likelihood -676.72377944885
Log likelihood -673.5643825913835
Log likelihood -670.706629891055
Log likelihood -668.0968903183067
Log likelihood -665.6867120416829
Log likelihood -663.4258206923182
Log likelihood -661.2538140270027
Log likelihood -659.0896525495963
Log likelihood -656.8207660722843
Log likelihood -654.2999317201798
Log likelihood -651.3613111758623
Log likelihood -647.8531607362195
Log likelihood -643.6709453080134
Log likelihood -638.7843858447937
Log likelihood -633.2539663815367
Log likelihood -627.2221184965578
Log likelihood -620.880154921898
Log likelihood -614.463642727101
Log likelihood -608.2643929292243
Log likelihood -602.456019642049
Log likelihood -596.908185827492
Log likelihood -591.4105651077605
Log likelihood -585.7

Adam

In [7]:
params_store = test_GMM.fit(
    init_params,
    opt_routine="adam",
    learning_rate=0.1,
    beta1=0.9,
    beta2=0.99,
    maxiter=100,
    tol=1e-7,
)

Log likelihood -720.8092616186832
Log likelihood -685.0765304238932
Log likelihood -662.6350166131663
Log likelihood -649.072602731839
Log likelihood -631.6760439953945
Log likelihood -600.1727102280388
Log likelihood -630.92746464587
Log likelihood -578.9238847231692
Log likelihood -562.1652766610533
Log likelihood -599.5333653984795
Log likelihood -576.6578614655712
Log likelihood -571.7616447184506
Log likelihood -572.1276897293135
Log likelihood -570.8580451771821
Log likelihood -568.429446558592
Log likelihood -563.1954191409968
Log likelihood -547.1253907516902
Log likelihood -535.9863256854517
Log likelihood -547.0920329764597
Log likelihood -533.8007869184894
Log likelihood -541.2124368508769
Log likelihood -544.8956395110821
Log likelihood -542.4479327601019
Log likelihood -536.805164672487
Log likelihood -526.1368045225358
Log likelihood -510.95425998893893
Log likelihood -543.4440078032451
Log likelihood -512.1522515530087
Log likelihood -527.9199129629405
Log likelihood -53

Newtons CG

In [8]:
params_store = test_GMM.fit(init_params, opt_routine="Newton-CG", maxiter=100, tol=1e-7)

Log likelihood -686.1832497203501
Log likelihood -662.5130705607683
Log likelihood -582.3683234440556
Log likelihood -565.3490558739422
Log likelihood -559.7514610450737
Log likelihood -556.7400049073144
Log likelihood -527.1685213754887
Log likelihood -510.3776376442779
Log likelihood -503.4966867349866
Log likelihood -491.34390225156164
Log likelihood -481.6544426206567
Log likelihood -468.5391302392583
Log likelihood -464.8290683176063
Log likelihood -461.596046677574
Log likelihood -460.84108119521454
Log likelihood -459.6465663082922
Log likelihood -457.73995807840106
Log likelihood -456.81781746005606
Log likelihood -455.3388811327538
Log likelihood -453.8382571244513
Log likelihood -451.27541713751464
Log likelihood -447.63854844102804
Log likelihood -441.1096579659725
Log likelihood -429.5633650756498
Log likelihood -428.2180654269032
Log likelihood -426.96631907564773
Log likelihood -424.90755478132974
Log likelihood -422.94918959045117
Log likelihood -421.6720004272188
Log li