In [7]:
import numpy as np
import pandas as pd
import cPickle as pkl
from matplotlib import pyplot as plt
import time

from optimizers import *
import linear_corex_sgd
import theano_linear_corex
import linearcorex

# Load Data

In [3]:
def load_data():
    with open('../data/EOD_week.pkl', 'rb') as f:
        df = pd.DataFrame(pkl.load(f))
    return df

df = load_data()
print("Data.shape = {}".format(df.shape))

Data.shape = (887, 5038)


# SGD - Run 10 times and average

In [4]:
corex_sgd_params = {
    'n_hidden':10,
    'max_iter':300, # NOTE: 1e4
    'tol':1e-5,
    'anneal':True,
    'missing_values':None,
    'discourage_overlap':True,
    'gaussianize':'standard',
    'gpu':False,
    'verbose':1,
    'seed':None,
    'optimizer': Adam()
}

In [13]:
all_values = []

In [14]:
for i in range(10):
    corex = linear_corex_sgd.Corex(**corex_sgd_params)
    corex.fit(df[:200])
    all_values += [corex.tc]

Linear CorEx with 10 latent factors
294 iterations to tol: 0.000010, Time: 4.67767190933
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors
280 iterations to tol: 0.000010, Time: 4.48975014687
193 iterations to tol: 0.000010, Time: 3.06348800659
Linear CorEx with 10 latent factors
255 iterations to tol: 0.000010, Time: 4.03380203247
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors
Linear CorEx with 10 latent factors


In [15]:
print "{} runs".format(len(all_values))
print "mean = {}".format(np.mean(all_values))
print "max  = {}".format(np.max(all_values))
print "min  = {}".format(np.min(all_values))

10 runs
mean = 652.349555741
max  = 659.680652449
min  = 649.596666758


# Original - Run 10 times and average

In [10]:
corex_params = {
    'n_hidden':10,
    'max_iter':300, # NOTE: 1e4
    'tol':1e-5,
    'anneal':True,
    'missing_values':None,
    'discourage_overlap':True,
    'gaussianize':'standard',
    'gpu':False,
    'verbose':1,
    'seed':None,
}

In [11]:
all_values = []

In [12]:
for i in range(10):
    corex = linearcorex.Corex(**corex_params)
    corex.fit(df[:200])
    all_values += [corex.tc]

Linear CorEx with 10 latent factors
55 iterations to tol: 0.000010
125 iterations to tol: 0.000010
49 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
193 iterations to tol: 0.000010
64 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
226 iterations to tol: 0.000010
58 iterations to tol: 0.000010
87 iterations to tol: 0.000010
63 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
281 iterations to tol: 0.000010
84 iterations to tol: 0.000010
71 iterations to tol: 0.000010
63 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
294 iterations to tol: 0.000010
49 iterations to tol: 0.000010
76 iterations to tol: 0.000010
53 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
237 iterations to tol: 0.000010
77 iterations to tol: 0.000010
35 iterations to tol: 0.000010
43 iterations to tol: 0.000010
67 iterations to tol: 0.000010
Linear CorEx with 10 latent factors
278 iterations to tol: 0.000010
72 iterations to tol: 0.0000

In [13]:
print "{} runs".format(len(all_values))
print "mean = {}".format(np.mean(all_values))
print "max  = {}".format(np.max(all_values))
print "min  = {}".format(np.min(all_values))

10 runs
mean = 652.555400658
max  = 664.724098206
min  = 648.976228714


# Theano - Run 10 times and average

In [6]:
all_values = []
corex_params_theano = {
    'nv':df.shape[1],
    'n_hidden':10,
    'max_iter':300, # NOTE: 1e4
    'tol':1e-5,
    'anneal':True,
    'missing_values':None,
    'discourage_overlap':True,
    'gaussianize':'standard',
    'gpu':False,
    'verbose':1,
    'seed':None,
}

In [7]:
for i in range(10):
    corex = theano_linear_corex.Corex(**corex_params_theano)
    corex.fit(df[:200])
    all_values += [corex.tc]

Linear CorEx with 10 latent factors
tc = -208.009068657, obj = -968.156049495, eps = 0.6
tc = 197.015522988, obj = -1091.07193394, eps = 0.6
tc = 253.720780731, obj = -1150.01523565, eps = 0.6
tc = 281.294498867, obj = -1197.49569621, eps = 0.6
tc = 303.318531529, obj = -1198.27700182, eps = 0.6
tc = 315.211076839, obj = -1224.82056671, eps = 0.6
tc = 324.644773603, obj = -1222.93523733, eps = 0.6
tc = 330.004205346, obj = -1232.63203406, eps = 0.6
tc = 333.643432173, obj = -1229.63102422, eps = 0.6
tc = 337.342282708, obj = -1239.11283683, eps = 0.6
tc = 339.820247273, obj = -1236.82960797, eps = 0.6
tc = 341.984594531, obj = -1241.60007264, eps = 0.6
tc = 343.520007867, obj = -1245.79347815, eps = 0.6
tc = 343.998928536, obj = -1246.1241054, eps = 0.6
tc = 346.237739481, obj = -1243.60056993, eps = 0.6
tc = 347.405147997, obj = -1244.67411271, eps = 0.6
tc = 348.085904038, obj = -1243.23050096, eps = 0.6
tc = 348.849108717, obj = -1252.89471058, eps = 0.6
tc = 349.616528412, obj = -1

tc = 341.376643329, obj = -1245.17806596, eps = 0.6
tc = 343.337616173, obj = -1244.95112032, eps = 0.6
tc = 345.269530419, obj = -1242.27935442, eps = 0.6
tc = 346.999653852, obj = -1242.13677303, eps = 0.6
tc = 348.418363106, obj = -1247.93997622, eps = 0.6
tc = 349.239671155, obj = -1255.29194033, eps = 0.6
tc = 350.12704141, obj = -1253.44434192, eps = 0.6
tc = 351.074652908, obj = -1245.21945719, eps = 0.6
tc = 352.02907635, obj = -1248.1343045, eps = 0.6
tc = 352.304320822, obj = -1257.51399671, eps = 0.6
tc = 352.592183803, obj = -1255.95238883, eps = 0.6
tc = 478.11441685, obj = -2203.95072483, eps = 0.36
tc = 515.730370288, obj = -2194.39951862, eps = 0.36
tc = 519.394920517, obj = -2200.56623665, eps = 0.36
tc = 521.769113601, obj = -2207.88532392, eps = 0.36
tc = 524.565208938, obj = -2205.75755997, eps = 0.36
tc = 525.756002344, obj = -2209.07582192, eps = 0.36
tc = 526.176796278, obj = -2212.91213786, eps = 0.36
tc = 526.602872205, obj = -2213.15594923, eps = 0.36
tc = 527

tc = 478.360408467, obj = -2199.98296137, eps = 0.36
tc = 514.524303088, obj = -2194.3696408, eps = 0.36
tc = 519.572573746, obj = -2204.11381747, eps = 0.36
tc = 520.617542898, obj = -2204.04953286, eps = 0.36
tc = 521.8833501, obj = -2203.75181905, eps = 0.36
tc = 522.611606459, obj = -2208.07155335, eps = 0.36
tc = 523.65444344, obj = -2205.79914812, eps = 0.36
tc = 524.497024478, obj = -2212.757929, eps = 0.36
tc = 524.881198061, obj = -2200.47571279, eps = 0.36
tc = 525.591759247, obj = -2213.1420276, eps = 0.36
tc = 526.533717768, obj = -2213.10379539, eps = 0.36
tc = 527.333013913, obj = -2209.69918527, eps = 0.36
tc = 527.203295174, obj = -2204.47070351, eps = 0.36
tc = 527.893736034, obj = -2212.80842562, eps = 0.36
tc = 528.616604546, obj = -2209.45211092, eps = 0.36
tc = 529.065559452, obj = -2218.03277224, eps = 0.36
tc = 529.884719356, obj = -2212.55726589, eps = 0.36
tc = 529.606895885, obj = -2207.25593485, eps = 0.36
tc = 530.663638691, obj = -2213.0119091, eps = 0.36
t

tc = 529.955052686, obj = -2213.88203035, eps = 0.36
tc = 530.102475339, obj = -2215.70473629, eps = 0.36
tc = 530.477339932, obj = -2215.88400704, eps = 0.36
tc = 530.590038357, obj = -2213.98246842, eps = 0.36
tc = 531.091535694, obj = -2215.25433457, eps = 0.36
tc = 531.197873232, obj = -2213.43575159, eps = 0.36
tc = 531.797369345, obj = -2212.97863432, eps = 0.36
tc = 531.66017503, obj = -2213.79584959, eps = 0.36
tc = 531.499774669, obj = -2212.22557593, eps = 0.36
tc = 579.548765056, obj = -3087.33518, eps = 0.216
tc = 600.01629573, obj = -3079.77451331, eps = 0.216
tc = 599.937918925, obj = -3080.86800964, eps = 0.216
tc = 601.699835447, obj = -3080.2158009, eps = 0.216
tc = 602.077388694, obj = -3084.62718574, eps = 0.216
tc = 602.381261105, obj = -3085.24269472, eps = 0.216
tc = 602.38959981, obj = -3081.79095176, eps = 0.216
tc = 601.986613917, obj = -3082.81250278, eps = 0.216
tc = 603.028465865, obj = -3082.52845925, eps = 0.216
tc = 602.784689467, obj = -3086.99445623, ep

tc = 604.365475013, obj = -3082.39361334, eps = 0.216
tc = 604.826096782, obj = -3090.67434693, eps = 0.216
tc = 604.555559476, obj = -3081.91835924, eps = 0.216
tc = 605.565588399, obj = -3089.36606756, eps = 0.216
tc = 605.354740015, obj = -3090.58957748, eps = 0.216
tc = 605.25306445, obj = -3087.21315549, eps = 0.216
tc = 605.725283181, obj = -3085.62011902, eps = 0.216
tc = 606.261144216, obj = -3086.79300271, eps = 0.216
tc = 606.259272912, obj = -3086.27059486, eps = 0.216
tc = 606.413717001, obj = -3088.38446426, eps = 0.216
tc = 606.703764016, obj = -3092.1777847, eps = 0.216
tc = 606.816564251, obj = -3089.2770401, eps = 0.216
tc = 606.84956675, obj = -3090.51929114, eps = 0.216
tc = 607.192915203, obj = -3089.71366528, eps = 0.216
tc = 607.327454745, obj = -3087.74160998, eps = 0.216
tc = 607.266630378, obj = -3090.50569942, eps = 0.216
tc = 607.543105391, obj = -3091.38551065, eps = 0.216
tc = 607.462560734, obj = -3091.51778982, eps = 0.216
tc = 621.613695549, obj = -3923.

tc = 603.444741651, obj = -3086.73270265, eps = 0.216
tc = 603.363383865, obj = -3083.78833767, eps = 0.216
tc = 602.76530922, obj = -3085.59586974, eps = 0.216
tc = 603.368158251, obj = -3084.37791043, eps = 0.216
tc = 603.221557256, obj = -3087.26353843, eps = 0.216
tc = 603.087758938, obj = -3084.7580175, eps = 0.216
tc = 603.225621957, obj = -3085.98627485, eps = 0.216
tc = 610.449409487, obj = -3919.14028242, eps = 0.1296
tc = 630.151207988, obj = -3914.48223148, eps = 0.1296
tc = 630.632446267, obj = -3916.96073158, eps = 0.1296
tc = 630.76990372, obj = -3917.37744976, eps = 0.1296
tc = 630.939696752, obj = -3921.71259874, eps = 0.1296
tc = 630.961140058, obj = -3921.13955582, eps = 0.1296
tc = 630.726841736, obj = -3916.93273272, eps = 0.1296
tc = 631.095890578, obj = -3916.8086678, eps = 0.1296
tc = 630.897789858, obj = -3915.76020352, eps = 0.1296
tc = 631.186519424, obj = -3917.75952598, eps = 0.1296
tc = 631.349929145, obj = -3920.0979677, eps = 0.1296
tc = 631.162470809, ob

tc = 628.360286272, obj = -3918.72767553, eps = 0.1296
tc = 628.177412864, obj = -3919.63934432, eps = 0.1296
tc = 628.543398071, obj = -3910.63802728, eps = 0.1296
tc = 628.391532711, obj = -3918.29654914, eps = 0.1296
tc = 628.231313164, obj = -3915.93613935, eps = 0.1296
tc = 628.592625351, obj = -3914.00151478, eps = 0.1296
tc = 628.897408858, obj = -3914.2023277, eps = 0.1296
tc = 628.802187476, obj = -3917.70145826, eps = 0.1296
tc = 628.775593142, obj = -3917.76500948, eps = 0.1296
tc = 629.501439198, obj = -3916.27051715, eps = 0.1296
tc = 629.647684691, obj = -3917.81120677, eps = 0.1296
tc = 629.866796235, obj = -3918.3559781, eps = 0.1296
tc = 630.518533727, obj = -3915.9263064, eps = 0.1296
tc = 630.54656025, obj = -3917.4671064, eps = 0.1296
tc = 630.982122512, obj = -3919.13159515, eps = 0.1296
tc = 631.560561759, obj = -3919.25324902, eps = 0.1296
tc = 629.511051536, obj = -4737.99359678, eps = 0.07776
tc = 641.695697287, obj = -4737.73633558, eps = 0.07776
tc = 641.9820

tc = 636.402535986, obj = -3922.04404451, eps = 0.1296
tc = 636.19713091, obj = -3924.04723053, eps = 0.1296
tc = 636.130140115, obj = -3926.57219727, eps = 0.1296
tc = 636.448501594, obj = -3921.34574992, eps = 0.1296
tc = 636.550835096, obj = -3924.34216821, eps = 0.1296
tc = 638.70387945, obj = -4746.51731724, eps = 0.07776
tc = 647.836096862, obj = -4745.49800933, eps = 0.07776
tc = 648.069695064, obj = -4740.88583572, eps = 0.07776
tc = 648.347959983, obj = -4747.29433303, eps = 0.07776
tc = 648.47965212, obj = -4743.00580226, eps = 0.07776
tc = 648.521029905, obj = -4745.12796974, eps = 0.07776
tc = 648.397780755, obj = -4745.5507711, eps = 0.07776
tc = 648.910987676, obj = -4742.71148903, eps = 0.07776
tc = 648.710425946, obj = -4744.72409717, eps = 0.07776
tc = 648.627976633, obj = -4743.00179588, eps = 0.07776
tc = 647.924516081, obj = -4749.68775471, eps = 0.07776
tc = 648.49913854, obj = -4746.18139907, eps = 0.07776
tc = 648.985627273, obj = -4743.1619428, eps = 0.07776
tc 

tc = 639.513894785, obj = -4734.43671037, eps = 0.07776
tc = 639.704420487, obj = -4733.37789085, eps = 0.07776
tc = 639.48346093, obj = -4735.36493539, eps = 0.07776
tc = 639.216458456, obj = -4736.40076315, eps = 0.07776
tc = 639.928512711, obj = -4738.98845917, eps = 0.07776
tc = 639.656987058, obj = -4731.54478926, eps = 0.07776
tc = 639.640538207, obj = -4735.48857079, eps = 0.07776
tc = 639.563214534, obj = -4737.4809978, eps = 0.07776
tc = 639.771527837, obj = -4737.91154656, eps = 0.07776
tc = 639.353959578, obj = -4735.90883252, eps = 0.07776
tc = 639.045134753, obj = -4734.59898027, eps = 0.07776
tc = 639.46649598, obj = -4739.2190859, eps = 0.07776
tc = 639.657343583, obj = -4732.11229823, eps = 0.07776
tc = 639.793691022, obj = -4736.26137329, eps = 0.07776
tc = 626.441449117, obj = -5552.16531299, eps = 0.046656
tc = 641.677045862, obj = -5547.31650499, eps = 0.046656
tc = 643.777284945, obj = -5548.72838943, eps = 0.046656
tc = 644.043352829, obj = -5549.7495965, eps = 0.

tc = 649.03332269, obj = -4743.59384842, eps = 0.07776
tc = 648.89665696, obj = -4746.66812452, eps = 0.07776
tc = 648.888863438, obj = -4745.10954448, eps = 0.07776
tc = 649.037700955, obj = -4743.07546812, eps = 0.07776
tc = 644.485196448, obj = -5562.76350978, eps = 0.046656
tc = 653.715932471, obj = -5557.64773605, eps = 0.046656
tc = 653.903380477, obj = -5559.74554942, eps = 0.046656
tc = 653.976685799, obj = -5555.95152811, eps = 0.046656
tc = 654.330889342, obj = -5558.73562506, eps = 0.046656
tc = 654.313815228, obj = -5559.86750781, eps = 0.046656
tc = 654.160905429, obj = -5558.99735916, eps = 0.046656
tc = 654.192358729, obj = -5561.69012937, eps = 0.046656
tc = 654.204677194, obj = -5559.70027222, eps = 0.046656
tc = 653.772050592, obj = -5557.80605415, eps = 0.046656
tc = 654.427295555, obj = -5565.87249409, eps = 0.046656
tc = 654.38407915, obj = -5559.07245702, eps = 0.046656
tc = 654.056490184, obj = -5559.92211336, eps = 0.046656
tc = 653.93317465, obj = -5561.6643705

In [8]:
print "{} runs".format(len(all_values))
print "mean = {}".format(np.mean(all_values))
print "max  = {}".format(np.max(all_values))
print "min  = {}".format(np.min(all_values))

10 runs
mean = 653.546141986
max  = 659.433333187
min  = 647.053301305
