In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)



In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
    multi_path_separate_bn,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
            multi_path=True,
            multi_path_separate_bn=multi_path_separate_bn,
            # tried le5, gave same results as `test_full_training_ff_1st_multipath_train`
            multi_path_hack = 'leD3',
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_multipath_train/sep_bn{multi_path_separate_bn}_leD3/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

# print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=False))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

# results here are exactly the same as those in `scripts/debug/rcnn_basic_kriegeskorte/test_full_training_ff_1st.ipynb`

In [4]:
print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=True))

{'final_act', 'pooling', 'bn_output', 'fc'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 29613
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

0-0, train loss 0.1443062275648117
train loss 0.1443062275648117
val metric {'loss': 0.14617048501968383, 'loss_no_reg': 0.14376786351203918, 'corr': [0.0, 0.0, 0.0, -0.044446496495231275, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.07124348309139919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.051931935038466284, -0.10188888848884817, -0.05308089244181002, 0.0, 0.0, 0.0, 0.0, 0.021506735929772716, 0.06134196192755239, 0.0, 0.0, 0.0, 0.03049137507294005, 0.0, 0.0, -0.019125879954220666, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.040459644679946075, 0.0, 0.0, 0.03976508194558501, 0.0, 0.0, -0.029561403917653578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00946054054482032, 0.0, 0.0, 0.0, 0.0, -0.2187269482809438, 0.0, 0.11438818706828995, 0.0, 0.0, 0.0327169933385157, 0.000989843008705124, 0.0, 0.0, 0.0, 0.0, 0.1964743168695262, -0.02605207757376831, 0.0, 0.0, 0.0, 0.0, 0.15677097469843732, 0.0], 'corr_mean': 0.0014082560825163287, 'corr_mean_neg': -0.0014082560825163287, 'corr2_mean': 0.002006392695203114, 'corr2_

300-0, train loss 0.11839597672224045
train loss 0.11839597672224045
val metric {'loss': 0.12002036869525909, 'loss_no_reg': 0.11479201167821884, 'corr': [0.24082215785922384, 0.48966062074981254, 0.5305569594204742, 0.7430141439927129, 0.4513862823682045, 0.5053026647056167, 0.3965475416029817, 0.5633162563511018, 0.652611235242958, 0.67237605935226, 0.6470326528868882, 0.4689408043242441, 0.5648097971100452, 0.6296432102046297, 0.5381462335222329, 0.5862710228403848, 0.6761375956522195, 0.5929916414399953, 0.5486559882287618, 0.6573427769807861, 0.4250955084369375, 0.4106458786842209, 0.2357417546757698, 0.10864133712669372, 0.7500073051754546, 0.5118596019817803, 0.47228996946380314, 0.3209710655569896, 0.6370846496917983, 0.5270218381730662, 0.29426185167252783, 0.5898401814162938, 0.4120267831812813, 0.53039733219575, 0.6170512440225806, 0.4813980696526585, 0.542561826504401, 0.5652190787368997, 0.4192841352622621, 0.39320361617421146, 0.32738721234108886, 0.5427845812813255, 0.43

600-0, train loss 0.10895277559757233
train loss 0.10895277559757233
val metric {'loss': 0.11431891620159149, 'loss_no_reg': 0.10942494869232178, 'corr': [0.27228939069779406, 0.4992913256461874, 0.6054113782679114, 0.7652901790831874, 0.5277660617684756, 0.521106537671029, 0.4797155453918558, 0.6105890145412671, 0.6836904636025046, 0.7167714066834833, 0.6994821697838096, 0.5248370205843235, 0.6047816596635615, 0.645214845297416, 0.64260527676675, 0.6035740316606163, 0.7045542030889714, 0.6102713606076242, 0.5929858164906712, 0.6943489322317038, 0.45969501251553385, 0.4442287490564263, 0.3030289217782229, 0.09960855733252666, 0.7690135263161822, 0.5945035015839286, 0.5793031791934784, 0.3739526351875549, 0.6669854800411402, 0.561000306793553, 0.34797805117753355, 0.581897058535904, 0.5062879098416586, 0.5698070994548065, 0.63548770485484, 0.5326573443344016, 0.6139121312353846, 0.6199804268960004, 0.46842250130011803, 0.40663696490373863, 0.41261984636205, 0.5871844232060323, 0.4614904

900-0, train loss 0.10974754393100739
train loss 0.10974754393100739
val metric {'loss': 0.11239622831344605, 'loss_no_reg': 0.10767494142055511, 'corr': [0.2655422074996363, 0.4955410304856157, 0.6057569160023717, 0.7560376066460871, 0.522588049256052, 0.5130031164373081, 0.4914938698132155, 0.6272391522551696, 0.6889767684854828, 0.7163394516672774, 0.6981648170887198, 0.5360134074922032, 0.6113176013479064, 0.6564247024019172, 0.6511892348386988, 0.6101108448957764, 0.7144588440915043, 0.6092329867677234, 0.5926656339435228, 0.6927085690808699, 0.4698604815663913, 0.4455424615027883, 0.341764201150615, 0.09894966502981133, 0.7694971276156799, 0.6072610871726454, 0.5817404144725686, 0.37801755730320835, 0.6706748141475091, 0.5655814441476866, 0.3492076874343629, 0.579819542892287, 0.5193074842758034, 0.5681424978701703, 0.6393708422594119, 0.529895044634791, 0.6339566715614905, 0.64218653785742, 0.4798252257782023, 0.4161167762338191, 0.4429265296191593, 0.5978051103957486, 0.4570580

1200-0, train loss 0.10687312483787537
train loss 0.10687312483787537
val metric {'loss': 0.11123971790075302, 'loss_no_reg': 0.10665827989578247, 'corr': [0.27602124611433715, 0.5053712740475789, 0.6187856872715061, 0.7583762579487541, 0.5415968903513912, 0.5233499442651723, 0.5037240856366485, 0.6381613491510123, 0.6908693043225838, 0.7190989467623856, 0.700463215913095, 0.5414202055895821, 0.6168999058300324, 0.6453616746002424, 0.6576410909217054, 0.615838817061676, 0.7135582753296346, 0.6203467128085867, 0.6042911305620577, 0.7003466888793544, 0.47916352403233264, 0.46047771245158176, 0.37408485046453854, 0.09430617839676593, 0.7769906981597603, 0.6273354859821496, 0.5834343447854622, 0.37229108463576327, 0.6738906436046244, 0.5633842230578192, 0.3789608057948153, 0.5839849308151079, 0.5229368045188876, 0.5817565886750662, 0.6372141982915802, 0.5306985409632647, 0.6399717138376073, 0.6524133064411192, 0.4934438040918038, 0.4075823990508271, 0.4634083937883287, 0.6009555206870129, 

1500-0, train loss 0.11425022035837173
train loss 0.11425022035837173
val metric {'loss': 0.110739766061306, 'loss_no_reg': 0.10627945512533188, 'corr': [0.2682499800015489, 0.4997463079572446, 0.6241242601124131, 0.7631683045962736, 0.5348759225126081, 0.5246503110371484, 0.5038721572145333, 0.6386436657375408, 0.6913181287738384, 0.7259246119844227, 0.7036033272360419, 0.5494653633638713, 0.621891589444992, 0.6531914999494143, 0.6581343748953963, 0.6176359996503069, 0.7130364084725574, 0.6217623697829622, 0.6128779938104799, 0.7030716724132386, 0.4840289881909501, 0.4586273292997224, 0.3785175795890732, 0.10361406573029355, 0.7796655014058469, 0.6369979559079022, 0.5870079138849162, 0.3790931656407607, 0.6745989061780215, 0.5646764661626363, 0.37637626633692356, 0.587289649587661, 0.5221175907022533, 0.5852391624406472, 0.6411348649247832, 0.5320062130291046, 0.643022880000361, 0.6583198636927868, 0.4957262814511248, 0.4129576764535739, 0.4554207773147561, 0.6065129610627555, 0.46548

1800-0, train loss 0.10882316529750824
train loss 0.10882316529750824
val metric {'loss': 0.11061991304159165, 'loss_no_reg': 0.10626615583896637, 'corr': [0.2702787950793053, 0.5096104335647866, 0.6222501708607684, 0.7587313353563695, 0.5372610235950037, 0.5223559463521631, 0.5051995670173218, 0.6457959118799664, 0.6900437112644117, 0.7285529631836757, 0.7060124297366777, 0.5489378860568582, 0.6209087385399505, 0.6528494382287986, 0.6616735116831498, 0.6186186691693445, 0.7205433769941559, 0.6204491765798315, 0.6128118155118443, 0.7047712947354678, 0.47828924807730827, 0.45479915491268075, 0.3758615541274095, 0.10811102518427017, 0.7750130398251391, 0.6367693082684217, 0.5879831019233583, 0.38218458400719946, 0.6720454568537675, 0.5636280514354319, 0.37911898141515243, 0.5888240647022414, 0.5327015750621975, 0.592721674738323, 0.642802083781814, 0.5357597096475237, 0.6415786605943838, 0.6615309246643861, 0.5061866838796418, 0.41331082592316437, 0.4531069390944068, 0.6104726286940649, 

2100-0, train loss 0.10033605992794037
train loss 0.10033605992794037
val metric {'loss': 0.11076121330261231, 'loss_no_reg': 0.10646969825029373, 'corr': [0.27562958339772126, 0.5165457963743192, 0.6248846612157757, 0.7637360753618403, 0.5411247196310639, 0.521996271115562, 0.5115318260227194, 0.6572354358009468, 0.6908960366144317, 0.7269319776124323, 0.705505452098542, 0.5464393458205659, 0.620131005962804, 0.6508943479554604, 0.6590089554319856, 0.6198929580298689, 0.7233255219523258, 0.6197829740845233, 0.6145587270954486, 0.7029964230909489, 0.4783260562756259, 0.45010536930772815, 0.3808526392019005, 0.1127721755299177, 0.7743890997366095, 0.6313564363279045, 0.5937496911784801, 0.3905597310412718, 0.6743157626469962, 0.565140798169228, 0.382904234564728, 0.5860547098666931, 0.5269430938636476, 0.6024807849746885, 0.6463540733104078, 0.5283980930323309, 0.6408565530833262, 0.6625460238064433, 0.5035536770117777, 0.40671122902856205, 0.45717846608692847, 0.6128858169368699, 0.463

2400-0, train loss 0.1072777807712555
train loss 0.1072777807712555
val metric {'loss': 0.1110390454530716, 'loss_no_reg': 0.10687718540430069, 'corr': [0.27070788716279093, 0.5154309960664197, 0.627793618701709, 0.7621758447766878, 0.5356150227748468, 0.5221627386282824, 0.5101739022924314, 0.6466781292133146, 0.6862035020224538, 0.7248810046244869, 0.7004482085026714, 0.546464056190517, 0.6189014881680701, 0.6541194351601474, 0.6541188016692497, 0.6198454478943941, 0.7266574470483695, 0.6243197403219924, 0.6128775529613275, 0.6992106377546958, 0.48425958455642903, 0.45403677988623525, 0.37616304161253356, 0.12482440294024284, 0.7619876907984241, 0.6412243792627962, 0.5893208635067088, 0.3828849920461291, 0.6775734889929794, 0.5600819203439003, 0.37627226518267826, 0.595940197531154, 0.5319620851414115, 0.5996503842637968, 0.6429856852891669, 0.531024147530659, 0.6414367804236453, 0.6574913503455999, 0.4960815057623035, 0.41185707621422557, 0.4472863485274232, 0.6134005429733207, 0.47

100-0, train loss 0.1041986346244812
train loss 0.1041986346244812
val metric {'loss': 0.11000189632177353, 'loss_no_reg': 0.10581265389919281, 'corr': [0.27169796007283165, 0.5132736949146459, 0.6283350859533563, 0.7621457774228443, 0.5417758778198687, 0.5233193793319005, 0.510359658032297, 0.6559567514701792, 0.692012390342895, 0.7284322443805415, 0.706267341466156, 0.5496057019805226, 0.6207318120147168, 0.6516425621913045, 0.6601885051974636, 0.6198521889679908, 0.7247818996490726, 0.6236225954778196, 0.6147645199204289, 0.704590053611984, 0.48344001648103274, 0.45748932085464894, 0.380577025581027, 0.11110764128943723, 0.7786075059150388, 0.6389107954874698, 0.5913483565784895, 0.3864470829025828, 0.6738415722219627, 0.5625095438067261, 0.38127225609910154, 0.5918941982667146, 0.5328053673525707, 0.6040205054433478, 0.647068064205487, 0.5339729497291372, 0.6399966403425598, 0.6636455643506928, 0.5052111875181549, 0.40704165052943575, 0.4540593408634459, 0.6157363164166592, 0.46676

400-0, train loss 0.10388478636741638
train loss 0.10388478636741638
val metric {'loss': 0.11003468632698059, 'loss_no_reg': 0.10584288835525513, 'corr': [0.27073148990256973, 0.5122043534785324, 0.6277014726637536, 0.7611534158969546, 0.5416529164347335, 0.5225264740240265, 0.5130993348723338, 0.653101715335769, 0.6918333917844144, 0.7270433734566273, 0.7067392899413112, 0.5485855840512368, 0.621517954408714, 0.6532290450802468, 0.6595719822676974, 0.6184553956080826, 0.7256970472433741, 0.6241531340621749, 0.6139800749520536, 0.703441680647529, 0.4827981041919986, 0.4555274080012144, 0.3826541004100047, 0.11200898581565324, 0.7781038365542678, 0.6419198260387156, 0.591944280010382, 0.3888106008529757, 0.6731362674798793, 0.5625832720753239, 0.38455305604869816, 0.5917872420594403, 0.5307737032569679, 0.604010843657705, 0.6461789515947612, 0.5319465662685541, 0.6405772309827519, 0.6657768316792478, 0.5056168283695552, 0.41340032001591664, 0.45100836960727486, 0.6157231423733371, 0.467

700-0, train loss 0.1006242111325264
train loss 0.1006242111325264
val metric {'loss': 0.11033628582954406, 'loss_no_reg': 0.10621291399002075, 'corr': [0.27669785282994896, 0.515316167416465, 0.6280011482569099, 0.760301715381489, 0.5380715262092086, 0.5226239030903038, 0.5121496850359246, 0.6533041077006032, 0.6911159502810307, 0.7291403264625667, 0.7074618967214998, 0.5491909537605737, 0.6220845695451356, 0.6541235713437034, 0.6591911026049243, 0.6197890831677975, 0.7251495581226781, 0.6210121247630154, 0.6153379687283982, 0.7051360682171586, 0.4837125144407102, 0.4548778746771289, 0.38153448064896467, 0.11343559675413648, 0.7763204444623542, 0.6418022365291958, 0.5890181879961944, 0.3896941765628707, 0.6742080559351925, 0.5638922219131499, 0.38244310748193044, 0.5942356720896838, 0.5322070531786991, 0.603844541246205, 0.6451819193987738, 0.5297788806501778, 0.6454964727847025, 0.6616891647766258, 0.5071552510417204, 0.41088565612746386, 0.4518306206539404, 0.614934035724945, 0.4649

100-0, train loss 0.10506466031074524
train loss 0.10506466031074524
val metric {'loss': 0.10987830609083175, 'loss_no_reg': 0.10576222091913223, 'corr': [0.2698635465474989, 0.515920073183799, 0.6290976776990456, 0.7616288997986524, 0.539475841804244, 0.5228432614590733, 0.5114359550028558, 0.6529355095673357, 0.6930574003689476, 0.7286657719221937, 0.7062836507610873, 0.5494645852974652, 0.6219526720113822, 0.651676647862533, 0.6589303092111193, 0.6194858458947983, 0.7253502335826855, 0.6244195446804879, 0.6154344649147352, 0.7061164490076515, 0.4813720593845998, 0.4545932107624179, 0.38026802590520037, 0.11255444756700639, 0.7784724562092598, 0.640252191272703, 0.592224074001193, 0.39000160489668173, 0.6730820573245745, 0.5650439790855329, 0.38211459561225747, 0.5916016188997425, 0.5317634298729319, 0.602295608024707, 0.6472139467289495, 0.5331379898072163, 0.6448557245680588, 0.6637547966459669, 0.5059443564365498, 0.4091043696644172, 0.45424700275301594, 0.6152359593830004, 0.4682

400-0, train loss 0.10083360970020294
train loss 0.10083360970020294
val metric {'loss': 0.10982789397239685, 'loss_no_reg': 0.10569710284471512, 'corr': [0.2691447897496235, 0.5150371338945938, 0.6277020037229238, 0.7625209062474152, 0.5383681738975189, 0.5229631766890088, 0.5119247297611587, 0.6517714539301922, 0.6925335958091136, 0.7285504925352644, 0.7069646792358215, 0.5486142986647788, 0.6222296648023514, 0.6535084384141795, 0.6602807020275067, 0.619839586716428, 0.7243388473524348, 0.6246852235371418, 0.6145675344045531, 0.7053066811670196, 0.4813537204192289, 0.4550440426490058, 0.38118174728075394, 0.11210926442358878, 0.7779253827180643, 0.6407298978537728, 0.5931571305461923, 0.39064935368393483, 0.6744655961451175, 0.5645054695634548, 0.3831381850242369, 0.5922769285552081, 0.5340026121154149, 0.6025310044574859, 0.6470544031780933, 0.5327429098030816, 0.645348075763442, 0.6637867808930829, 0.506413464615801, 0.4116452769320589, 0.45515192585926817, 0.6148021804504153, 0.46

700-0, train loss 0.10096006840467453
train loss 0.10096006840467453
val metric {'loss': 0.10994871705770493, 'loss_no_reg': 0.10581935197114944, 'corr': [0.27030379557975315, 0.514730957367114, 0.6277889997228692, 0.7615756479883158, 0.5383207878500271, 0.522098564609013, 0.5114244198566471, 0.6530410229437509, 0.6928177011961494, 0.7287569247232412, 0.7063658256083635, 0.5486447756851307, 0.6224809375332628, 0.6524805759338619, 0.6596174308804184, 0.61899999539001, 0.7241715511494671, 0.623644809966228, 0.6151394102002175, 0.7048243541525618, 0.4815582957551896, 0.4556540730962322, 0.3820860662161506, 0.11283100183606139, 0.7777277497950901, 0.6395891331953514, 0.592054667821255, 0.392222890153165, 0.6731006260254693, 0.5645242698008182, 0.3835107932300969, 0.5904013753659805, 0.5311902240253621, 0.6036532740542457, 0.646473571743946, 0.5317325513855315, 0.6442728286785442, 0.6640959766983634, 0.506860018352387, 0.4103322437378467, 0.4507780035018498, 0.6151746467463056, 0.4665847498