In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)

In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

{'final_act', 'bn_output', 'fc', 'pooling'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 27629
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

test metric {'loss': 0.13860258566481726, 'loss_no_reg': 0.13956981897354126, 'corr': [0.14860692933452235, 0.25416953724197694, 0.33744270929263703, 0.546258764234558, 0.2822274029505801, 0.23521413672176042, 0.3271851028805739, 0.312357490546624, 0.5331901713262637, 0.5346084506739779, 0.48691087358317964, 0.31105541855955976, 0.4141836057000511, 0.56420434669882, 0.24477889114864276, 0.4294639822785121, 0.21669760720986642, 0.42013276485508344, 0.39507585749676555, 0.4248954333047475, 0.33305866764540315, 0.27478069849045217, 0.12522043019110346, 0.041683770943703145, 0.47831911140858674, 0.3452504551616487, 0.17236247851487407, 0.19119505079856844, 0.4521112753057789, 0.38071501449149897, 0.13720941118529514, 0.2797350854256789, 0.25441102347400557, 0.3083272473832652, 0.4577498524449984, 0.3993244430224905, 0.39737422573836667, 0.3938491291385501, 0.30651268823908073, 0.3383366263518027, 0.21740314601665503, 0.3737167840162677, 0.15047215656164997, 0.2324630744946616, 0.2233772031

400-0, train loss 0.11778386682271957
train loss 0.11778386682271957
val metric {'loss': 0.11964204907417297, 'loss_no_reg': 0.11480004340410233, 'corr': [0.23753142294458046, 0.4851689568960317, 0.5247224091269544, 0.7305381827748376, 0.4409963283662571, 0.5068502001209796, 0.3544804553146137, 0.5738624825427648, 0.6541237337792177, 0.6731278639514499, 0.6444108901542658, 0.47131471299960703, 0.5601488190689282, 0.6375729594362596, 0.5369702501896936, 0.5805227985424579, 0.6571378614957151, 0.5659066465767151, 0.5453153139204865, 0.6349045758753779, 0.43248422555205646, 0.39081119383676716, 0.249180311805732, 0.1156314947304509, 0.7489845078081934, 0.5194247182431857, 0.5057109533774716, 0.32333090362490247, 0.628918995087143, 0.5190701064233916, 0.3029096262732701, 0.5389595215457672, 0.3997857647207157, 0.5430763488763359, 0.6068392898791509, 0.49548181266684704, 0.5547365495042753, 0.5684078820928835, 0.4315164885444356, 0.3959643199015475, 0.30286968315435703, 0.5441841466903573, 

test metric {'loss': 0.11658035750899996, 'loss_no_reg': 0.11180318892002106, 'corr': [0.3142636716623632, 0.5004718261393499, 0.5817844646953374, 0.7544649692616301, 0.46526902969831796, 0.5647823433529315, 0.4165367045742264, 0.6040166006297001, 0.6697979007800896, 0.7221616265898618, 0.696898716793749, 0.46975083082557045, 0.582487572032788, 0.6694307917581749, 0.5851400330207408, 0.6152206344723539, 0.7074186441217706, 0.6099806053198991, 0.5721976780971512, 0.6214747577534656, 0.49133730457718383, 0.44045169100583736, 0.2201882977437229, 0.08358205554284374, 0.7369621753867178, 0.5880222776981822, 0.5205335670517942, 0.3333701440513802, 0.6372807086031058, 0.5248736014494061, 0.34919692305924344, 0.5944470974055228, 0.3744618346568098, 0.5887453818292439, 0.635947367173409, 0.544463372822144, 0.6086060749694674, 0.5833513522635722, 0.4396800812236648, 0.4052510137702149, 0.41403324024715055, 0.5712118736309596, 0.44910948644861415, 0.4224363311238591, 0.5737261431970206, 0.4250498

900-0, train loss 0.11122670769691467
train loss 0.11122670769691467
val metric {'loss': 0.11356288194656372, 'loss_no_reg': 0.10872039943933487, 'corr': [0.27493438078022187, 0.49591314482232135, 0.6095396191722203, 0.7609251104665373, 0.5157524057177048, 0.5114117600398455, 0.37613024865976646, 0.611330386261545, 0.6821726099380778, 0.7135247167277441, 0.6918158270974454, 0.5261745402079776, 0.6141814167358844, 0.6471535401392061, 0.6409303762438668, 0.611086969292044, 0.7120868946675776, 0.6179575801572722, 0.5877196768130085, 0.6933657229782871, 0.47975903568630285, 0.44725612646382606, 0.3235842121397407, 0.1009242133433201, 0.7548534775483768, 0.6091651071098022, 0.5537952552830112, 0.35504043304481636, 0.6629811749373041, 0.5539360771632104, 0.3587021089570176, 0.5733640609986768, 0.45421140243463554, 0.5772837967820563, 0.6406222497330516, 0.5325471021693344, 0.6368938066419991, 0.6389076987121172, 0.493066593532123, 0.412546547184711, 0.4201598348155952, 0.5922455060409639, 0.

test metric {'loss': 0.11310180595942906, 'loss_no_reg': 0.10815174877643585, 'corr': [0.34638573847930276, 0.48960161984745193, 0.6053577750114418, 0.7634290869639808, 0.5171876638026096, 0.5766591331133465, 0.430817281260504, 0.6444399304141037, 0.6673766960474339, 0.7167863086961143, 0.7066532623771198, 0.49222310365517663, 0.6128561161933073, 0.6837965365517629, 0.6445137929654315, 0.6265327695170644, 0.717126507330474, 0.6456790921051473, 0.6113081244469192, 0.6663363694208755, 0.5094319873933018, 0.46255765262019, 0.3047907394472492, 0.08026550903106913, 0.7488075500932352, 0.6287829185592003, 0.5383914341647296, 0.33400113655953256, 0.6438164481067277, 0.5481867847988242, 0.38931312874091617, 0.5955447801527812, 0.3697882273571255, 0.5998242245208409, 0.6379153414042216, 0.5451775876797558, 0.6420440214133578, 0.6245261429454354, 0.4668898887588322, 0.4241480184137717, 0.464370428705237, 0.5831004569501426, 0.4532732353525223, 0.4566084601306202, 0.6394714093416038, 0.4730346711

1400-0, train loss 0.11277475208044052
train loss 0.11277475208044052
val metric {'loss': 0.11209384799003601, 'loss_no_reg': 0.10751261562108994, 'corr': [0.27629469084880576, 0.49976575756864167, 0.618785929224273, 0.7580246441712502, 0.5306229558825748, 0.5217822226586366, 0.3829903653702402, 0.6284133327567701, 0.678694506759398, 0.7173853144915041, 0.6958430016161135, 0.5467897333582729, 0.6192234492225288, 0.6501882576066228, 0.647789932731529, 0.6163664473080142, 0.7176808997271371, 0.621271741899846, 0.601600604435723, 0.6954102241800582, 0.4861823063486512, 0.46050036106817116, 0.3407401904902183, 0.11017265526903536, 0.7543605628020846, 0.6300377654173003, 0.5570542702554764, 0.3642220436806654, 0.6644564866684253, 0.5603073607318968, 0.3740192354335632, 0.5833401595426513, 0.4609990472154718, 0.5901330928827273, 0.6467492441688665, 0.5292126998597773, 0.6503680222001367, 0.6536105327704879, 0.5085923969376318, 0.4191971389869915, 0.4486684885407149, 0.6021238077358704, 0.448

test metric {'loss': 0.11224257520266942, 'loss_no_reg': 0.10763007402420044, 'corr': [0.33961150922747996, 0.5018333272392694, 0.6030074991834518, 0.757866884546095, 0.5348475267401468, 0.5819027968038643, 0.4449614728673239, 0.6493724292516145, 0.6793371953655903, 0.715292623540763, 0.7128823952559173, 0.494780342188578, 0.6195202805953799, 0.6819722823432148, 0.6407689597435693, 0.6313525016972288, 0.7219776504592649, 0.6577244581712748, 0.6134554673206496, 0.6663273584246092, 0.5096980918783743, 0.47239953181849514, 0.3079069049184574, 0.09041282421586236, 0.7493542227030439, 0.6425292339255058, 0.5445913497253712, 0.339894498125491, 0.6403864723588282, 0.553384179184466, 0.3880260203901177, 0.5982675833625175, 0.38220447951720804, 0.6002688281062858, 0.6420777054834896, 0.5528038305329076, 0.6400147699679049, 0.6342047621176311, 0.4777108794275019, 0.4243005064496444, 0.4659761167678583, 0.5915400096538055, 0.45793285513054083, 0.469466782150058, 0.6431300372812407, 0.481840971390

1900-0, train loss 0.10999299585819244
train loss 0.10999299585819244
val metric {'loss': 0.11150295585393906, 'loss_no_reg': 0.10705718398094177, 'corr': [0.2881664711353198, 0.506637236994047, 0.6156127965682113, 0.7590296984328238, 0.5398440120447721, 0.5185285548546569, 0.4038239195436695, 0.6349202040993647, 0.6776039927355333, 0.7232723668370422, 0.7020957382520536, 0.5412438058291142, 0.6240511170339686, 0.658146230865964, 0.6533258725183198, 0.6137156597600912, 0.7201446766389654, 0.6240564931342107, 0.6029287418356408, 0.69356557138483, 0.49572918667574944, 0.4601558433431267, 0.3394829710403974, 0.11958151243912726, 0.7654978577480351, 0.6447134800184791, 0.5585391957841535, 0.37330831862493846, 0.6670345463909768, 0.5668351818547703, 0.3820750224216978, 0.5828405411633514, 0.4863055702951676, 0.603447049755193, 0.642740800411553, 0.529048573639752, 0.6531403303567374, 0.6659690963915013, 0.5183612023682878, 0.4188629051192155, 0.4585761987673881, 0.6065746610984688, 0.442843

test metric {'loss': 0.11158216957535062, 'loss_no_reg': 0.10713132470846176, 'corr': [0.34167897116012813, 0.5129931813634505, 0.6046508181032645, 0.7672349447229171, 0.5443332856164975, 0.5816436299014572, 0.4592620091295228, 0.6458718078848902, 0.6767221221712205, 0.7246015881809281, 0.7148108055720158, 0.4957630186083188, 0.6275135729072948, 0.6820049458109536, 0.6394796715724897, 0.6362646407805387, 0.7235321112342759, 0.6502293886758148, 0.6190821292777933, 0.6678232053549054, 0.5111608779553668, 0.4639980145031005, 0.30501931566103574, 0.104497678672007, 0.7510666701553899, 0.63652919035082, 0.5396371502022524, 0.343461134442814, 0.642483270255928, 0.5471054894072905, 0.3879985472369233, 0.5993492938712661, 0.41328967732354627, 0.5991983908280001, 0.646658425142362, 0.5541703991943225, 0.6542368581815337, 0.6372875100416115, 0.47239350495437293, 0.41830471849806594, 0.46983275900572247, 0.5918103626471881, 0.4533729398913723, 0.4586826662382555, 0.649747738269311, 0.471058871816

test metric {'loss': 0.1117876757468496, 'loss_no_reg': 0.10703791677951813, 'corr': [0.34213191636941054, 0.5088999588708638, 0.6042377424386604, 0.7664295287149474, 0.5368029682817272, 0.5736553470697044, 0.4354275979945321, 0.6470137092447539, 0.6743158877962266, 0.7222167914905254, 0.7129788635688966, 0.4904340844948572, 0.6269512201830206, 0.6861219012014171, 0.6444556939440311, 0.6352474792694011, 0.7219880412828411, 0.6522116457358189, 0.6143717496283152, 0.6699399401423164, 0.5143000525227978, 0.45613557434132623, 0.2998463617008118, 0.09196158488636844, 0.7527929810801518, 0.64122201533087, 0.5398625857903178, 0.3398655540713262, 0.6434998951308526, 0.5529958545936857, 0.39183345008758885, 0.5935658916677725, 0.3874100257516435, 0.6041823185375704, 0.640152998821805, 0.5513436056717542, 0.654519974686872, 0.629644925077941, 0.4759310645992404, 0.42643584965941334, 0.4719157188478638, 0.5894244121982097, 0.4504254443452422, 0.456517521215675, 0.6401499993113435, 0.4733796428976

300-0, train loss 0.10803961008787155
train loss 0.10803961008787155
val metric {'loss': 0.11096786111593246, 'loss_no_reg': 0.10661537200212479, 'corr': [0.28140090631544673, 0.5042892507680508, 0.6117037239743293, 0.7632397701555984, 0.5340099628528784, 0.5208399520599654, 0.40313616268803654, 0.6438152375983431, 0.6785619838549417, 0.720007467179399, 0.7031237998178568, 0.5457563603841254, 0.6273341300277149, 0.6536672708814182, 0.6520490786386777, 0.6176259929272195, 0.7232179186499074, 0.6247850812092032, 0.6033005657188164, 0.687967676922308, 0.49370184073024664, 0.4561333200711375, 0.3476852310078471, 0.11481481879626027, 0.7586639317894791, 0.6399945731100272, 0.5566352653335793, 0.37676906415737144, 0.6692185130209517, 0.5699699637749762, 0.3844643142114236, 0.5815104754694971, 0.47212739963915873, 0.6056620850922967, 0.6422515826740317, 0.5290187722081676, 0.656298640504553, 0.6635400505447231, 0.5131962816817224, 0.4106886629714106, 0.46676471704946404, 0.6085415221715094, 0

test metric {'loss': 0.11137162574699946, 'loss_no_reg': 0.10691983252763748, 'corr': [0.3451528929706317, 0.5056375447224128, 0.6078958998218369, 0.7674307335190157, 0.5423437983711026, 0.5800362790323883, 0.4498187566846854, 0.6413611319773858, 0.6787898151952083, 0.7259745763593525, 0.716636180700659, 0.4928477159843405, 0.6290933091289173, 0.6829135315051156, 0.6424822668728132, 0.6352589687146126, 0.7238724376846999, 0.6515449053403546, 0.6202270324368007, 0.6662519637965154, 0.5177190771437423, 0.46774334937158696, 0.29764986097322366, 0.09258401667699238, 0.7506874422019602, 0.6442301328456839, 0.5443551914218728, 0.337590781804488, 0.6427711907438454, 0.5502610008720673, 0.3813189204842273, 0.6038578057380226, 0.4018078910683176, 0.6089582513552428, 0.6445395406464614, 0.5537724643916067, 0.655051615047251, 0.6355057980650548, 0.47895944002554514, 0.42129397485461356, 0.47245274230708323, 0.5914133250533368, 0.4591922361497429, 0.4624565937792436, 0.6483350541101296, 0.48038127

test metric {'loss': 0.11091699983392443, 'loss_no_reg': 0.1063288226723671, 'corr': [0.34751385302460475, 0.5068112606401122, 0.607759726422712, 0.7667144291152227, 0.5404506590090279, 0.5792940399460258, 0.44141967733628706, 0.6500050432599925, 0.6781810035154566, 0.7272214376897759, 0.7173424508200394, 0.49364941452489886, 0.6293272545339024, 0.6850838813608227, 0.6457838491756467, 0.6375132773973442, 0.725060783553295, 0.6534765243508514, 0.6188985726535967, 0.6711171440191857, 0.5156454729673945, 0.46456806616500423, 0.30319989630791716, 0.09191542244569659, 0.7534269856499353, 0.6457735819471656, 0.5469113613110405, 0.3435725988819768, 0.6441095829893756, 0.5531042693915107, 0.38596728689691195, 0.6010616690311399, 0.39262388887762806, 0.6096161136083473, 0.6430692807612496, 0.5566012819509416, 0.6532086210629516, 0.635005131898247, 0.4785330285684967, 0.4265969257132305, 0.4715130738023038, 0.5931020454632367, 0.4540913974044555, 0.4626305375652729, 0.6491311289445796, 0.4811379

400-0, train loss 0.10303798317909241
train loss 0.10303798317909241
val metric {'loss': 0.1106079563498497, 'loss_no_reg': 0.10625215619802475, 'corr': [0.2840046794928104, 0.5110491514091009, 0.6158222282228474, 0.7615776008927688, 0.5362408326739302, 0.52231697567874, 0.39881435837709817, 0.639210507684058, 0.6794136238176357, 0.7234676051378456, 0.7025121584727038, 0.5446348868748399, 0.6281592554351199, 0.6546656438178486, 0.6541042420008193, 0.6155048249689741, 0.7253952040481266, 0.624489755731501, 0.603620256796419, 0.6929005324417205, 0.49235054226322184, 0.4614438893129841, 0.34630115269499223, 0.11302070306918355, 0.7606913742996033, 0.6388042781098817, 0.5593753047767611, 0.3742152145059144, 0.668814837062974, 0.5707138585459172, 0.38288621235228665, 0.5857748588869909, 0.4733070845865712, 0.6020141728364773, 0.6438965314009772, 0.5323124820815409, 0.659524005808354, 0.6641559986732081, 0.5160585201252341, 0.4115719218341367, 0.4627843858671376, 0.6075701591677574, 0.444512

test metric {'loss': 0.11096459094967161, 'loss_no_reg': 0.10640909522771835, 'corr': [0.34564708286226153, 0.5065926431210779, 0.6087588859265121, 0.7669720578804073, 0.5413999343013299, 0.5775035952237632, 0.4450139354470495, 0.6485885494853596, 0.6782469952033705, 0.7260112412269615, 0.716919246486047, 0.4928299631621821, 0.6307162894518169, 0.6846642176609694, 0.6453871724302449, 0.6368706309570912, 0.7238907984751829, 0.6521677582509318, 0.6198722160788722, 0.6691702947734224, 0.5162150410867614, 0.4639411042224031, 0.3042010831737762, 0.09441895162191621, 0.7512698287210741, 0.6471589557618185, 0.5458967634800198, 0.34213270308635746, 0.6440682643346198, 0.5544626995096003, 0.38458652771488, 0.6004509902898773, 0.39455044964458635, 0.6108946912203999, 0.6426992509691861, 0.5567319844242662, 0.6542772362426873, 0.6351171848129062, 0.4791473543684379, 0.42348481428749674, 0.47143619296172623, 0.5923799205650392, 0.4539396509181237, 0.4609722942571355, 0.6482724170664829, 0.48073478