In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)

In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/yhat_reduce_pick-none/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
                'yhat_reduce_pick': 'none',
            }
        }
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular))

{'bn_output', 'pooling', 'fc', 'final_act'}
['bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv', 'bl_stack.layer_list.2.b_conv']
num_param 55808
num of phase:  3
val metric init {'loss': 0.1476345956325531, 'loss_no_reg': 0.14376665651798248, 'corr': [0.05479143575273985, 0.021254413116750567, 0.028083031191425682, -0.02689761876208481, 0.018408089976197063, -0.05305871449502127, 0.000890616674787484, -0.018328340050741153, 0.029774580188336328, 0.13997672398676186, 0.01945270579508741, 0.13919383005375596, 0.08593291578839071, 0.03413738215403807, -0.03227486692390781, -0.031208009155148257, -0.00396864610094288, 0.01026658771114074, 0.01001400298929947, 0.16845532532160407, 0.033960530713351, -0.031763678907935376, -0.02325077379200151, 0.040679488549285055, 0.019994279188352098, -0.0018366361673360287, -0.01355438184938194, 0.05334537127348192, 0.04617807397273281, -0.031888750882798605, -0.005485851927773809, -0.054832435986664416, 0.07274804282236529, 0.000314636888594

test metric {'loss': 0.12537304844175065, 'loss_no_reg': 0.12184615433216095, 'corr': [0.26668891239803905, 0.39829047475286317, 0.45453317003273463, 0.6914040391040088, 0.41299787606732696, 0.4981890257873129, 0.3027705867212853, 0.5107645953588396, 0.6036425895586957, 0.6110931651728992, 0.5832254207845864, 0.39442819899249015, 0.5040171062356245, 0.607859327954043, 0.32827786790082725, 0.5246501155656365, 0.5863361525519557, 0.5277001081554435, 0.5113616541659035, 0.5312200893080252, 0.41891026512015905, 0.38905708407465933, 0.16074591256640264, 0.09637126622703787, 0.648115645518495, 0.4570973484779144, 0.3608319488845919, 0.2511371218244277, 0.5651034406553264, 0.46937721509891284, 0.28810596311629855, 0.5367209096529011, 0.3037700372813904, 0.4852842675068276, 0.5379635830279356, 0.4567378641237111, 0.4935610023883735, 0.4988579773077788, 0.3650292144033006, 0.36065699050076183, 0.2575148751919426, 0.47168248244450217, 0.3436032867679716, 0.34344477195906087, 0.41188187044247965,

500-0, train loss 0.11549799144268036
train loss 0.11549799144268036
val metric {'loss': 0.11965274065732956, 'loss_no_reg': 0.11558865010738373, 'corr': [0.22730440714370068, 0.47316325924155217, 0.5332681839698357, 0.7323002230908262, 0.422388044148917, 0.4931126834590033, 0.45671382756698603, 0.547287791011817, 0.6288736539782532, 0.659829780681612, 0.6314066163872069, 0.4464401604060618, 0.5531312304225846, 0.6238361293341677, 0.503656380490772, 0.5620182873892606, 0.664591633137332, 0.577118455183671, 0.5410333216081862, 0.6320849671224102, 0.44354875310520997, 0.40122852844749723, 0.2018918503474933, 0.11801865326303516, 0.7343665496072092, 0.5120152367028997, 0.44750135820838455, 0.3263896287728895, 0.626395462140385, 0.5253062101922128, 0.29825150426977537, 0.546969225333971, 0.4908262044341599, 0.5030977823936278, 0.5975111196583625, 0.4989098884524522, 0.521145115952514, 0.5645079479086619, 0.41749818305083336, 0.3930624756545986, 0.28755156590829833, 0.5339888287791088, 0.42

test metric {'loss': 0.11776485506977354, 'loss_no_reg': 0.11367756128311157, 'corr': [0.2970933530811369, 0.4771421173066087, 0.5691253261797521, 0.7288675580037851, 0.4460102916046485, 0.5584772654471953, 0.5097644353783501, 0.6138727045085907, 0.6601127575574856, 0.6920722916835105, 0.6626799595241968, 0.4482655463643821, 0.5535856766672195, 0.6600484890605963, 0.5020553736402089, 0.5764360699507545, 0.7024386064251313, 0.6087693358703335, 0.5385038485550447, 0.6069361827217596, 0.47119714317142153, 0.44607182113110505, 0.21569069219200263, 0.09579695288068271, 0.7184786366562899, 0.5696074836741213, 0.4877789346788742, 0.31148816822731706, 0.6071198964655588, 0.5072907258767441, 0.32831753909059885, 0.5874971958486774, 0.4325887313358724, 0.5794853152104096, 0.6090338076462732, 0.526932559759876, 0.5689239186850291, 0.5633226797837316, 0.4022149101168063, 0.3944119075997369, 0.36632602690109195, 0.5294166526083696, 0.4417007276186615, 0.41455956663090077, 0.5508784937774691, 0.3939

1000-0, train loss 0.11705164611339569
train loss 0.11705164611339569
val metric {'loss': 0.11635163724422455, 'loss_no_reg': 0.1120227575302124, 'corr': [0.25792804531158614, 0.4960454838507307, 0.5807795534603429, 0.746272191652334, 0.48396559601710226, 0.5154515849036406, 0.46932794423044766, 0.6246674614671707, 0.6631558913654731, 0.6932695904346653, 0.6643484649461733, 0.4839919409817641, 0.570534544996481, 0.6361206770560782, 0.574177138342662, 0.5842238068640182, 0.688204723660721, 0.602674964158334, 0.550498252867883, 0.6569618747729253, 0.4744631975446428, 0.42699764486590464, 0.23959740794025103, 0.11425772194608952, 0.7469540366286425, 0.5583206076485498, 0.525331367182894, 0.3759208305578598, 0.6475910267584597, 0.5385950232018734, 0.3324512393094988, 0.5836291283678183, 0.5314387683442346, 0.5637882061354548, 0.6295834176784691, 0.512606367463429, 0.5776318261950146, 0.5769457552248822, 0.4391285767661593, 0.39602434443571766, 0.3570466768306616, 0.5425733368536875, 0.4609

test metric {'loss': 0.11607620758669716, 'loss_no_reg': 0.11152325570583344, 'corr': [0.31330864702161426, 0.4983377807845962, 0.5921891876813896, 0.7609112766955664, 0.4911767870764512, 0.5671669074144279, 0.5021780236481721, 0.6356825230513434, 0.664521714988532, 0.6949480209934541, 0.6882564509928143, 0.4659686897672769, 0.583455017635763, 0.6734196966899519, 0.5583857525148329, 0.5957300783530559, 0.7051205196820466, 0.6312919560627792, 0.5976160689342391, 0.6483963062409106, 0.4970573433534661, 0.45368875179822776, 0.2152424928615526, 0.08966375754222627, 0.7360253081339666, 0.5997476502681188, 0.5102021197298028, 0.33992837308040263, 0.6388860204869977, 0.5280450774515855, 0.34940345967565645, 0.5999719677095343, 0.4408312291945106, 0.5957286814043816, 0.6213211965263694, 0.5334078592870318, 0.6154513918537714, 0.5950693844865655, 0.43559245068924257, 0.4190338724597471, 0.40508785551137955, 0.5508015709152931, 0.45739598462621883, 0.4392803540052137, 0.5915218372954657, 0.45035

1500-0, train loss 0.11329194903373718
train loss 0.11329194903373718
val metric {'loss': 0.11505215615034103, 'loss_no_reg': 0.11053723841905594, 'corr': [0.25142244420584337, 0.4977878631279186, 0.5833912630101554, 0.7497633631584117, 0.5000281082276017, 0.5142742147358843, 0.48664610796713137, 0.6385696173443631, 0.6646559804395232, 0.7032348663871391, 0.6755824634238154, 0.5098476458421363, 0.5941647104632524, 0.6460077407500296, 0.5908599195725279, 0.6008518200274582, 0.705022385597404, 0.5987575315710978, 0.5818466411994216, 0.6895075821274543, 0.4626007463961561, 0.421335515565371, 0.340816828156542, 0.11231464314650783, 0.7561510311498965, 0.5843130143002189, 0.557875186748195, 0.37546956593196645, 0.654241997052815, 0.545310862769286, 0.34504337868766494, 0.5695204144320745, 0.5060338018080205, 0.5924941856644386, 0.6362732793701633, 0.5171824495423656, 0.5999672159886074, 0.6057443177738678, 0.45516816755432155, 0.40281694108008526, 0.3644032417518022, 0.5766389859928147, 0.4

test metric {'loss': 0.11864792662007469, 'loss_no_reg': 0.11389122158288956, 'corr': [0.31378950939536504, 0.5059527497143838, 0.5905230240373076, 0.735534848524776, 0.4913738543260583, 0.5621699827830193, 0.5054100529822316, 0.5900262266904519, 0.6667787025529768, 0.7046715324340442, 0.6871997901137334, 0.4569090251535475, 0.5826473888742835, 0.6632793296219835, 0.5643582947314867, 0.5928704416969204, 0.7064727540955259, 0.6231708445091968, 0.562253858449298, 0.6429693668758374, 0.4845347391221979, 0.4358530884829327, 0.3102539689055172, 0.09782164118001972, 0.7374462768890822, 0.5701098546541915, 0.45250142750221006, 0.34211178927249514, 0.6372766069731948, 0.5196795445088954, 0.3536429699286313, 0.5745080338589018, 0.43488458148480746, 0.5667684301837993, 0.6262097005621406, 0.5434652015573627, 0.6261709189177758, 0.5942911382235143, 0.43012210495812997, 0.41003826247778735, 0.4185160945187715, 0.5372205352744309, 0.45396457170528015, 0.4439475991854896, 0.6152532648413448, 0.42757

2000-0, train loss 0.10203954577445984
train loss 0.10203954577445984
val metric {'loss': 0.11501713693141938, 'loss_no_reg': 0.11052625626325607, 'corr': [0.24894215995541077, 0.49815272552512707, 0.5937115547114158, 0.7631070708655637, 0.505369982492488, 0.5210842636806687, 0.503128039961031, 0.6532258799631461, 0.6725657817546481, 0.711613093461384, 0.6830323975854029, 0.5160922652361435, 0.6014886009450068, 0.663400512964972, 0.5911711861726549, 0.6103568206585386, 0.6988089124431549, 0.6172941932838025, 0.5880813105533472, 0.6849273352861132, 0.4738754045019742, 0.4376358487456943, 0.350214110897047, 0.11115915908415828, 0.7539080122997539, 0.5993652611437199, 0.5584749241800482, 0.38844485657276984, 0.6642116339780085, 0.5468362481477216, 0.3571458643829677, 0.5723011383510895, 0.5295141681553762, 0.6063775791286141, 0.6415927988881867, 0.5189571664501309, 0.6116457761759111, 0.619795729311926, 0.47232468828859836, 0.41025431484774905, 0.3719592864441077, 0.5799907348895341, 0.46

test metric {'loss': 0.11584983340331487, 'loss_no_reg': 0.11116697639226913, 'corr': [0.3223623394298879, 0.5181742703721819, 0.6002715276991646, 0.7629135003515498, 0.5072726905973182, 0.5874201460354505, 0.5348683161289547, 0.6377586229878468, 0.6788171742141971, 0.7150677785039495, 0.6970865208188732, 0.46790585391476036, 0.5984023469948818, 0.6781625483769852, 0.593756485491552, 0.6161387590922491, 0.7119716784450374, 0.6470406701515437, 0.5998530545774575, 0.6687317046762462, 0.5096913638286746, 0.46176032422893887, 0.3236530478180653, 0.09939321701205567, 0.7495169445158135, 0.6076998647179949, 0.5342269056858284, 0.3483504979703558, 0.6449581356627409, 0.5324036744677846, 0.3727207845167701, 0.5938232333143921, 0.456243158797877, 0.605897728166174, 0.633239468999508, 0.5458504584962471, 0.6305573111756166, 0.6117971745681365, 0.45817347460295377, 0.40906240604205757, 0.4399557572100652, 0.5713322741532247, 0.46146832476080374, 0.463462635198875, 0.645882172520298, 0.45340344899

early stopping after epoch 2450 metric 0.10970117896795273
for grp of sz 35, lr from 0.001000 to 0.000333
val metric init {'loss': 0.11421069204807281, 'loss_no_reg': 0.10970117896795273, 'corr': [0.24529224642080882, 0.49655545284980124, 0.5945273962816132, 0.7655117723457139, 0.5114546783322542, 0.5204009343056589, 0.49981684084761246, 0.6400567398761772, 0.6698076377687445, 0.7064877380516769, 0.6854281679557548, 0.5193583463352558, 0.6049081375317104, 0.6629787901066913, 0.6097455215602932, 0.6094756161125325, 0.6970625889887907, 0.6130738392862591, 0.5944879409823656, 0.6856532877077518, 0.47439382488529896, 0.4467708402423675, 0.3542268470513504, 0.11172729015590353, 0.7576653590621698, 0.5999833810587943, 0.5481526504898153, 0.3894064666600926, 0.6695665560119921, 0.5507742068169914, 0.3590025087178114, 0.5807059745259749, 0.5273385703971958, 0.5896208494664159, 0.6390562399301019, 0.5142556883581317, 0.6192560299043636, 0.6206379589844806, 0.46767675901636746, 0.405885452060713

200-0, train loss 0.1036161258816719
train loss 0.1036161258816719
val metric {'loss': 0.11397016197443008, 'loss_no_reg': 0.10962973535060883, 'corr': [0.2541773450355762, 0.49816866601922283, 0.5973978796099332, 0.7654184893773615, 0.5163776597157753, 0.5212356380435261, 0.50881508449593, 0.6534904677150739, 0.6732331137070049, 0.7123230446713341, 0.6914044787578288, 0.5243891070462874, 0.6060504103737616, 0.6648054838264377, 0.612223686104286, 0.6137761977893588, 0.7017139249532915, 0.6158084015751573, 0.5953471255498479, 0.6901557611533922, 0.4795618122676295, 0.44167981206216683, 0.35603847678906253, 0.11115737403085876, 0.7611828517933679, 0.6053652848047938, 0.5614300760134613, 0.39083777113514895, 0.6708257438781269, 0.5545278407148124, 0.36600100261388613, 0.580352600715138, 0.5265589798406562, 0.6000755790308322, 0.6397844617538506, 0.5200503250355833, 0.6194180300627871, 0.6275087172958353, 0.4775397671218702, 0.40935024942088527, 0.39019046385712286, 0.5869655607109866, 0.4

test metric {'loss': 0.1153543580855642, 'loss_no_reg': 0.11075492203235626, 'corr': [0.32483287893725066, 0.5113070992592473, 0.6000821045442188, 0.7657816470248935, 0.5068759184510732, 0.5891387071168996, 0.5332705940238043, 0.6427702399477102, 0.6780449265765809, 0.7164128782662105, 0.7001495872055408, 0.47653658006383354, 0.6009434215368371, 0.6796523437055793, 0.5939625191855036, 0.6174005509064642, 0.7193718461098303, 0.6480664239609677, 0.6004859014217188, 0.6692641658489112, 0.5070528153051896, 0.4636253245700416, 0.32530413907390704, 0.10320996641141766, 0.7499201303646179, 0.6133972434935435, 0.5361426919583596, 0.3503270238127063, 0.6483815339103091, 0.5333933271222977, 0.36980948966387206, 0.5935946680592254, 0.4569216121612725, 0.6071516140181026, 0.6341791004868389, 0.5478390605807447, 0.6353250074409507, 0.6124024329842473, 0.45723960781624884, 0.4117219065841926, 0.43885755488261163, 0.567851293509698, 0.4661210822337364, 0.46370980794894684, 0.6430940597292776, 0.45672

early stopping after epoch 650 metric 0.1094382107257843
for grp of sz 35, lr from 0.000333 to 0.000111
val metric init {'loss': 0.11374663710594177, 'loss_no_reg': 0.1094382107257843, 'corr': [0.25371723316736106, 0.4966919048279558, 0.5988807783821628, 0.7705370772359479, 0.5140785141497628, 0.5212304615761163, 0.5082065036198061, 0.6468059646311198, 0.6723822310006499, 0.7097046724709647, 0.6875946973728886, 0.5210920222474233, 0.606279224540635, 0.6639321311724495, 0.6110319048439367, 0.6097079078132172, 0.7002130894034012, 0.6177441215797566, 0.59526159988906, 0.6909966505461462, 0.48041202983208586, 0.44685534472643906, 0.35813966302115424, 0.1106297808489182, 0.761253952278475, 0.6030705362068536, 0.555312412025966, 0.39124809167395863, 0.6719663471936526, 0.5507433081943678, 0.36461755197935475, 0.5837592171584198, 0.5303960781722901, 0.5961753495550206, 0.6388076743766458, 0.5199843296483715, 0.6182576333507622, 0.6235008116269531, 0.4749193253114405, 0.4081371023151762, 0.396

200-0, train loss 0.10236421227455139
train loss 0.10236421227455139
val metric {'loss': 0.11390146166086197, 'loss_no_reg': 0.10965630412101746, 'corr': [0.25181998268098393, 0.496324207275021, 0.5994084017447593, 0.7681101339386238, 0.5139236018303277, 0.5220324098857808, 0.5073887754303528, 0.6528775968190228, 0.6732038254613988, 0.712434712891973, 0.6906874689472164, 0.5230529897409537, 0.6066677453989378, 0.6650818665804166, 0.6105909113544734, 0.6121713973669358, 0.7002520102160794, 0.6171570790779382, 0.5961374601531417, 0.6925191262673649, 0.4802310974464288, 0.4457581121860523, 0.3583867602243841, 0.11163568118048525, 0.7613878676165302, 0.6058514032996889, 0.561598627561031, 0.3912222216029063, 0.6718378024407343, 0.5536473461977436, 0.3628576147372517, 0.5842787416111725, 0.5290207969280469, 0.6017136651020242, 0.6398127499762923, 0.5195444672198783, 0.6196861975273125, 0.6248597490516496, 0.476312787935647, 0.40765891187946357, 0.3902280446351901, 0.5859147373012391, 0.4706

test metric {'loss': 0.115038843027183, 'loss_no_reg': 0.11050818115472794, 'corr': [0.3250700208817723, 0.5133494879687417, 0.6025044806377774, 0.7651935174598432, 0.509298039797472, 0.5882156147118883, 0.5321554695157775, 0.6395893070793394, 0.6784906581632972, 0.7167811922218781, 0.7000825872433099, 0.4741174620357152, 0.6004756775388458, 0.6801455929716932, 0.5888122320534771, 0.6156280290027288, 0.7197103679197039, 0.6488018230116734, 0.6032645124373184, 0.6715309208428821, 0.5104563952336778, 0.4635284837260045, 0.3246785702135029, 0.1017586743064561, 0.7497772432361246, 0.6142283975156391, 0.5379973616425399, 0.3518890450870999, 0.6461416589622775, 0.5301091562549799, 0.3669647774810556, 0.5895296256959393, 0.4564097799252102, 0.6093264094783014, 0.6338716172668171, 0.5473445219251833, 0.6369512990178885, 0.6122308805546121, 0.458996612340873, 0.411730770302956, 0.43929529178325744, 0.5688489451673264, 0.46324439528142597, 0.4664976450285423, 0.641517658617246, 0.457387208649424