In [3]:
import numpy as np
import pandas as pd

from sklearn.model_selection import KFold
from utils import extract_for_ensemble, create_matrix_from_raw, RAND_SEED
from models import IRSVD, Baseline, GBias, SVP, SVT, RSVD

## KFolds
- Split data into 10 folds (90% train set)
- Use random state for reproducibility

In [4]:
data_pd = pd.read_csv("./data/data_train.csv")
kf = KFold(n_splits=10, shuffle=True, random_state=RAND_SEED)
# Check whether we have the same splits
for train_set, test_set in kf.split(data_pd):
    print(train_set)
    print(test_set)

[      0       1       3 ... 1176948 1176950 1176951]
[      2       7      14 ... 1176940 1176941 1176949]
[      0       1       2 ... 1176948 1176949 1176951]
[      5      12      13 ... 1176920 1176927 1176950]
[      1       2       3 ... 1176949 1176950 1176951]
[      0      36      43 ... 1176933 1176936 1176945]
[      0       1       2 ... 1176949 1176950 1176951]
[      3      10      17 ... 1176922 1176925 1176943]
[      0       1       2 ... 1176949 1176950 1176951]
[      6       9      32 ... 1176907 1176928 1176929]
[      0       1       2 ... 1176948 1176949 1176950]
[     11      26      29 ... 1176893 1176946 1176951]
[      0       1       2 ... 1176949 1176950 1176951]
[     28      35      61 ... 1176931 1176935 1176948]
[      0       1       2 ... 1176949 1176950 1176951]
[      4      15      16 ... 1176912 1176917 1176934]
[      0       2       3 ... 1176949 1176950 1176951]
[      1       8      21 ... 1176919 1176930 1176944]
[      0       1       2 ...

## Models
- Predict matrix for different models and parameters
- Save the produced matrix for ensemble (for all folds)
- Also train on entire dataset

### Improved Regularized SVD

- For ensemble training

In [5]:
params = (
    ("mean", 96, 0.01, 0.02, 0.05, 13),
    ("mean", 148, 0.01, 0.02, 0.05, 13),
    ("mean", 296, 0.01, 0.02, 0.05, 14),
    ("mean", 324, 0.01, 0.02, 0.05, 15),
    ("zero", 96, 0.01, 0.02, 0.05, 13),
    ("zero", 148, 0.01, 0.02, 0.05, 13),
    ("zero", 296, 0.01, 0.02, 0.05, 14),
    ("zero", 324, 0.01, 0.02, 0.05, 15),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        biases, features, eta, lambda1, lambda2, epochs = param
        fname = " irsvd_"+biases+"_"+str(features)
        print(param)
        model = IRSVD(train_matrix, biases=biases, features=features,
                      eta=eta, lambda1=lambda1, lambda2=lambda2, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:40<00:00, 35.41s/it]


{'train_rmse': [0.9928676922975986, 0.9918426269254519, 0.9918486763409934, 0.991125206101243, 0.9902663735465557, 0.989069888814417, 0.9861477389209141, 0.9799428860827163, 0.9682538009695231, 0.9510397479373288, 0.9288328726413521, 0.9016282070873272, 0.8700921113027357], 'test_rmse': [1.0037524033844163, 1.0030437870005222, 1.0027006664794795, 1.0024598112025538, 1.0017754613821215, 1.0019236615015619, 1.001530768242282, 0.9992211821398828, 0.9950817199945327, 0.9909491006985983, 0.9871527051390491, 0.9848431564677214, 0.9848267504102693]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [10:14<00:00, 47.29s/it]


{'train_rmse': [0.9927189913825504, 0.9921557928045148, 0.9920041332109162, 0.9914389014093807, 0.9909997278959303, 0.9898153165077408, 0.9880969785674063, 0.9835078559120553, 0.9747289443366217, 0.9606272725773228, 0.9419624595735441, 0.9180997806909194, 0.8880585561108983], 'test_rmse': [1.0038568078380468, 1.003129739023154, 1.0029745809912536, 1.0026483230072485, 1.0022573837124362, 1.001758252252774, 1.0021261854789516, 0.9997636149112271, 0.9963780177413344, 0.9921704577847621, 0.9880162152281193, 0.9848738179006072, 0.9830231978779059]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [13:35<00:00, 58.28s/it]


{'train_rmse': [0.9929555485612032, 0.9922004929638176, 0.9918162511212755, 0.9917500428889585, 0.9913256999432767, 0.9909424728191467, 0.9898511177147326, 0.9872471814126621, 0.9816803819974786, 0.9715049969799117, 0.9573218712043627, 0.9394111946442293, 0.915711412196087, 0.8855183116174847], 'test_rmse': [1.0035472452012146, 1.0026357760874092, 1.002293275748977, 1.0023667431654093, 1.0021004011499084, 1.0021994196310011, 1.0018588111420783, 1.0008635748314514, 0.9981147491853798, 0.9943414328042629, 0.9898664065994027, 0.986686945253791, 0.9836349194548305, 0.9814612829018748]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [11:45<00:00, 47.01s/it]


{'train_rmse': [0.993178684180608, 0.9921301535247934, 0.9920342612548219, 0.9918615644095977, 0.9913127156205762, 0.9910582072884911, 0.9901196725675511, 0.9877997192728868, 0.9828657662529736, 0.9738962999855603, 0.9599528297201076, 0.9424794098531609, 0.9196670350965036, 0.8908251060142981, 0.8549678854103796], 'test_rmse': [1.0042768579393269, 1.0026371475069238, 1.0024712641303162, 1.0025526529182804, 1.0022516363990763, 1.0024455803510985, 1.0019721556028662, 1.000793772156946, 0.9988569007899241, 0.9956688518677214, 0.9910560622177415, 0.9868391685304034, 0.983857911347744, 0.9820960860024295, 0.9812946968068437]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [08:21<00:00, 38.61s/it]


{'train_rmse': [0.9975397255933051, 0.9932897182928064, 0.9919624968470417, 0.9913886224977155, 0.9906827901042217, 0.9887927391228296, 0.9854660097570822, 0.9783800831655398, 0.9662845681282917, 0.9488506986224607, 0.9272795867411283, 0.9005079508458329, 0.8689414897117923], 'test_rmse': [1.0049279568706495, 1.0029780377833761, 1.0021152070398445, 1.002060530826864, 1.0023375381143655, 1.0020720204951725, 1.000693224130877, 0.9981308363280761, 0.9943641055891105, 0.9896578856414164, 0.986321557043782, 0.9846217763968769, 0.9847765117381213]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:40<00:00, 30.79s/it]


{'train_rmse': [0.9975006421228775, 0.9934073930468581, 0.9925330485839544, 0.9915096214179375, 0.9909112372469041, 0.9900536737412643, 0.9874494441385503, 0.9826340603385173, 0.9729727009565161, 0.9584935476900738, 0.9401989237819113, 0.916383359973566, 0.8864187471945649], 'test_rmse': [1.0054022045268738, 1.0027586258326107, 1.0027452537640398, 1.001618958783991, 1.0019132560941355, 1.0022238377752068, 1.0008857365087342, 0.9993364927777297, 0.9957650384134985, 0.9911473515345604, 0.9875673006144748, 0.984395967513839, 0.9823353314870436]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:07<00:00, 30.55s/it]


{'train_rmse': [0.9974731103814101, 0.9933611833493534, 0.9922804255203584, 0.9920240155236167, 0.9916113604205282, 0.9908891578334393, 0.989712647099351, 0.9872315254324003, 0.98120162910832, 0.9710050959207552, 0.9567740649405431, 0.9386073882993083, 0.9150615860622402, 0.8849592320036165], 'test_rmse': [1.0053928490657662, 1.0025808757119425, 1.0022145108848461, 1.0020787686435464, 1.0020485959182739, 1.001956373237731, 1.0016126446460674, 1.0008882457778583, 0.9976462881280025, 0.9937873376996892, 0.9898231793390312, 0.985631681079198, 0.9834603146604539, 0.9811051555877633]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:17<00:00, 29.17s/it]


{'train_rmse': [0.9975620197516427, 0.9935868419509019, 0.9925459808634485, 0.9920702235978187, 0.9917150296687941, 0.9911318259926339, 0.9899610447145124, 0.9875594218999367, 0.9822872173347802, 0.9723111995315643, 0.9587604314036768, 0.9415887804461511, 0.9187632699450191, 0.8896493325375622, 0.8544635326097851], 'test_rmse': [1.0050769532000468, 1.0024623583869798, 1.0027225301550773, 1.002297366439915, 1.0020237422198344, 1.001860521816584, 1.0020323184536986, 1.0010106052917243, 0.9986852109021315, 0.9944037022532394, 0.9900300935303019, 0.9865408894603408, 0.9837108302644914, 0.9812400545838837, 0.9807740153058427]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:19<00:00, 24.55s/it]


{'train_rmse': [0.9931058770971536, 0.9923067393435149, 0.9916785294401452, 0.991198798427562, 0.9901694897408942, 0.9886591119784481, 0.9851725115095652, 0.9778570332848663, 0.9656811983256939, 0.9492489628860303, 0.9280290451648092, 0.9009526850311697, 0.8688460497212568], 'test_rmse': [1.002062880485979, 1.0009843034906536, 1.0005490679566302, 1.000160464323458, 0.9996627544842005, 0.9994093692393909, 0.998355283598872, 0.996285869021889, 0.9919070363152397, 0.9879086149560127, 0.9853254700843183, 0.9832201297359007, 0.9823894724622885]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:40<00:00, 26.23s/it]


{'train_rmse': [0.9930471004055815, 0.9920393407661183, 0.992049252028981, 0.9918283742291772, 0.9911376693979183, 0.9900730617929112, 0.9878785547189567, 0.9827274452550209, 0.9735939670208099, 0.9592499672705094, 0.9410158008982961, 0.9172717808576494, 0.8875196626520314], 'test_rmse': [1.0017454186118318, 1.0004468865749325, 1.000662267488143, 1.0005837236506265, 1.000460164803375, 1.0003220899842808, 0.9993653255626495, 0.9971303982880902, 0.9940902784529836, 0.9898617387265946, 0.986277068041265, 0.9836830167008405, 0.9821373523541912]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [06:40<00:00, 28.59s/it]


{'train_rmse': [0.9932840958286511, 0.9924490710949203, 0.9923025880981685, 0.9921852089434265, 0.9913863712516531, 0.9911679875662057, 0.9900624723156091, 0.98740524926874, 0.9819680795991031, 0.9718551235599218, 0.9577887157511016, 0.9394302332919995, 0.9160105300054485, 0.8858985759611742], 'test_rmse': [1.001933261631954, 1.0007692147873448, 1.0010236400723036, 1.0003103793617891, 1.000350569043078, 1.0000799390780792, 0.9999606371760332, 0.9989891838678961, 0.9967014056493309, 0.9927697203886514, 0.9890690902057375, 0.9852891950140019, 0.9825032440325511, 0.9804350891785]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:37<00:00, 30.51s/it]


{'train_rmse': [0.9932728857042248, 0.992379403472772, 0.9919482178928777, 0.9920067945725868, 0.991653687061498, 0.9910907228353559, 0.9900876484619517, 0.987611362810913, 0.9821102063924221, 0.9723344434276623, 0.9586906035780098, 0.9417860509753728, 0.919086410538229, 0.8903215689034979, 0.8551110445307116], 'test_rmse': [1.0015640076116925, 1.00100589775799, 1.000238691978944, 1.0008291109100529, 1.0003863510183688, 0.9997430913638191, 0.9999832304731427, 0.9984527294715984, 0.9958791527424028, 0.9926251628735541, 0.9884205546268756, 0.9858928721320201, 0.9825995779135513, 0.9806488019710909, 0.979819389807764]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:54<00:00, 27.25s/it]


{'train_rmse': [0.9976238556227606, 0.9933060289484908, 0.9924160186818445, 0.9916728109264786, 0.9908508264248322, 0.9889708076784783, 0.9857404779686604, 0.9787202109684635, 0.9663329038064034, 0.9494705399211443, 0.9279571839444498, 0.9012517007776242, 0.8700570605717428], 'test_rmse': [1.002799491315216, 0.999982479505925, 1.0002847435908198, 1.0004373657708394, 1.0003358237775768, 0.9994674358346083, 0.9988812022153244, 0.9963265296358524, 0.9919824984397577, 0.9882431404331123, 0.9855419211288475, 0.9834066845514166, 0.982644073438385]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:19<00:00, 29.18s/it]


{'train_rmse': [0.9979509952220138, 0.9936252314541617, 0.9925584784001512, 0.9917870636890708, 0.9913426416834618, 0.990154245298434, 0.9880668648236027, 0.9827168365300163, 0.9731662608358713, 0.958758850102472, 0.9398468894059911, 0.9158608649457788, 0.8858417836414895], 'test_rmse': [1.003114093745876, 1.0003294519963521, 0.9996862144089355, 0.999664452664708, 0.999983496463104, 0.9996564198290281, 0.9992288941471731, 0.9973358726477936, 0.9936082489256685, 0.989690496047129, 0.9861458093540361, 0.9830761376122542, 0.9819632845211382]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [06:52<00:00, 29.44s/it]


{'train_rmse': [0.9978220371150435, 0.9935321947356434, 0.9925462735139204, 0.9920178314653776, 0.9920690919098573, 0.9912713015831501, 0.9900422985580197, 0.9873354911077077, 0.9812636345447665, 0.971346213260262, 0.9568800241246426, 0.9384518253057177, 0.9147098247940814, 0.8847198073084444], 'test_rmse': [1.0030061310837073, 1.0000669779332876, 0.9998037204960221, 0.9997499945697413, 1.0004492217077874, 0.9998431888370657, 0.9999166499663327, 0.9988044534048466, 0.9958497642033021, 0.9922624102330777, 0.9882743574736096, 0.984938110051107, 0.9823324919145096, 0.9805488813559947]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:12<00:00, 28.82s/it]


{'train_rmse': [0.9979520709118048, 0.9938024384177115, 0.9928115522412808, 0.9922737757541782, 0.9918714129242588, 0.9911780925357325, 0.9903540689507029, 0.9874618352112355, 0.9820562658033636, 0.9721151569172595, 0.9584071645197623, 0.9408049323008864, 0.9181753147967909, 0.8890130999024757, 0.8538785344919161], 'test_rmse': [1.0028820017398667, 1.0006254830100307, 1.000501903355878, 1.0000245253674753, 0.9999466328226755, 0.9997338455320921, 1.0001074210059355, 0.9984733780858405, 0.9960661239380695, 0.9923388836426984, 0.9882186485154278, 0.9848515374109539, 0.9823250118750829, 0.9806748261189093, 0.9800614202578773]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:21<00:00, 24.72s/it]


{'train_rmse': [0.993220235607962, 0.99201831991994, 0.9916278973504562, 0.9915310483521537, 0.9905551727063873, 0.9891087285103344, 0.9856062492242545, 0.9784907919288081, 0.9663288656591338, 0.9494977451279469, 0.928355750265296, 0.9017664849341837, 0.870552708701339], 'test_rmse': [1.0008525821136132, 0.9994710010051625, 0.9995064798338258, 1.0000744769871426, 0.9997745959354549, 0.9989813590773688, 0.9978285963163288, 0.9951792380971386, 0.9914533917870225, 0.9883175735897572, 0.9854678372413591, 0.9835560651720313, 0.9839281861480147]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:32<00:00, 30.20s/it]


{'train_rmse': [0.99331243427323, 0.9923629081300244, 0.9922012681802715, 0.9917177037011364, 0.9910758907102005, 0.9904406079055718, 0.9881597324172788, 0.9837376523453137, 0.9747234759256792, 0.9604651482587907, 0.9417645090438205, 0.9180737663825677, 0.8885383315325621], 'test_rmse': [1.0008312516660427, 0.9996429950085611, 0.999705567783886, 0.9997617058822553, 0.9991520800046345, 0.9997507354946203, 0.9985072087051559, 0.997549829895591, 0.9944461001563517, 0.9901510687557664, 0.9862884327964112, 0.98319826181685, 0.9821236385362272]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:39<00:00, 32.80s/it]


{'train_rmse': [0.9934015326430418, 0.9924704676532364, 0.9921536975323675, 0.991920421613194, 0.9919548166982125, 0.9911553094753539, 0.990025487485889, 0.9874278331459766, 0.9818493731986779, 0.971722205744227, 0.9576864006752465, 0.9392365363332446, 0.9156137574117937, 0.8855314251102127], 'test_rmse': [1.0005971227638084, 1.0003431093974937, 0.9999289965692781, 0.9996231575659712, 0.9992385206506278, 0.9993341577635427, 0.9991229875851336, 0.9979860690491253, 0.9957702571841134, 0.9920228262838084, 0.9883054923234741, 0.9850131259606086, 0.9820957144180553, 0.980315227611152]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:01<00:00, 32.11s/it]


{'train_rmse': [0.9933904259045403, 0.9923803666078643, 0.9922053662324685, 0.9919912804681813, 0.9917486128422321, 0.9914876828423062, 0.9901379765026288, 0.987740880713389, 0.9823548891185749, 0.9723676784799024, 0.9588648078052429, 0.9417050114646156, 0.9192142267094644, 0.890336711299574, 0.855164202622096], 'test_rmse': [1.0009849133340931, 0.9996917543975071, 0.9997984381122244, 0.99975588127376, 0.9996171369611517, 1.0000623022476103, 0.9989042589694588, 0.9979981401775486, 0.9959093833733007, 0.9919925141275872, 0.9879515915559676, 0.985148645149295, 0.9828885120104799, 0.9804681527519187, 0.9795410702200774]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:55<00:00, 31.98s/it]


{'train_rmse': [0.9978912865588953, 0.9938706539115649, 0.9924465319422898, 0.9916763047264754, 0.9908131804928046, 0.9893057996971998, 0.9856517215339915, 0.9789667091537201, 0.966808634944418, 0.9502467425120708, 0.9285010805376452, 0.9018962267118011, 0.8703670739872226], 'test_rmse': [1.0018420795101803, 1.0000574233728734, 0.999572505422612, 0.9993941075871454, 0.9997452568132585, 0.9991369930110141, 0.9976794705917045, 0.9957274318796061, 0.9919195602568073, 0.9883618527295878, 0.9850971521446417, 0.9836342371457996, 0.9833037989722757]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:32<00:00, 30.18s/it]


{'train_rmse': [0.99803945698699, 0.993979968575791, 0.9926152156410998, 0.9919285733732585, 0.9915641351506268, 0.9901116945710045, 0.9875636842359058, 0.9821558164637176, 0.9722675725143228, 0.9575432369634339, 0.939203144713502, 0.9156491447865247, 0.8860750080026966], 'test_rmse': [1.002603624529893, 0.9995989232853922, 0.999504550093404, 0.9990005634349388, 0.9998184191889061, 0.9992703841992322, 0.9980304245433631, 0.9960361189556546, 0.9927626208568238, 0.9886607814882229, 0.9854918906511106, 0.9829023890366136, 0.9811300064576856]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [06:56<00:00, 29.75s/it]


{'train_rmse': [0.9981585315494934, 0.9938604627691412, 0.9928274815302922, 0.992271158786199, 0.9920999822886948, 0.9911519693468518, 0.9901992908171874, 0.9871381658939306, 0.9813322689293558, 0.9710754555224704, 0.9563386289980959, 0.9382182188139566, 0.9142473583319467, 0.8843818305213116], 'test_rmse': [1.0026077696486595, 0.9995392682458177, 0.9996035791715315, 0.999302132044952, 0.9996169198852208, 0.9994807914066618, 0.9989912300074619, 0.9977217274379819, 0.9958721224456323, 0.9919314992492647, 0.9876908635077609, 0.9851761319998414, 0.9819695667813533, 0.9807231516810383]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:48<00:00, 31.20s/it]


{'train_rmse': [0.9979069026010277, 0.9939592327175268, 0.9928861734367783, 0.9922821497573809, 0.9917171085370988, 0.9917255251606928, 0.9902032590323501, 0.9877894248203084, 0.9824191244950388, 0.9725663886423564, 0.958960096792933, 0.9411618758341049, 0.9182831490029387, 0.8892176420245576, 0.8537915618893037], 'test_rmse': [1.0016990442549645, 0.9997743712507219, 0.9998899866116838, 0.9990249684707443, 0.999100380290553, 1.0000024346755758, 0.9987491908966478, 0.9980630690779331, 0.9958958243465409, 0.9920430321164819, 0.9886012298580674, 0.9846236846272551, 0.9823284741020244, 0.9803138945883092, 0.9797638110311282]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:37<00:00, 30.61s/it]


{'train_rmse': [0.992887845413748, 0.9919594559725445, 0.9915355227276437, 0.9910809964093619, 0.9903125648043577, 0.9886521567244353, 0.9850819749741626, 0.9779122860888704, 0.9656889238797483, 0.9488324032023558, 0.9275112887638141, 0.900845334246288, 0.8696323371496218], 'test_rmse': [1.0031030260554283, 1.0025573246487327, 1.0025859961223793, 1.0024759657520501, 1.002270253067234, 1.001545781854671, 1.0007667971014953, 0.997882623721706, 0.994240769699928, 0.9906637253336169, 0.9875240267762256, 0.9855933687197035, 0.9854922506146541]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:51<00:00, 31.66s/it]


{'train_rmse': [0.9927902452793226, 0.9920132373078321, 0.9916612386326408, 0.9915141134321066, 0.9907673709807882, 0.9896858645755657, 0.9871751704298747, 0.9823014915033068, 0.9722248203566375, 0.9582771523090863, 0.9402270835816564, 0.9165710787842096, 0.8867987951261033], 'test_rmse': [1.0035257812142968, 1.0027592717922884, 1.0018566824268698, 1.002838041065408, 1.0026064201802958, 1.00196041361243, 1.0005992755617577, 0.9996962358671135, 0.9957415439744504, 0.9916037781919919, 0.9889872412806603, 0.9861660962626867, 0.9843308218811239]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:37<00:00, 32.70s/it]


{'train_rmse': [0.9930207082022683, 0.9922445945383074, 0.9916473098860422, 0.991745397412137, 0.9912912261879078, 0.9907885862792631, 0.9895759314438237, 0.9867978077485196, 0.9810092056658013, 0.9705313674178112, 0.9563958921294604, 0.9384040474441521, 0.9149684765218487, 0.8853480254611459], 'test_rmse': [1.0036508055437983, 1.002712955503385, 1.0023251910239224, 1.0026533124242047, 1.0024745471986494, 1.0020542440912545, 1.0019760591086786, 1.0006011371799117, 0.998431677886403, 0.9941367879251415, 0.9904428473714499, 0.9875343220078425, 0.9849766419093609, 0.9833073745755856]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:24<00:00, 33.61s/it]


{'train_rmse': [0.9931098900074806, 0.9921672747997776, 0.9921814795472497, 0.9918531892470845, 0.9915892259088764, 0.9910144092516483, 0.989781246661478, 0.98754316796059, 0.9823042189854134, 0.9726459263110498, 0.9587213155288948, 0.9411066287027302, 0.9182653983754668, 0.8896947608077679, 0.8545481067411983], 'test_rmse': [1.0038738723538516, 1.0026564741829462, 1.0028240941016482, 1.0023346736787, 1.0026405492150539, 1.0021680971589146, 1.001997061515661, 1.0010310165672471, 0.999047308165114, 0.9955195394907159, 0.9912309447389841, 0.9877824186867491, 0.9850529964554852, 0.9830834616752648, 0.9822001622664817]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:46<00:00, 31.26s/it]


{'train_rmse': [0.9973825647031865, 0.9933352283941732, 0.9920153258909529, 0.991085559125754, 0.9904051208119505, 0.9886067036784023, 0.9851621483953584, 0.9777676890953458, 0.965751136385296, 0.9485154937069626, 0.9274800444708726, 0.9010701684403032, 0.869939615181069], 'test_rmse': [1.0046160697091513, 1.002577693339734, 1.001858958731732, 1.001435263235481, 1.0025332876805781, 1.0014850712057037, 1.0006148132210986, 0.9981240567724284, 0.9947261737966806, 0.9902285844716193, 0.9882748402991376, 0.9865697777289882, 0.9862255142955922]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:05<00:00, 32.73s/it]


{'train_rmse': [0.997697663474902, 0.9934047399966828, 0.9924342569355411, 0.9917269431472555, 0.991046402841504, 0.9897920374245401, 0.9875680249427351, 0.9827409746002259, 0.9734005583508863, 0.958894978801668, 0.9403830043160565, 0.9166464832025882, 0.8871616153054923], 'test_rmse': [1.004686614248648, 1.0022823006578683, 1.0023211496237865, 1.0023144876565415, 1.0023274487658795, 1.0018616845875126, 1.001252631753469, 0.9996677384704578, 0.9964175797393134, 0.9923653802427097, 0.9888465456440829, 0.9858120943806467, 0.9845088561208146]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:55<00:00, 33.96s/it]


{'train_rmse': [0.9975733336051457, 0.993549205591803, 0.9924120221075048, 0.9921581250936474, 0.9916082134092458, 0.9908426068551718, 0.9896630520308877, 0.9870232354228624, 0.9807849414958538, 0.9705470539514869, 0.956640524804636, 0.9383818411749342, 0.9147407095384948, 0.8845973056301457], 'test_rmse': [1.0046055546461885, 1.0024350524567194, 1.0020569778616695, 1.0028009697978282, 1.002267926204685, 1.001858276854427, 1.0018459804119713, 1.0005450920939252, 0.9982316324154953, 0.9949947493796371, 0.9911488101440373, 0.9878645175273614, 0.985491104147908, 0.9835231220520029]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:13<00:00, 32.89s/it]


{'train_rmse': [0.9974978954195641, 0.9936224818774358, 0.9926658553053727, 0.9918650649095482, 0.9914995256731562, 0.9908415357936996, 0.9899681757651115, 0.9872219998589091, 0.981421122448016, 0.971527120875105, 0.9576961423275075, 0.940629359883608, 0.9177642905693341, 0.888884245131004, 0.8537848804584813], 'test_rmse': [1.0048151055730634, 1.0026556917651976, 1.002313665318457, 1.0021605249164083, 1.0018114587388998, 1.0018347155261766, 1.002159407063075, 1.0007312778076174, 0.9984569221858731, 0.994642236879459, 0.9905350436289242, 0.9877796109288579, 0.9848579706153289, 0.9829254316320172, 0.9819716401484962]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:49<00:00, 31.46s/it]


{'train_rmse': [0.9926892373976504, 0.9915171540230823, 0.9913874257665674, 0.9905677729773158, 0.9898624823056551, 0.9882124621893579, 0.9851025236609803, 0.977886088144949, 0.9659898022773373, 0.9492600939925364, 0.9275478414440934, 0.9011467001392738, 0.8701644484446996], 'test_rmse': [1.007205033974094, 1.0062702345383323, 1.0062990360368835, 1.005890178059552, 1.0051710998286454, 1.0050192747490891, 1.004538950670149, 1.0018231834083837, 0.998018629178468, 0.9939206957104166, 0.9905102042197048, 0.9887949163734393, 0.9885106828276442]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:58<00:00, 32.22s/it]


{'train_rmse': [0.992486100203434, 0.9916634239165627, 0.9913038485528555, 0.9910065356156904, 0.9905013973843678, 0.9894010974712566, 0.9873049691749046, 0.9825241101623926, 0.973482692766064, 0.9587131635026965, 0.9398497181085127, 0.9162023967981939, 0.8864373996721101], 'test_rmse': [1.0069235260236906, 1.005844370244003, 1.006129062666915, 1.0059843310913712, 1.0061736069498073, 1.005422948110445, 1.0051112022506865, 1.0033861999190368, 1.0003605045012511, 0.9956662134435612, 0.9913975317097471, 0.9884945223848117, 0.9863813502798489]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [08:01<00:00, 34.37s/it]


{'train_rmse': [0.9925817653428401, 0.9918037543525061, 0.9915327584048923, 0.9912595858227662, 0.9910071695130321, 0.990665231335652, 0.9894949810630165, 0.9865060612526599, 0.980658782431046, 0.9704601973312432, 0.9568508693052105, 0.9392001614980994, 0.9159866653813523, 0.886088835175212], 'test_rmse': [1.0065307789230933, 1.0061001175258801, 1.0061363785391193, 1.0061551623838256, 1.005576078164934, 1.0056252461228712, 1.005840471735746, 1.0039965609960646, 1.0014054895091786, 0.9974922220160704, 0.9938475493298214, 0.990590725439813, 0.987223550760106, 0.9858406160183795]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:33<00:00, 34.26s/it]


{'train_rmse': [0.9925712859346844, 0.9918380116411407, 0.9913481117696776, 0.9915389883063312, 0.9909032733285654, 0.9908759467361765, 0.9896909187786144, 0.9871278359091613, 0.9817985937080176, 0.9719761693973017, 0.9586698518807966, 0.9414708698170097, 0.9188259885192018, 0.8899059478875081, 0.8546435256947333], 'test_rmse': [1.0070749040363467, 1.0061825428295312, 1.006130402649177, 1.0061107016943525, 1.005950156356193, 1.0058526298928225, 1.005044270842148, 1.0043212194506426, 1.0023612707209077, 0.9981343116108244, 0.9941576251040193, 0.9908641296477114, 0.9877023858564925, 0.985417893501673, 0.985045914551221]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:28<00:00, 29.91s/it]


{'train_rmse': [0.9971881541548805, 0.9929029490254847, 0.991581120950477, 0.9908952269000461, 0.989889812982652, 0.9886317993788012, 0.9850916854307736, 0.9785449285384368, 0.9663276378424261, 0.9492252497079137, 0.9269147501451029, 0.8998539921445875, 0.8680258641864262], 'test_rmse': [1.0081575755301428, 1.0062084765414383, 1.0051237199412495, 1.0056356299206624, 1.0053465403355348, 1.0052720678305276, 1.004335940468794, 1.0019717936018384, 0.9980781396952427, 0.9936638375185726, 0.9907983824170707, 0.9884906308784696, 0.9885934293500679]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:37<00:00, 30.59s/it]


{'train_rmse': [0.9972957164500893, 0.9933263826838874, 0.9917731988598242, 0.9914095993648208, 0.9904827127306967, 0.9897149494058158, 0.986944328464397, 0.981865312556409, 0.9720746922580983, 0.9578140154726262, 0.9401871200759012, 0.916930478239619, 0.8876662512332497], 'test_rmse': [1.007976404208172, 1.0062830325034664, 1.0056983009038658, 1.0055569992847786, 1.0055532274478665, 1.006108065525139, 1.0043477956831288, 1.002865529849533, 0.9990139001387751, 0.9943978159320292, 0.9910510898044197, 0.9883149819077413, 0.9868811989366608]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:10<00:00, 30.72s/it]


{'train_rmse': [0.9970468868976128, 0.9930089211980284, 0.992233605340722, 0.9916059152890945, 0.9911407876261824, 0.9904792813953535, 0.9895598079418028, 0.9866462703334762, 0.9808359618386384, 0.9703731158564936, 0.9564799234582505, 0.9387292724243999, 0.9154196725843142, 0.8859875147066009], 'test_rmse': [1.0074956451534662, 1.0054928248685355, 1.0058097485950668, 1.0060677213714915, 1.0055051515457136, 1.0051773974951626, 1.0047840695264663, 1.0040942871388892, 1.001523710719919, 0.9974789476269265, 0.993585033623389, 0.9900837486956927, 0.9873666445417439, 0.9854299686539084]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:35<00:00, 30.36s/it]


{'train_rmse': [0.9972634126521392, 0.9929751005571221, 0.992516737347956, 0.9916833065854257, 0.9911130388200022, 0.990555126401567, 0.9896175536218504, 0.9868582099134209, 0.9815977982941783, 0.9716948792352684, 0.9582431625923445, 0.9405160063032815, 0.9182233518661675, 0.8891571530784531, 0.8539583769312491], 'test_rmse': [1.0078353160845666, 1.0057094268110445, 1.0059769700685932, 1.0056491036693116, 1.0056002278151779, 1.0055812374613329, 1.0052217723891235, 1.0043072915203108, 1.0022544375800864, 0.9978406290639407, 0.9939679397873452, 0.9902347453087813, 0.9876914696608903, 0.9857362958874865, 0.9843807907557832]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:28<00:00, 29.87s/it]


{'train_rmse': [0.992830206854356, 0.9918389289912855, 0.991695548869964, 0.9912386290267122, 0.9901815764345884, 0.9887518090395874, 0.9852844078056977, 0.9784900599780768, 0.966554761743529, 0.949494107057252, 0.9280672428183637, 0.901146719769826, 0.8700772672548035], 'test_rmse': [1.0037378112965087, 1.0023294480092542, 1.002752538767133, 1.0021562615043416, 1.0019025057099162, 1.0021800003060388, 1.0003737488627649, 0.9980363819805399, 0.9939259025409788, 0.990028893096325, 0.9870068251236674, 0.9849889398266282, 0.9851100579822639]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:18<00:00, 33.70s/it]


{'train_rmse': [0.9927443227657083, 0.9919604133701475, 0.9918388494552004, 0.9914939281636366, 0.9908947684489483, 0.9896771833688148, 0.987816023813145, 0.9825798549585943, 0.9735674371249051, 0.9591385830215465, 0.9405347539173167, 0.9163458651645456, 0.8865831629481832], 'test_rmse': [1.003079554392059, 1.0023731426033762, 1.0027038930913643, 1.0027327352893154, 1.0027358406896827, 1.001607199631242, 1.0014924599290258, 0.9990689038002867, 0.9964137877508434, 0.9913987913355438, 0.9882848755729992, 0.98534034868, 0.9838651165477694]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:14<00:00, 31.00s/it]


{'train_rmse': [0.9929650443284261, 0.9921918513906409, 0.9921563886972747, 0.9917324587068636, 0.9913824991087504, 0.9910979335786236, 0.9896652477843157, 0.9871484649526822, 0.9817347416767397, 0.9715905162290916, 0.9573863639890362, 0.9388171351967521, 0.914988795432051, 0.8850558919825288], 'test_rmse': [1.0034418002899315, 1.0026183147738406, 1.0027397670297238, 1.0019399899715216, 1.002144063205074, 1.0017373858429575, 1.0014769502998462, 1.0005394979511437, 0.9989191790524794, 0.9943733509826671, 0.9905174139201011, 0.9868334266777521, 0.9839390831381271, 0.9824242489429614]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:17<00:00, 33.20s/it]


{'train_rmse': [0.9931521703505191, 0.9922958715738647, 0.9919971046421698, 0.9917809220484083, 0.9915767455792991, 0.9909807675991344, 0.9899869003716258, 0.9872527847977044, 0.9821340941499133, 0.9721769253232073, 0.9588921890445814, 0.9416898769123948, 0.9190342726793403, 0.8900585421279557, 0.8546654878478724], 'test_rmse': [1.004182934982479, 1.0030717472935686, 1.0026673974101508, 1.002316271983314, 1.0022949155856797, 1.001990407054214, 1.0018285377971754, 1.0002885032810314, 0.9989365950409005, 0.9950403137276701, 0.9908465872355384, 0.9870028220837704, 0.9842490869598967, 0.9824498938583837, 0.9815264826894917]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:15<00:00, 28.90s/it]


{'train_rmse': [0.9978467347214277, 0.9934815689435443, 0.9923167744832893, 0.9915318945537441, 0.9906693864004735, 0.989095754258329, 0.9857163464369797, 0.9787248960195365, 0.9671051849882006, 0.9502622476086009, 0.9280229669876697, 0.9008476415241677, 0.869001200208467], 'test_rmse': [1.0049642065682005, 1.002588372980139, 1.0021563919317877, 1.0021698954223701, 1.0019137153701885, 1.001876228042965, 1.0003624276424006, 0.9982400146761587, 0.9947020548170986, 0.9906704579412369, 0.9871248203344336, 0.9850879676510782, 0.9849980103975693]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:28<00:00, 29.89s/it]


{'train_rmse': [0.9975191053429785, 0.9935427591563454, 0.9925287558636849, 0.9917023570272215, 0.9910813089505843, 0.9899731711251172, 0.9876647381795063, 0.9823377639648834, 0.9728057874435082, 0.9584133094042226, 0.940030663629683, 0.9164009410416809, 0.8866636822517214], 'test_rmse': [1.004909371529848, 1.0026994208987512, 1.0021985933803617, 1.0018693946990405, 1.0019985897919588, 1.0018651758566555, 1.0010071931651012, 0.9988723327209907, 0.9958824916744697, 0.9912694447379725, 0.9874416490598227, 0.9851331167943781, 0.9833418628758115]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:14<00:00, 31.06s/it]


{'train_rmse': [0.9975905475379249, 0.9934810451868478, 0.9922912643223049, 0.9920570383629418, 0.9917729135547194, 0.990926996950355, 0.9899762399761143, 0.9870671669421196, 0.9814592455608494, 0.9711274525712498, 0.9566316101148006, 0.9384471399612487, 0.9151796995852766, 0.8850739420510028], 'test_rmse': [1.0048289528582544, 1.0023898665023945, 1.001588168699538, 1.0025797427004683, 1.0021830223129216, 1.0019522089600905, 1.0016218733105287, 1.0009468908203163, 0.9981773319798195, 0.9942740659513141, 0.9898275126089271, 0.9863581352442548, 0.9842264310592623, 0.9819787792077737]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:31<00:00, 34.08s/it]


{'train_rmse': [0.9976764261685893, 0.9935313961503731, 0.9927214034538048, 0.9920983594948619, 0.9919941568153375, 0.9909719045483886, 0.9898860731878178, 0.9871721236190895, 0.9816486277128872, 0.9719638042465544, 0.958028416662844, 0.9408606882504358, 0.9178763748799771, 0.8887903610051169, 0.8537498120949554], 'test_rmse': [1.0046301581958024, 1.0025900997710482, 1.0021697422511981, 1.0026614395522593, 1.002419795472505, 1.0015187326806412, 1.001365409192312, 1.0003000468774137, 0.9978900310134902, 0.9940680968056761, 0.9901663150403669, 0.9875248268758449, 0.9842556492303486, 0.9822540799123846, 0.9819960455490515]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:03<00:00, 32.59s/it]


{'train_rmse': [0.9932479203328837, 0.9923808765535569, 0.9918362996468646, 0.991295219481971, 0.9907348872593229, 0.9890951215112777, 0.9855880336586729, 0.9785286382719891, 0.9663952105861393, 0.9491368287655838, 0.9274566138407845, 0.9009165324042924, 0.8701892950451957], 'test_rmse': [0.9997427620865237, 0.9991172774193041, 0.9986419401562373, 0.9988029556697285, 0.9988163518114688, 0.9983599209570729, 0.9968620664377642, 0.9944030192943236, 0.9911236337859717, 0.9869351106770842, 0.983895573470314, 0.9815912460245281, 0.9814651148262626]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:31<00:00, 34.72s/it]


{'train_rmse': [0.9932472946707529, 0.992411529790563, 0.9924703044484346, 0.9920648053272487, 0.9915335952395273, 0.990041669165535, 0.987693931600461, 0.9827868120275044, 0.9734372011995069, 0.9591887441883042, 0.940500430411008, 0.9165547090020144, 0.8868167079513556], 'test_rmse': [1.000036781369622, 0.9988911105900226, 0.9994426011186353, 0.999112103337696, 0.9991475566139496, 0.9983967077747182, 0.9974074016998519, 0.9957868769034846, 0.9924261739573964, 0.9880732145592199, 0.9849085804060271, 0.9820022379869635, 0.9799170509051456]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [07:52<00:00, 33.75s/it]


{'train_rmse': [0.9933544460343484, 0.9926846945669341, 0.9920954658946323, 0.9918643859063205, 0.9916754335839152, 0.9914643787871932, 0.9903031795804403, 0.9878138659611309, 0.982057051679697, 0.9720336679871546, 0.9580862100861234, 0.9395056329406247, 0.915935772445775, 0.885558043540147], 'test_rmse': [0.9999063864687601, 0.9995302778610875, 0.9989818589562952, 0.9986906667282458, 0.9986278851422185, 0.9985007291805725, 0.9983944654221826, 0.9977724622511516, 0.9951190488692856, 0.9914116753012766, 0.9874023812836891, 0.9838696379967863, 0.9814512591854956, 0.9791105684108284]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:25<00:00, 29.68s/it]


{'train_rmse': [0.9932457400779853, 0.99274488450028, 0.9924883863147494, 0.9919100306015921, 0.9918049519892025, 0.9914419863525227, 0.9902977299952849, 0.9881033005681575, 0.9826871616385555, 0.9725461224829806, 0.9589712486401963, 0.9415711452271524, 0.9190293704764799, 0.8901377055675276, 0.8545915991107109], 'test_rmse': [0.9998310165341826, 0.9993227072873421, 0.9991148955805638, 0.9986743823936624, 0.9986849681204037, 0.9986335820937143, 0.9981649913739451, 0.9976829002517358, 0.9949941230846332, 0.9909353554758966, 0.9869981144453381, 0.9834646566486481, 0.9810175169052178, 0.9791822921955152, 0.9780739714000336]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:58<00:00, 27.58s/it]


{'train_rmse': [0.9976191517041116, 0.9939473385310225, 0.9925772091353225, 0.9916001784582155, 0.9907423937276909, 0.9893477999648908, 0.9858012742248073, 0.9786014259054421, 0.9665882007568699, 0.949435224493428, 0.9280709429616735, 0.9010094179104503, 0.8694762120718545], 'test_rmse': [1.001011022680189, 0.9993531241311385, 0.9985319600761291, 0.9987755732708455, 0.9984541937012129, 0.9979056898202691, 0.9970816577770505, 0.9946836108742875, 0.9908986123949555, 0.9866453454622381, 0.9836597691542328, 0.9819792045468813, 0.9813398887005396]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:50<00:00, 31.58s/it]


{'train_rmse': [0.9979853499393363, 0.9935824797414767, 0.9927402230368023, 0.9920123254910196, 0.9913986561606001, 0.9901841374285404, 0.9877704732861189, 0.9825496259081279, 0.9724373173885902, 0.9578419654431045, 0.939435193161837, 0.9152684261711675, 0.8855373630983581], 'test_rmse': [1.001490234975203, 0.9987342562037329, 0.9984658329465302, 0.9988497564046196, 0.9986595885251077, 0.9981406873691802, 0.9975312509148319, 0.9955685157686697, 0.9920047205434499, 0.9877180930360164, 0.9846013064685464, 0.9814186365598222, 0.9805020360449156]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [08:26<00:00, 36.19s/it]


{'train_rmse': [0.9979609811833797, 0.9939012296790688, 0.9927934497014076, 0.9925189335448216, 0.9920356567073468, 0.9912610704588046, 0.9898746222667805, 0.9873738312042308, 0.9807836838056981, 0.9702947087632731, 0.9562851276540129, 0.9384780926616636, 0.9146679845093699, 0.8847199832544219], 'test_rmse': [1.0012247684131266, 0.9988142946535123, 0.9987236129179085, 0.9987944708921644, 0.9986107489791556, 0.9982695792276601, 0.9978318459505876, 0.9972498705179897, 0.9939343214821156, 0.9901018141013224, 0.986815757108454, 0.9835708793438576, 0.9811463934928338, 0.978669530122395]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [07:44<00:00, 30.94s/it]


{'train_rmse': [0.9981823413209348, 0.9939760432726503, 0.9930384276281738, 0.9925469187492438, 0.9918780124429893, 0.9914946577870689, 0.9903927277392303, 0.987876524183928, 0.98266318894931, 0.9727533492450974, 0.9588458363522279, 0.9415235809852934, 0.9186884552280556, 0.8897204171206544, 0.8545785737334762], 'test_rmse': [1.0012433977140447, 0.9988176891162182, 0.9990133399184875, 0.9989684349853886, 0.9985194781657499, 0.9981969619035419, 0.9980405471534025, 0.9972193076970508, 0.9948826081681639, 0.9913827608986051, 0.9870313599676179, 0.9840862365131362, 0.9808708533450591, 0.9791171247351794, 0.9784053982461748]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:23<00:00, 29.52s/it]


{'train_rmse': [0.9927357138470686, 0.9918552508730689, 0.9912210437811778, 0.9907612265824188, 0.989994023294378, 0.9886439195355748, 0.9853005866549844, 0.9787224747947277, 0.9670045917969537, 0.9497001039410684, 0.9281751187181652, 0.9014502150556071, 0.8703559007248204], 'test_rmse': [1.0058334268150477, 1.0046554467592492, 1.0048160963161665, 1.0042969895009348, 1.0041226306899578, 1.0039904514506692, 1.0027568680804129, 1.000390478900281, 0.9969471416122515, 0.9925186500633222, 0.9890108505584501, 0.9871724778873135, 0.9869441935373621]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:16<00:00, 33.57s/it]


{'train_rmse': [0.9928908019492947, 0.9920602257576033, 0.9914481490642668, 0.9910815750513808, 0.9907447682443034, 0.9895480215239983, 0.9871851950502043, 0.98272362149029, 0.9731482371253125, 0.9592532316033118, 0.9410186259596566, 0.9170322937885659, 0.8871723601063185], 'test_rmse': [1.0055842637411732, 1.0052557279309815, 1.0044806640820412, 1.0041619120739564, 1.004372319478967, 1.0037900166493443, 1.0035639115743296, 1.001489269251819, 0.9980367631487536, 0.994086333983074, 0.9904451075067298, 0.9874049286821479, 0.9858049946608555]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [10:09<00:00, 43.53s/it]


{'train_rmse': [0.9929274656612602, 0.991998655423485, 0.991711033761242, 0.9915467046795127, 0.991308256395205, 0.990546341519639, 0.9894576976497755, 0.9870173980460695, 0.9811415268033896, 0.9707866054956853, 0.9568041316801473, 0.9390077612835067, 0.9155580729495264, 0.8858620270101382], 'test_rmse': [1.0055191944494197, 1.0046045060973716, 1.0046317815874528, 1.004228761206842, 1.0044430484738884, 1.0040684009554997, 1.0041083896727048, 1.0030462293723656, 1.0005992502165655, 0.9964652174114699, 0.9924664102655097, 0.9889089205904583, 0.9854705237300374, 0.9839897749489156]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [08:41<00:00, 34.75s/it]


{'train_rmse': [0.9928699702262065, 0.9921748267445537, 0.9917723732126824, 0.9913680954357488, 0.9911319890033354, 0.9907105980933327, 0.9895953193371438, 0.9874565272132121, 0.9821973525075774, 0.972438910314029, 0.9586101877833408, 0.9414243320407657, 0.9189896691684463, 0.8902302366244113, 0.8551931526266005], 'test_rmse': [1.0059923351740137, 1.004271076314599, 1.0049254358651416, 1.00442097163029, 1.0044438512057552, 1.0040184842904039, 1.0035555966468226, 1.0028915094376363, 1.000853268153444, 0.9966810617898264, 0.9923372780497833, 0.9884783488239223, 0.9859500861671234, 0.9839968464985628, 0.9828847594634526]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:22<00:00, 29.42s/it]


{'train_rmse': [0.997300096025101, 0.9930014520007766, 0.9920425729170119, 0.9910339306954956, 0.9903592962448314, 0.9886886545831491, 0.9853630135501773, 0.9781881804541175, 0.9662654692911337, 0.9496590792378622, 0.9284613355561127, 0.9019604101045796, 0.8704249212789597], 'test_rmse': [1.006196708636388, 1.0042130614907083, 1.004568147918899, 1.0040677088138803, 1.0042106108028732, 1.0038220735440178, 1.0026444428353223, 1.0005081038510109, 0.9961585986612143, 0.9925093252204488, 0.989064244322051, 0.9874443988488085, 0.9866300034213921]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [07:24<00:00, 34.23s/it]


{'train_rmse': [0.9973218855290986, 0.9931966996963449, 0.9921863802667578, 0.9915276258936382, 0.9905527681336489, 0.9895280799787022, 0.9874672954148195, 0.9818316636309578, 0.9722459854618312, 0.9576885742629405, 0.9392198508430176, 0.9157140513366917, 0.8861492379810516], 'test_rmse': [1.006218495745305, 1.0042356760105904, 1.0041826299559797, 1.0037741189642437, 1.0036516357907819, 1.0037294272419712, 1.0029646173595128, 1.0007291728392687, 0.9974038369711681, 0.992906138196922, 0.9893899982854268, 0.9873794641384477, 0.9853242462386408]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [08:13<00:00, 35.26s/it]


{'train_rmse': [0.9974438415947671, 0.9932919276835961, 0.9923945899487427, 0.9916213143253464, 0.991398015496332, 0.9906678582288231, 0.9895731486219138, 0.986895252618705, 0.9810471945364215, 0.9707613059173807, 0.956618029603107, 0.9384942458553562, 0.9147748875587081, 0.8851011721239777], 'test_rmse': [1.0060475768300918, 1.0044056937739458, 1.0038398886235476, 1.0042816027261108, 1.004210924212487, 1.0037717444880483, 1.0036229894045774, 1.003011227531608, 0.9997098762853958, 0.99612951020315, 0.9918275541097433, 0.9888741645216911, 0.9855056974545273, 0.984215363673973]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [09:00<00:00, 36.01s/it]


{'train_rmse': [0.9974509571488125, 0.9931687560693768, 0.9921875707263141, 0.9918880557139257, 0.9912684773068678, 0.9907315637923476, 0.9896685582410204, 0.9870401877501331, 0.981670698441401, 0.9717182244837685, 0.9579488773579621, 0.9404460157030615, 0.9178986884573156, 0.888794151710917, 0.854039694861646], 'test_rmse': [1.0059894594071805, 1.004198340725238, 1.0037579260378082, 1.004296909437736, 1.0037540015370336, 1.0040148122376442, 1.0034108215199535, 1.0025201427646693, 1.0001343113933163, 0.9963735008733579, 0.991911239940728, 0.9890595885276114, 0.9857437461475462, 0.9839736880782103, 0.9833512804514286]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [08:53<00:00, 41.03s/it]


{'train_rmse': [0.9927462771343444, 0.9918381974673803, 0.9915799287321911, 0.9912561166873521, 0.9901700752884457, 0.9886430270350753, 0.9854952243538831, 0.9790265149711629, 0.9668535924126318, 0.9501916777430711, 0.9284586541378305, 0.9010993185214836, 0.869342060334873], 'test_rmse': [1.0037770677155287, 1.0026089479062057, 1.0021753088433307, 1.00227589193635, 1.0023138816861894, 1.0018852632185675, 1.000974159583219, 0.998970405822788, 0.9950278220502573, 0.9911337248727093, 0.9876427356544754, 0.9856739132474843, 0.9849162154922025]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [06:22<00:00, 29.43s/it]


{'train_rmse': [0.9928456447579093, 0.9921780499146677, 0.9917886364667531, 0.9913490005027611, 0.990675651093685, 0.98977115086637, 0.9879207363570202, 0.9827973486657652, 0.9735495719312987, 0.9594430900525699, 0.9407774345141725, 0.916931074840191, 0.8870646869654683], 'test_rmse': [1.003422234826156, 1.0025346313080887, 1.0028766150924298, 1.0029914921627863, 1.0021536127099404, 1.0019142986851626, 1.001616686675396, 0.9993302652274793, 0.9959540762769996, 0.9921811822949154, 0.9882362220502289, 0.9853512004806462, 0.9834666928835518]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [08:37<00:00, 36.93s/it]


{'train_rmse': [0.9929005924480705, 0.9922964358508413, 0.9918119736690915, 0.9917482255233869, 0.9914533013210289, 0.990752840695291, 0.9899166471343493, 0.9873611229644089, 0.9819449179833997, 0.9716982581282431, 0.9578392366070501, 0.9394912940553916, 0.9160971080359748, 0.886254257763087], 'test_rmse': [1.003356237517619, 1.0025652207825857, 1.0021354744662736, 1.002054160326767, 1.0022233043672635, 1.0017517273084402, 1.0021761687974757, 1.0008069555557297, 0.9991106693198256, 0.9951764693834738, 0.9909287421004424, 0.9872010029608038, 0.9840190837995352, 0.9826670527086068]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [09:49<00:00, 39.30s/it]


{'train_rmse': [0.9928062379571724, 0.9921702297121506, 0.9919207034418334, 0.9916794999597932, 0.991363141075744, 0.9908236219737638, 0.9901029042759754, 0.9876784427486869, 0.9822413182692269, 0.9726007280436141, 0.9585699057007124, 0.9409724408809752, 0.9182857152557677, 0.8894116836761431, 0.8544776037100754], 'test_rmse': [1.0032822716553145, 1.0024738081824005, 1.002298326980569, 1.0021625362789668, 1.0021535594883393, 1.0023113650114999, 1.0021152126231494, 1.0010310131872249, 0.9989493989373283, 0.9954144038389481, 0.9907051030317396, 0.987099551934759, 0.9846831458795152, 0.9824109688746797, 0.9819955202183692]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [05:59<00:00, 27.65s/it]


{'train_rmse': [0.99750170450308, 0.9932417278288546, 0.9919823399617412, 0.9913912022900028, 0.9903885918054467, 0.9886027724176626, 0.9852529833858684, 0.9784991054592989, 0.9661174910807119, 0.949018182616513, 0.9271332472164392, 0.9003731295148727, 0.8686039012370975], 'test_rmse': [1.004011716176053, 1.002036737265489, 1.0017756820523733, 1.0021390179095744, 1.001701911875399, 1.0014299740213812, 1.000261399566345, 0.998654337579808, 0.9942233345561834, 0.9904320071307645, 0.9873484416117193, 0.9852249820801866, 0.9847002331017639]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [09:29<00:00, 43.84s/it]


{'train_rmse': [0.9975282782805682, 0.9933683989028153, 0.9923845089334282, 0.9917954943062958, 0.9911374749989882, 0.9902206431327595, 0.9877101828001894, 0.9826004143872228, 0.9735658479094891, 0.9595941251062464, 0.9416913400356445, 0.918140380244719, 0.8884848992800997], 'test_rmse': [1.0041076160988451, 1.0023106060545217, 1.0022496937858667, 1.002061943101451, 1.0021855612700572, 1.001888932699475, 1.0013572986295545, 0.9989678418501338, 0.9964963353476523, 0.9925198672251848, 0.9888022014538352, 0.9856378531037896, 0.9839317466266791]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [13:49<00:00, 59.22s/it]


{'train_rmse': [0.9977781586147081, 0.9937754583957098, 0.9924255559874939, 0.9919970109157182, 0.9913716500099581, 0.9911556837317065, 0.9897015184739639, 0.9867445228016811, 0.9806837234707962, 0.9702770100620238, 0.9560010378042993, 0.9379317197900351, 0.9146009777694268, 0.884891581210234], 'test_rmse': [1.0044132965007377, 1.0023816019727925, 1.0020432204178884, 1.001807096687386, 1.0018103431066503, 1.001965631059997, 1.0016806259943187, 0.9999961899679671, 0.9978845799080589, 0.9939496113100207, 0.9901833150178446, 0.9868070193549286, 0.9841990154979087, 0.9823055224324317]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [14:00<00:00, 56.06s/it]


{'train_rmse': [0.9977909599372575, 0.9935602690390563, 0.992374745535495, 0.9919009121072466, 0.9915920466659706, 0.9910109742330938, 0.9902790449485234, 0.9875282192749865, 0.981536985624934, 0.9718180488311665, 0.95808199069298, 0.9405871171442932, 0.9179779112415273, 0.8888950793422133, 0.8534686630067104], 'test_rmse': [1.0045124302074588, 1.0021423940227379, 1.0023220045617058, 1.0021638674292, 1.0018424966658461, 1.0021365916566674, 1.002177522875511, 1.0008256825138324, 0.998432328047091, 0.9946012694471822, 0.9908379292325782, 0.9871002270178066, 0.9842480890837852, 0.9823378734852941, 0.9818574102690186]}
('mean', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [10:22<00:00, 47.86s/it]


{'train_rmse': [0.9934885481190907, 0.9923424837969292, 0.9918091842714215, 0.991412399449728, 0.9905201307744281, 0.9889449220305748, 0.985558245651737, 0.9782215067578747, 0.9662602653093129, 0.9491598081218595, 0.9273861791416458, 0.9006798339674433, 0.8695654513266178], 'test_rmse': [1.0008481890735441, 1.0004929086365242, 0.9997333446500524, 0.9998603977288899, 0.9995502242221554, 0.9988650142708386, 0.9980539543826559, 0.9949703459720082, 0.9913112645262429, 0.9874260365712674, 0.9836855099812997, 0.9818049736879074, 0.9816766340901389]}
('mean', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [08:33<00:00, 39.48s/it]


{'train_rmse': [0.9932251763924926, 0.9922544182150541, 0.9920209517980358, 0.9918547712831998, 0.9912746798562503, 0.9900251493085284, 0.987683824504488, 0.9830186526961703, 0.9733387067348913, 0.9592021189296803, 0.94105302940231, 0.9175844031374044, 0.8879813402592439], 'test_rmse': [1.0010871263548728, 0.9998590511435782, 0.9998090002390109, 0.999739490358477, 0.9998875470708424, 0.9990012379883985, 0.9982211202405948, 0.9961089560567772, 0.992752038640795, 0.9886733102425896, 0.9853902316975891, 0.9828273103857347, 0.9804061046965853]}
('mean', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [09:38<00:00, 41.32s/it]


{'train_rmse': [0.9935280757775453, 0.992468421514163, 0.9923476848436813, 0.9919610128839899, 0.9917412117155698, 0.9911684284648304, 0.9900611621842142, 0.9875549661857529, 0.9815314694663276, 0.9715134187004958, 0.957608110625204, 0.9397153624745621, 0.9160556724768953, 0.8858970271944883], 'test_rmse': [1.0011504617586295, 0.9999343689974786, 0.9996681793289661, 0.9990832286917327, 0.9998335933682461, 0.9993985975528491, 0.999223245438974, 0.9979659108177589, 0.9953398474282196, 0.9914593867950932, 0.9876097247411334, 0.984327351676766, 0.9812503269536262, 0.9792869586162919]}
('mean', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [09:52<00:00, 39.50s/it]


{'train_rmse': [0.9933628844358721, 0.9926395906775496, 0.992359854381119, 0.9921305696998077, 0.9916729367321198, 0.9912815247269545, 0.9901805413485353, 0.9879146299919158, 0.9822727601328427, 0.9724201771983196, 0.9590260767260542, 0.9414809051879454, 0.9189082202309795, 0.89002315582095, 0.8548045550991591], 'test_rmse': [1.0009700890658007, 0.9997363895732865, 1.000343822111202, 0.9997111593460547, 0.9995780594748276, 1.0000864138759382, 0.9991217338267917, 0.9986542704436191, 0.9956728199542597, 0.991545363237706, 0.9875670044048748, 0.9845488195876946, 0.9809743888628083, 0.979084765431965, 0.9786820617021904]}
('zero', 96, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [10:18<00:00, 47.57s/it]


{'train_rmse': [0.9977350164030481, 0.9934971236466884, 0.9922239582524536, 0.9915165731657696, 0.9907637634206327, 0.989142809854146, 0.9857284169369133, 0.9788092624009888, 0.9668404818137147, 0.9497539018208869, 0.9282736248378529, 0.9015156068231821, 0.8702157238798344], 'test_rmse': [1.001895075580192, 0.9993225353199128, 0.9991953313116683, 0.9992621456916579, 0.9994848036965204, 0.9985089299117927, 0.9977020465515792, 0.9954844007110281, 0.9920169137915651, 0.9870436496924948, 0.9841928696321296, 0.9821352989435234, 0.9814798169254185]}
('zero', 148, 0.01, 0.02, 0.05, 13)


100%|██████████| 13/13 [11:08<00:00, 51.42s/it]


{'train_rmse': [0.9979773856427812, 0.9939211979940756, 0.9927752235154953, 0.9921266504517946, 0.9915355330157224, 0.9902296159910365, 0.987837471013719, 0.9829942175158043, 0.9733226620899896, 0.9592269430290041, 0.9403040226089004, 0.9168167026771428, 0.8871249782079682], 'test_rmse': [1.0015954099754465, 0.9993291036682017, 0.9995232619793708, 0.9994713870429365, 0.9995290205563283, 0.9990107121540538, 0.9983634970186108, 0.9966841804746459, 0.9930296391462892, 0.9889678589060058, 0.9849430552273251, 0.9822906304642306, 0.9810604478059882]}
('zero', 296, 0.01, 0.02, 0.05, 14)


100%|██████████| 14/14 [10:48<00:00, 46.30s/it]


{'train_rmse': [0.9978367476011591, 0.9938619069039715, 0.9927568995415474, 0.9925513925983381, 0.9919614873096502, 0.991274365248086, 0.9898594449291217, 0.9872844133634654, 0.980777163962837, 0.9706382135244959, 0.956355203404782, 0.9385961460558978, 0.9150643567667364, 0.8853600827459915], 'test_rmse': [1.0021450549434487, 0.9997019142063313, 0.9992051779852257, 0.9995482852170878, 0.9997175786098853, 0.9990679288420454, 0.9984950353227099, 0.9974183391062483, 0.9945349843321124, 0.9912753021697555, 0.9869020066711341, 0.9836642189311807, 0.9808399294733364, 0.9790593754026972]}
('zero', 324, 0.01, 0.02, 0.05, 15)


100%|██████████| 15/15 [12:10<00:00, 48.69s/it]

{'train_rmse': [0.9977055205328389, 0.9939558779041601, 0.9925434473040412, 0.9924339561815033, 0.9919025814071708, 0.9914006195921122, 0.9901897161735718, 0.9879807655529308, 0.9822755742481141, 0.9726154644063061, 0.95901564757073, 0.9416656358370904, 0.9192315416326645, 0.8900887829748052, 0.8548278352987392], 'test_rmse': [1.0021254934047916, 0.9998212838332466, 0.9993715454894507, 0.9992448703525656, 0.9991425440246346, 0.999343469980891, 0.99884108451085, 0.998238677145012, 0.9956558209951555, 0.9917779387087191, 0.9875370272007973, 0.9839175336959616, 0.9808862289118564, 0.9790646954685105, 0.97819611523108]}





- Entire dataset training

In [None]:
params = (
    ("mean", 96, 0.01, 0.02, 0.05, 13),
    ("mean", 148, 0.01, 0.02, 0.05, 13),
    ("mean", 296, 0.01, 0.02, 0.05, 14),
    ("mean", 324, 0.01, 0.02, 0.05, 15),
    ("zero", 96, 0.01, 0.02, 0.05, 13),
    ("zero", 148, 0.01, 0.02, 0.05, 13),
    ("zero", 296, 0.01, 0.02, 0.05, 14),
    ("zero", 324, 0.01, 0.02, 0.05, 15),
)

train_matrix = create_matrix_from_raw(data_pd)
for param in params:
    biases, features, eta, lambda1, lambda2, epochs = param
    fname = "irsvd_"+biases+"_"+str(features)
    print(param)
    model = IRSVD(train_matrix, biases=biases, features=features,
                    eta=eta, lambda1=lambda1, lambda2=lambda2, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)

### Baseline

- For ensemble training

In [None]:
params = (
    (3, 0.1, 3),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        K, lambda1, epochs = param
        fname = "baseline_"+str(K)+"_"+str(epochs)
        print(param)
        model = Baseline(train_matrix, K=K, lambda1=lambda1, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

- Entire dataset training

In [None]:
params = (
    (3, 0.1, 3),
)

train_matrix = create_matrix_from_raw(data_pd)
print(data_pd.shape)
for param in params:
    K, lambda1, epochs = param
    fname = "baseline_"+str(K)+"_"+str(epochs)
    print(param)
    model = Baseline(train_matrix, K=K, lambda1=lambda1, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)

### Global biases

- For ensemble training

In [None]:
params = (
    (0.001, 5),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        lambda1, epochs = param
        fname = "global_"+str(epochs)
        print(param)
        model = GBias(train_matrix, lambda1=lambda1, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

- Entire dataset training

In [None]:
params = (
    (0.001, 5),
)

train_matrix = create_matrix_from_raw(data_pd)
for param in params:
    lambda1, epochs = param
    fname = "global_"+str(epochs)
    print(param)
    model = GBias(train_matrix, lambda1=lambda1, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)

### SVProjection

- For ensemble training

In [None]:
params = (
    (5, 3, 10),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        eta, K, epochs = param
        fname = "svp_"+str(eta)+"_"+str(K)+"_"+str(epochs)
        print(param)
        model = SVP(train_matrix, eta=eta, K=K, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

- Entire dataset training

In [None]:
params = (
    (5, 3, 10),
)

train_matrix = create_matrix_from_raw(data_pd)
for param in params:
    eta, K, epochs = param
    fname = "svp_"+str(eta)+"_"+str(K)+"_"+str(epochs)
    print(param)
    model = SVP(train_matrix, eta=eta, K=K, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)

### Nuclear norm relaxation / SVT

- For ensemble training

In [None]:
params = (
    (1.2, 2000, 28),
    (1.2, 1000, 15),
    (1.2, 1500, 21),
    (1.5, 2000, 22),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        eta, tau, epochs = param
        fname = "svt_"+str(eta)+"_"+str(tau)+"_"+str(epochs)
        print(param)
        model = SVT(train_matrix, eta=eta, tau=tau, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

- Entire dataset training

In [None]:
params = (
    (1.2, 2000, 28),
    (1.2, 1000, 15),
    (1.2, 1500, 21),
    (1.5, 2000, 22),
)

train_matrix = create_matrix_from_raw(data_pd)
for param in params:
    eta, tau, epochs = param
    fname = "svt_"+str(eta)+"_"+str(tau)+"_"+str(epochs)
    print(param)
    model = SVT(train_matrix, eta=eta, tau=tau, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)

### Regularized SVD

- For ensemble training

In [None]:
params = (
    (96, 0.01, 0.02, 13),
)

for idx, (train_set, test_set) in enumerate(kf.split(data_pd)):
    train_data = data_pd.iloc[train_set]
    test_data = data_pd.iloc[test_set]
    
    train_matrix = create_matrix_from_raw(train_data)
    test_matrix = create_matrix_from_raw(test_data)
    
    for param in params:
        features, eta, lambda1, epochs = param
        fname = "rsvd_"+str(features)+"_"+str(epochs)
        print(param)
        model = RSVD(train_matrix, features=features,
                      eta=eta, lambda1=lambda1, epochs=epochs)
        print(model.train(test_matrix=test_matrix))
        rec_matrix = model.reconstruct_matrix()
        extract_for_ensemble(rec_matrix, fname, idx+1, train=True)

- Entire dataset training

In [None]:
params = (
    (96, 0.01, 0.02, 13),
)

train_matrix = create_matrix_from_raw(data_pd)
for param in params:
    features, eta, lambda1, epochs = param
    fname = "rsvd_"+str(features)+"_"+str(epochs)
    print(param)
    model = RSVD(train_matrix, features=features,
                    eta=eta, lambda1=lambda1, epochs=epochs)
    print(model.train())
    rec_matrix = model.reconstruct_matrix()
    extract_for_ensemble(rec_matrix, fname, 0, train=False)