In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)

In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/regular/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        }
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular))

{'final_act', 'bn_output', 'fc', 'pooling'}
['bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv', 'bl_stack.layer_list.2.b_conv']
num_param 55808
val metric init {'loss': 0.1476345956325531, 'loss_no_reg': 0.14376674592494965, 'corr': [0.058109296390085594, 0.02109747215932515, 0.027515191091892426, -0.03901109963146236, 0.01123947523614571, -0.05336567342002975, -0.00433439661308649, -0.01787438968481538, 0.03863201141127208, 0.15008972727755626, 0.023520704968540695, 0.1414776669269659, 0.08566432043736724, 0.029155014142259017, -0.03247472192507334, -0.029545185814865985, 0.0015026930543709923, 0.008137061881077714, 0.007465024719027984, 0.17541698514456022, 0.02855167940107918, -0.024510255505319403, -0.03355362851738701, 0.0415603074538, 0.024597719129324878, -0.0012241833990332763, -0.011557944100550114, 0.049707678594448926, 0.045988430414181665, -0.05450135989435649, -0.0036988108943219744, -0.05916807907626321, 0.07476149338066111, 0.0024263811188696928, -0.03706809

test metric {'loss': 0.12537304844175065, 'loss_no_reg': 0.12169387936592102, 'corr': [0.2640944698583386, 0.39853824833118345, 0.45451993854284034, 0.6945781093040657, 0.40965658030646035, 0.4846753140930752, 0.3012851611922463, 0.5069808098963553, 0.604526811899726, 0.6111520791415926, 0.5867179794380383, 0.39607203023578863, 0.503588709328948, 0.6043223743578908, 0.3252115388033378, 0.5233069907698216, 0.5833182637432823, 0.5259455059957385, 0.5100950576072043, 0.5279141245075993, 0.41974065521737786, 0.38763355466107674, 0.1567147118971369, 0.0929028895100969, 0.642274338816769, 0.4577207775371074, 0.365939869668403, 0.24906836112080463, 0.5691474649823973, 0.46873473816983163, 0.2833958909896383, 0.5178815113261437, 0.30132270630498026, 0.48846165444818757, 0.5393707888980279, 0.4592736838821383, 0.4968273609972808, 0.49701852722669393, 0.3662424239756368, 0.36309519868549095, 0.25685240429650513, 0.46997775831109595, 0.3388458359038314, 0.3441267788759555, 0.4154574471511596, 0.3

500-0, train loss 0.11549799144268036
train loss 0.11549799144268036
val metric {'loss': 0.11965274065732956, 'loss_no_reg': 0.11553753912448883, 'corr': [0.2261752526667472, 0.4720380604592146, 0.5302207283444262, 0.7313928967661143, 0.41555725513911373, 0.4881337180353353, 0.45550017860825365, 0.5387738810347298, 0.6250514360946369, 0.6569412923164155, 0.6338948716077628, 0.44606272715264883, 0.5534456204641975, 0.6106806369664659, 0.5033404970663641, 0.5629554743027476, 0.6605597008533499, 0.5759892024722122, 0.5449549815420831, 0.6315815938982899, 0.44424920053945904, 0.4019351103268985, 0.20256743839087832, 0.1163372618008698, 0.7315470797790292, 0.5098829942261722, 0.4421866155115973, 0.3293831351737197, 0.6287313065363634, 0.5278769812769307, 0.2994225955577393, 0.5400140973516222, 0.4833036124901677, 0.494246068939426, 0.594145331231263, 0.49157741499482654, 0.5211582638008629, 0.5624799530679154, 0.41387317452947414, 0.38470730164680955, 0.28552232032489455, 0.5319926010202363

test metric {'loss': 0.11776485506977354, 'loss_no_reg': 0.11372369527816772, 'corr': [0.29709497611845426, 0.47089579446666374, 0.5664735505095151, 0.7334289549622994, 0.44171977852461525, 0.5585991751295571, 0.5118284481705428, 0.6083062686113916, 0.6536716503351345, 0.6899869802932843, 0.6603241814136084, 0.44720215231293003, 0.5513191133187365, 0.6536488530550445, 0.4953738433702479, 0.5742424775404227, 0.698282479413654, 0.6092399248573497, 0.5337009513003119, 0.6053023218001599, 0.472554461831702, 0.44682995413125054, 0.2145302476166825, 0.09316198363923303, 0.7134237954373198, 0.567077725409288, 0.48641403604294803, 0.3110143769170534, 0.6070957274388507, 0.5061521733347399, 0.3299119656799985, 0.5881601863812862, 0.43450140094299644, 0.5750049628837655, 0.6056286610862986, 0.5259040391767746, 0.5652503329570076, 0.5611666018230301, 0.40136661873212076, 0.39907111018276126, 0.36548543365674907, 0.5288511937606472, 0.43237236661605116, 0.41604813145626734, 0.5424044100302309, 0.3

1000-0, train loss 0.11705164611339569
train loss 0.11705164611339569
val metric {'loss': 0.11635163724422455, 'loss_no_reg': 0.11225222796201706, 'corr': [0.26115502959012743, 0.49175355375423263, 0.5790859098198946, 0.7347143016256874, 0.48242675903598486, 0.5086637230970547, 0.47027274729621615, 0.626066400366576, 0.656550995377327, 0.6914012350674132, 0.6645056590947398, 0.48767292345498786, 0.5717103830854339, 0.6266053884442369, 0.5692760319972376, 0.5831987554353173, 0.6823626045820093, 0.6033598328185181, 0.5523044381237812, 0.6563961177218078, 0.4735187109100071, 0.4244188878907167, 0.2459054298672036, 0.11501404018408753, 0.7367110561136636, 0.5610884918344548, 0.5241942218666671, 0.38160064841808583, 0.6463337015093007, 0.5352844151109815, 0.33166411641588667, 0.5830962782559074, 0.5356303545315015, 0.5615333448369358, 0.6290032176025527, 0.5090357804130023, 0.5787700705711032, 0.5749481964696863, 0.43866608695420306, 0.3864159127291078, 0.35553712536357274, 0.53938998511561

test metric {'loss': 0.11607620758669716, 'loss_no_reg': 0.11200768500566483, 'corr': [0.3146012545756074, 0.4993813962345295, 0.5860080219487412, 0.7629155181128242, 0.48653943333898664, 0.5588651638371472, 0.5031405996531957, 0.6346695573859074, 0.6541073130351818, 0.6869184248388593, 0.685665757715407, 0.46447414873954357, 0.5811996241438956, 0.6589184817445772, 0.550947168094869, 0.5923732461920365, 0.7086982540899721, 0.6245256796597265, 0.5983639809252395, 0.6438477120055904, 0.4976547154187826, 0.45113062006805615, 0.2242516204602734, 0.09100319412797471, 0.7276229081387913, 0.5956348708086969, 0.5138980909606994, 0.3388791094464779, 0.6373497660032384, 0.5282015058342182, 0.35101995362939925, 0.5986912970592396, 0.4321291768294508, 0.5939166829551825, 0.6123993969928073, 0.5282180211403369, 0.6136585059166384, 0.5926228499141395, 0.43258628050718345, 0.4118495841717237, 0.3924690437295346, 0.548285462157333, 0.45321472779540817, 0.43154433774992207, 0.5693888282614485, 0.448370

1500-0, train loss 0.11329194903373718
train loss 0.11329194903373718
val metric {'loss': 0.11505215615034103, 'loss_no_reg': 0.11154534667730331, 'corr': [0.25399277923893826, 0.4917945927314893, 0.5817936442361638, 0.7518447840898075, 0.49835195456224646, 0.5099341491984958, 0.48965216659592636, 0.6387599425548025, 0.6560421073739492, 0.6983963781522484, 0.6769737138061085, 0.5119998911891934, 0.5903095210024928, 0.6348233424278906, 0.5833854764159718, 0.5981844606322237, 0.6964381107595213, 0.5880920856831636, 0.5861564225368949, 0.6848567883251895, 0.46302466156900485, 0.41589387743270295, 0.3520865132314674, 0.11206230371598162, 0.7440372430339062, 0.5860643822979068, 0.5616625383438891, 0.37905729143178535, 0.656357074508214, 0.5383624587258224, 0.34525773240871405, 0.5706105624986763, 0.49914231858256547, 0.5891813729610798, 0.6276052949706737, 0.5112263181080161, 0.599774442966547, 0.6038253395464994, 0.45459320101372924, 0.39025612844020763, 0.3538242285790812, 0.5734063218462

test metric {'loss': 0.11864792662007469, 'loss_no_reg': 0.11518799513578415, 'corr': [0.31612951692960817, 0.504412750892172, 0.5863343442497587, 0.7412299660548421, 0.4878830620306238, 0.549404816421001, 0.5057626014076466, 0.5773568476977705, 0.660453262161071, 0.6974805595292227, 0.6856383400648024, 0.4539630822638644, 0.5797735032563779, 0.6520469722631028, 0.5546914546212287, 0.5887909011911698, 0.708287787794068, 0.6165872251785145, 0.5639990352331327, 0.6392821708203508, 0.4852996580651291, 0.43567472813125596, 0.31316789583398263, 0.10181215369564134, 0.7284343541005368, 0.5650622885340608, 0.44588130734164255, 0.34056147363485334, 0.6367592057665894, 0.5197561004317597, 0.35594093732607845, 0.568996620855075, 0.43182471026630076, 0.5595350014813315, 0.6166091997509693, 0.5381601879274833, 0.6250679653464096, 0.5929334340601176, 0.4270110481713373, 0.40318357294303153, 0.40449871793207287, 0.531662299027563, 0.44789038185896635, 0.4392716910242298, 0.6009268733696596, 0.427062

2000-0, train loss 0.10203954577445984
train loss 0.10203954577445984
val metric {'loss': 0.11501713693141938, 'loss_no_reg': 0.11145175993442535, 'corr': [0.2529313846004355, 0.4924323406957351, 0.5937335789859535, 0.7644286330951867, 0.5080750390615764, 0.5171117431185475, 0.5068377932471995, 0.6551982505157368, 0.6685248427924384, 0.7083686743473967, 0.6833915008299964, 0.5156881919977749, 0.5989602132477595, 0.6578858990614075, 0.5875704679157474, 0.60700261366855, 0.6915047749039367, 0.6116028970325099, 0.5890360525208316, 0.6814722290144712, 0.4721920578395511, 0.4359153965121768, 0.35769710309145386, 0.10907352955747175, 0.7432678118789499, 0.5994307933883565, 0.554054101052627, 0.391942470505656, 0.6651387806640937, 0.5422613142313245, 0.36095315734308064, 0.5709426452568576, 0.5254410938139087, 0.6009657509422078, 0.6370163568455574, 0.5128887312474271, 0.6098133481977133, 0.616472875675491, 0.4695972739326508, 0.40125630842444737, 0.3588146439394073, 0.5777020075910491, 0.468

test metric {'loss': 0.11584983340331487, 'loss_no_reg': 0.1122511625289917, 'corr': [0.3296229127130987, 0.5169628443359394, 0.5972396631088634, 0.7623623537431665, 0.5103012443986098, 0.5853676804466579, 0.5339913070419314, 0.6253632553965661, 0.675853856644909, 0.7099748281556373, 0.6945936093577039, 0.4641826036138699, 0.5942189620092009, 0.6716410737259576, 0.5894684321794433, 0.6130185872532703, 0.7070394550569747, 0.6451078399879895, 0.6027412530272817, 0.6673119913810674, 0.5071599704367572, 0.46027759887221426, 0.3353044513059251, 0.10370008306596595, 0.7475553348916947, 0.6026493076167856, 0.5328064157132468, 0.34468071232359243, 0.640411935070286, 0.5314281173970793, 0.3765804960978605, 0.5952577666238751, 0.4495666676192306, 0.5972208261936542, 0.6258273273707781, 0.5398754545196263, 0.6287267114930403, 0.6097119157615049, 0.4518493204110802, 0.40280772359225814, 0.42683185618839126, 0.5692858176501462, 0.4555073037945569, 0.46021324976908085, 0.6400415291315618, 0.45421780

test metric {'loss': 0.11541333262409482, 'loss_no_reg': 0.11113443970680237, 'corr': [0.3235984567982407, 0.5095222651850762, 0.5955531434596519, 0.75376565138384, 0.4967928717380987, 0.5791202711784563, 0.5213510294508029, 0.6238246614791733, 0.670730351313945, 0.7002138092627447, 0.6886439230515216, 0.46907700123142676, 0.5858112626256915, 0.6605217485688393, 0.5730063282528761, 0.6103649801645846, 0.7184936916343818, 0.6353652566718735, 0.5920704852676459, 0.6640449204089776, 0.4955631143847633, 0.4591939974762439, 0.3167367616036144, 0.10587972079621405, 0.7411317751862105, 0.6065775598744112, 0.5362136100773184, 0.34594850169703206, 0.6396902752835303, 0.5246015881463704, 0.3593044581833511, 0.5942513347701321, 0.44879769054628377, 0.5962558090551876, 0.6199699556608873, 0.5329232417121168, 0.6301762387734133, 0.6032634384307262, 0.44474385616358836, 0.3993933938990343, 0.4134753767313699, 0.5553806259249581, 0.461445169416704, 0.45616582764148716, 0.6240859389129765, 0.440481369

300-0, train loss 0.10609088838100433
train loss 0.10609088838100433
val metric {'loss': 0.1139493703842163, 'loss_no_reg': 0.11024188995361328, 'corr': [0.2587007856581291, 0.49260368801591725, 0.5980820662196364, 0.7682757356454697, 0.5123388522908421, 0.5145038518039864, 0.5092216915723142, 0.6420010599084367, 0.6659540528662531, 0.7052051736738478, 0.6856275443755615, 0.5210619340165424, 0.6039872265454371, 0.6566449700615403, 0.6060938112814251, 0.6055986193986003, 0.6935324211252596, 0.6096473508325143, 0.5967063972245388, 0.688595863011753, 0.4780999394612501, 0.44573707437804216, 0.3671148554572965, 0.10771850209844691, 0.7477935242436577, 0.6015826683017975, 0.5492161077088272, 0.39316681967730216, 0.6714766772608485, 0.5464389638878187, 0.36459893819830075, 0.5829413607742536, 0.5263808091217057, 0.58650861600497, 0.6337617707659944, 0.5116754807309492, 0.6146318260298668, 0.6183947370410172, 0.47138393980985516, 0.39908808431996123, 0.38675200396589937, 0.5790238583075299, 0

test metric {'loss': 0.11562533463750567, 'loss_no_reg': 0.1118086576461792, 'corr': [0.32755021072280344, 0.509606486495397, 0.5985972316528211, 0.7634540587784541, 0.5074204357884706, 0.5834576293858831, 0.5302572628847961, 0.6272135116036959, 0.6739360277756821, 0.7107975742167782, 0.6956759959485335, 0.4713059564152645, 0.5959966986854943, 0.6721490054264503, 0.5820240130915508, 0.6114018104177549, 0.7197298803847583, 0.6419042392749422, 0.6030246596672759, 0.6682610329880991, 0.5059208383122569, 0.4633467483225867, 0.3240334730372007, 0.10363983184545242, 0.7457908113458402, 0.6068591546702786, 0.5370114519009852, 0.34765433104189625, 0.6420894672391788, 0.5278643856727245, 0.36426831217280686, 0.5929097356883039, 0.4486725441217666, 0.5989674086740613, 0.6265155288565626, 0.5393250567337045, 0.6322084892493514, 0.6079862148643844, 0.4508729853128798, 0.403810221965577, 0.421061406588489, 0.5608821484647725, 0.45861033547544783, 0.46135761660213315, 0.6310148838935418, 0.451341577

test metric {'loss': 0.1145750635436603, 'loss_no_reg': 0.11065530776977539, 'corr': [0.3274205136544549, 0.5082924814417211, 0.5991866503691072, 0.7612942438183246, 0.5043237106260434, 0.5805312887008068, 0.5315291388574823, 0.6303793285870372, 0.6740880698946917, 0.7070338155766015, 0.6933404090459204, 0.4717545642641444, 0.5955924630370479, 0.6691604763424812, 0.5686200892701686, 0.6125834066953975, 0.720329656445783, 0.6404793462041998, 0.6009388883642327, 0.6675259444965194, 0.5032729929822021, 0.4623074074839426, 0.3306431871957383, 0.10510107680915909, 0.7455249934903734, 0.6093394953373674, 0.5405802647071447, 0.34678475491652744, 0.6434103162391146, 0.5270200213169838, 0.3647121907008106, 0.5916971775729976, 0.4476087920370389, 0.6014743271027815, 0.6222104725371715, 0.537394727115849, 0.629461386563748, 0.6078865967692229, 0.44974036051244065, 0.4034705707230865, 0.4134359819518321, 0.5602932404608965, 0.45871160087787366, 0.4571008819143404, 0.6293543891592337, 0.45263295725

400-0, train loss 0.10422135144472122
train loss 0.10422135144472122
val metric {'loss': 0.11375600546598434, 'loss_no_reg': 0.11018171906471252, 'corr': [0.255759262345406, 0.4930739057958955, 0.6002465367218832, 0.7651421887416534, 0.5114942186287571, 0.5150428223721665, 0.5045897209236782, 0.6509587787078803, 0.664764369473938, 0.7072705382026773, 0.688264253349347, 0.5209861256901361, 0.6026811092670847, 0.6572353143693331, 0.607243219211193, 0.6064911868783216, 0.6956592406161847, 0.6088714262729616, 0.5973174823723859, 0.6904987524388231, 0.4776016388992874, 0.4437319011677402, 0.36703183634788117, 0.10907666369574483, 0.748390223685092, 0.6037500911392146, 0.5600472080227015, 0.393340675593045, 0.6707275239389672, 0.5487003218502987, 0.36180681634619793, 0.5852667052996152, 0.5209145719090098, 0.5945896371108422, 0.6347672053331489, 0.5110330433840558, 0.6146021655247221, 0.6184035493257489, 0.47215688825350965, 0.3977510803172619, 0.3819742758640588, 0.5810948431908554, 0.47388

test metric {'loss': 0.11483048647642136, 'loss_no_reg': 0.1109386757016182, 'corr': [0.3294093533535788, 0.5093959052504565, 0.600066643580512, 0.7643679334238298, 0.5089854636137634, 0.5827314258982443, 0.5308919435985806, 0.6301130965182367, 0.6743148456902663, 0.7098773267741918, 0.6959813285529892, 0.4717097622909127, 0.5957594035023419, 0.6706337021089852, 0.5790119293170148, 0.6134006027635808, 0.7200318989876173, 0.6422419774408124, 0.6018744576745585, 0.6679810954763218, 0.5041184542962783, 0.46240834524590946, 0.33173495249957, 0.10396752271012191, 0.7440340253548361, 0.6082628065574177, 0.5381500679149134, 0.34803598130851293, 0.6434924952566211, 0.5292376252589575, 0.36756842403844725, 0.5935065612893988, 0.4482519409301169, 0.6032035079826867, 0.6235042890988475, 0.5398882771003072, 0.6321379993574261, 0.6080497961433169, 0.45145802427602444, 0.40670884763397874, 0.41795888363117184, 0.5613986466801182, 0.45868563729550027, 0.45982833551140045, 0.6301020210118535, 0.453652