In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)



In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
    multi_path_separate_bn,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
            multi_path=True,
            multi_path_separate_bn=multi_path_separate_bn,
            # tried le5, gave same results as `test_full_training_ff_1st_multipath_train`
            multi_path_hack = 'geD3',
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_multipath_train/sep_bn{multi_path_separate_bn}_geD3/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

# print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=False))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

# results here are exactly the same as those in `scripts/debug/rcnn_basic_kriegeskorte/test_full_training_ff_1st.ipynb`

In [4]:
print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=True))

{'pooling', 'bn_output', 'final_act', 'fc'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 29613
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

val metric init {'loss': 0.1466030567884445, 'loss_no_reg': 0.1437675952911377, 'corr': [-0.0362335761909663, 0.044940387383060804, -0.029566950474332224, 0.011795256262101655, -0.031140188436121343, -0.039093811160902395, -0.007890446792992003, 0.08460442715268071, 0.09801664504455312, -0.06407798555165395, -0.04768519216356399, -0.01057439902128484, -0.036860150743686385, 0.1054442864535316, 0.04897745874745626, 0.06055593573742159, 0.05019300086222476, -0.09938403432129646, 0.03859283725626592, 0.038611465156174786, 0.026499758614732156, -0.017763510987413154, -0.11028192928038792, -0.014094010124800395, -0.10894144493096433, -0.019942928053973654, 0.009548227515124507, -0.02154646532888021, -0.03784327781438692, -0.13144405242906348, 0.006196397911585212, 0.023144433892569573, -0.04204017020347693, 0.045040761203512, -0.0506513351482142, -0.16261780730467817, 0.03582333389781861, -0.09511260073432891, -0.15377644593473438, -0.12164877044060218, 0.006320371200581088, 0.0063461938894

test metric {'loss': 0.12136983126401901, 'loss_no_reg': 0.1179540678858757, 'corr': [0.30220842505738055, 0.452830809347405, 0.5526564128199478, 0.728192191684349, 0.4448023366785528, 0.5542492766128811, 0.4435556268624202, 0.5193351042192103, 0.6443746440550715, 0.6645958763044617, 0.6516934421655391, 0.43370039957740164, 0.5308048499123819, 0.6414768146968899, 0.4488593415614998, 0.5679618010599843, 0.6734645438174613, 0.5959281794568668, 0.5432399078612805, 0.593329874014169, 0.45295205651864656, 0.43422468959914895, 0.19678474724146405, 0.06880077039388897, 0.7120179638360273, 0.5211233739350649, 0.4247738711412061, 0.24678412533568414, 0.5945746363220267, 0.490354016961152, 0.331279695672541, 0.607487831097343, 0.3539140195421211, 0.5120647992311824, 0.5852559455055473, 0.4751928959240102, 0.5324738201877641, 0.5394509783963664, 0.3763240334621425, 0.37784728506495435, 0.346504511368948, 0.5196850161440937, 0.4254093095749967, 0.39660599551188525, 0.489691254010465, 0.38931041535

500-0, train loss 0.10779531300067902
train loss 0.10779531300067902
val metric {'loss': 0.11362311840057374, 'loss_no_reg': 0.11031544953584671, 'corr': [0.2429847077198355, 0.4988461888651744, 0.600266253238706, 0.7663282859545517, 0.5196518347572994, 0.5167842197379918, 0.4519340525303339, 0.6189506738325281, 0.6700912302147726, 0.7060590111179694, 0.6783856723596455, 0.5208374007529237, 0.6025440845259247, 0.6521480507870471, 0.6101374651042687, 0.6092776880628078, 0.7116496518339754, 0.6117747434373877, 0.5939386055043697, 0.6897305091483183, 0.48159128152050623, 0.4356583943885032, 0.30291424796983957, 0.10289987758560949, 0.7640799069085947, 0.6107669515360508, 0.5605532555509656, 0.3798241817154161, 0.6666998943184494, 0.5558055967062858, 0.35034970305905433, 0.5837313553144987, 0.4633283592671213, 0.5875660807859923, 0.6299752776555627, 0.512137038155576, 0.6147304283350622, 0.6274791407584197, 0.47817287923012, 0.4196639158163969, 0.40327377712439294, 0.5850451186342243, 0.46

test metric {'loss': 0.11294279886143548, 'loss_no_reg': 0.10944553464651108, 'corr': [0.3296317131133165, 0.4967044298443778, 0.607217059151564, 0.776682169313325, 0.5171273605641225, 0.5813329613505145, 0.5181847073255021, 0.6403307283924552, 0.678100917531864, 0.7217953104016497, 0.7047319893594787, 0.47104734285925304, 0.5924183190173957, 0.6828118243553309, 0.6166704447064755, 0.6241767677854756, 0.7211000254609119, 0.6485662398073744, 0.599895943696884, 0.674116280015247, 0.5165178540240781, 0.4636538063978256, 0.27577166920583995, 0.07595756343064887, 0.7502249595538366, 0.615414173074238, 0.5419148541227752, 0.3292921730802043, 0.657372562574063, 0.5371461720534991, 0.37649415071161524, 0.5994130653912872, 0.3879037683392113, 0.5946038974964509, 0.6382880554384956, 0.5469082687932691, 0.6346865279171147, 0.6184010180264801, 0.46920406355362837, 0.41987731134854567, 0.4548136375816915, 0.5858198093862935, 0.4690122462519437, 0.45341933146242286, 0.611347935612395, 0.469686169177

1000-0, train loss 0.11013109236955643
train loss 0.11013109236955643
val metric {'loss': 0.11149416714906693, 'loss_no_reg': 0.10825810581445694, 'corr': [0.24430121044374528, 0.5042591730196341, 0.6243529974007321, 0.7713502135667201, 0.5322506083011782, 0.5328798363731708, 0.49511501015395365, 0.6330141050611711, 0.6795271155449323, 0.7139262213865807, 0.7016306569257733, 0.5364401709116746, 0.6141803822570486, 0.6497398430546748, 0.6312755932713896, 0.6160367218700988, 0.7186577019286898, 0.6231311994513522, 0.593667627764437, 0.6927320834830193, 0.4951682372419569, 0.4582760634906038, 0.3415912654559833, 0.09438425475657802, 0.7736349926109632, 0.6316104545894413, 0.5622594260315134, 0.39280885218477235, 0.6796629067499698, 0.5587077038894431, 0.37834358903785603, 0.601145138334599, 0.5403704212285872, 0.5935834876556363, 0.6337336663238459, 0.520421900833974, 0.6263467636373741, 0.6504047178127259, 0.4975649822236211, 0.4125942873817182, 0.44864949800199083, 0.6042610521578511, 0

test metric {'loss': 0.11141393865857806, 'loss_no_reg': 0.10791531950235367, 'corr': [0.33436734666897505, 0.501465913069682, 0.6037088988826571, 0.7683348260150548, 0.528010819882116, 0.5835770635772799, 0.5037724401776797, 0.647407557059178, 0.6814179604720805, 0.7304182712530196, 0.7162645266901433, 0.4884050241779611, 0.6077233404673527, 0.6816122184785675, 0.6276648402540506, 0.6333471612012052, 0.7367029312705844, 0.6442010031710199, 0.6087517518879012, 0.6698007828093955, 0.5298835939087836, 0.4702409908784505, 0.3091622834022253, 0.06405766157305048, 0.7612704500140287, 0.6288971462412758, 0.5476202005665167, 0.3323900034146185, 0.6484205310219102, 0.5463387478354642, 0.4016594183828186, 0.5985855595088659, 0.4588611278976471, 0.5965366059678417, 0.6425221340897389, 0.5540345499033346, 0.6477235781079539, 0.6293068758423273, 0.4721217514372652, 0.42537209348200605, 0.46735601407139904, 0.5901418460431866, 0.47004260232471917, 0.46054834136168826, 0.6507820914404684, 0.47424372

1500-0, train loss 0.1083960086107254
train loss 0.1083960086107254
val metric {'loss': 0.11278568804264069, 'loss_no_reg': 0.11023711413145065, 'corr': [0.2477754102159426, 0.5057911089851685, 0.6267121170400825, 0.7691366894635454, 0.5273866987565929, 0.5336368292859945, 0.5133790658916912, 0.6311952264150936, 0.6871327256587026, 0.7208991146001199, 0.7075827010288153, 0.533442304224222, 0.6211872895517463, 0.6399472564123876, 0.6387726508830349, 0.6151999391701177, 0.7204273917754154, 0.6266898413228964, 0.6012408143606764, 0.6960465420473759, 0.4939493376602352, 0.457742857199708, 0.35993239387824993, 0.0978403683048682, 0.7791579419021765, 0.6310852385521919, 0.5733917601676156, 0.3868172726063579, 0.6833985156839903, 0.55852889500305, 0.37580334580540153, 0.5950339434049782, 0.5348215920915046, 0.5968597144395049, 0.636097398628586, 0.5262558399627975, 0.6373393377366359, 0.6618056819673068, 0.48741743927211045, 0.41347824898170726, 0.433348826500799, 0.6140951245967381, 0.467047

0-0, train loss 0.10006961226463318
train loss 0.10006961226463318
val metric {'loss': 0.11089855879545212, 'loss_no_reg': 0.10747833549976349, 'corr': [0.2512974731343881, 0.5073571414066107, 0.6227148683158257, 0.7662171860220249, 0.5288915889543361, 0.5324489414782696, 0.5156774380988481, 0.638366097002145, 0.6827076517105894, 0.7198665888016385, 0.7024956493563771, 0.5299129141038611, 0.6165926599485273, 0.6410023277383324, 0.626204061751533, 0.6034550419036722, 0.7235740942032416, 0.628161714879677, 0.5905136537141938, 0.6902459152858118, 0.4975791810318397, 0.46467684507691, 0.3481939080367595, 0.09362435679177242, 0.775257104626403, 0.6273150334619206, 0.5737840042021926, 0.3959517508745256, 0.676556261479512, 0.5605773155758456, 0.37614848282973123, 0.6021049516717885, 0.5550152012841607, 0.6032117219336066, 0.6387533251858994, 0.5236717614596172, 0.6266465178077716, 0.6530460536943505, 0.5025612203716151, 0.42088830221296636, 0.4401583896223597, 0.603101264731366, 0.4709429488

test metric {'loss': 0.110567393047469, 'loss_no_reg': 0.10701245814561844, 'corr': [0.3396503000062927, 0.5032161008401619, 0.6171445401511003, 0.7759271779334418, 0.5187574664688477, 0.5826434971825463, 0.5355507782105717, 0.6456532609194, 0.6850342037487309, 0.7338442399094298, 0.7199125676831939, 0.48834374372741896, 0.6148956532307704, 0.6819052311145353, 0.6285432229473609, 0.6355424857567874, 0.7399677260948883, 0.6521771290370821, 0.6143992632794921, 0.6762331445101195, 0.5300183715418345, 0.4695580401312712, 0.30793233018970695, 0.06828797892331574, 0.7617791392435901, 0.6358407008889122, 0.5488089943676195, 0.3442902781233139, 0.6536381417868713, 0.5456674226554509, 0.3980908644052304, 0.6018322284341362, 0.465941587109721, 0.5959778870622994, 0.645253040481635, 0.558019752387986, 0.6449788250481123, 0.6352705810558548, 0.4754057999190997, 0.4242868815657518, 0.4648363873605445, 0.5896739436473217, 0.46667069434767255, 0.46808856822135625, 0.6531424334198693, 0.46713099916386

500-0, train loss 0.10441022366285324
train loss 0.10441022366285324
val metric {'loss': 0.11049116253852845, 'loss_no_reg': 0.10711252689361572, 'corr': [0.25779641251411883, 0.5091274758918001, 0.6270070753909734, 0.7686913463742095, 0.5366791072440799, 0.5321556009944919, 0.525080231549152, 0.6418038134379608, 0.6867158405326373, 0.7232111175025857, 0.7046151035368129, 0.5340194623179885, 0.6232421502454437, 0.6497358050187558, 0.6304815827954952, 0.6152155790287374, 0.7219651658772641, 0.6286026132691902, 0.6004896182550201, 0.6954344704927262, 0.4972666584620786, 0.4536526372869933, 0.36651228294195837, 0.08956047843088474, 0.7749460642854233, 0.6295689705238174, 0.5695854279896517, 0.39013728796593583, 0.6799477606886668, 0.558813902060584, 0.38212336736370056, 0.6016975817614524, 0.5544321408704446, 0.6012993934709246, 0.6422353023628066, 0.5281855271202003, 0.6334639287600456, 0.6626175362748282, 0.49756554994472224, 0.428961920482406, 0.43838740893338674, 0.6110979318081263, 0

0-0, train loss 0.10819145292043686
train loss 0.10819145292043686
val metric {'loss': 0.10989842861890793, 'loss_no_reg': 0.10654143989086151, 'corr': [0.25082776766866727, 0.5055573044866325, 0.6288072852425646, 0.7727579293633613, 0.5357311550422379, 0.5315549430002654, 0.5161414011996778, 0.6425171981007968, 0.6844245248543064, 0.7231015735275959, 0.7060016571053418, 0.5400411900849565, 0.6240194002264696, 0.6509277354947531, 0.6397460498182301, 0.6172639726016955, 0.7201734351297087, 0.6310513185183799, 0.6027133389241175, 0.699645856740472, 0.49487541989340267, 0.45899316358887865, 0.35788260372434877, 0.09295051107394131, 0.7764136627758784, 0.6304706511086661, 0.5708750286524547, 0.39428735908777834, 0.6825539960863592, 0.5629298884083778, 0.3764856922127285, 0.6046934659755132, 0.5554078940826022, 0.6058436809405073, 0.6417012582837714, 0.5279676582638972, 0.637468312511054, 0.6616632904433343, 0.4981124430630579, 0.4280572188879701, 0.4481250818765838, 0.6111648308231878, 0.4

test metric {'loss': 0.11029268694775445, 'loss_no_reg': 0.1068507581949234, 'corr': [0.33817024998986056, 0.5035013522237863, 0.6197651253787495, 0.7780575373896722, 0.5189682333987207, 0.5829125619757801, 0.5408034105307233, 0.6487836496685866, 0.6840208241798313, 0.7356163969091387, 0.7212449614994811, 0.4896210624324013, 0.6147339174548143, 0.6827248130549158, 0.6312275810917625, 0.6369573572989162, 0.7423306147534159, 0.650153395207782, 0.6153453432953878, 0.6764759576953354, 0.5290219770396902, 0.46993806712912406, 0.31239091269473473, 0.07309578301944804, 0.765158821181979, 0.6396894171186273, 0.5496248757894473, 0.3451413535105323, 0.6544601890867088, 0.5493169618792614, 0.3986765518862052, 0.5987664933511057, 0.47054102120354874, 0.5991052728074646, 0.6473659591339884, 0.5575862722026317, 0.6475469401745831, 0.6368177155391738, 0.4760391950097234, 0.42454378010120447, 0.4665345619019065, 0.5911911719547442, 0.4698845230057044, 0.46822768945764165, 0.6567635791537885, 0.4718397

500-0, train loss 0.10637934505939484
train loss 0.10637934505939484
val metric {'loss': 0.10998515635728837, 'loss_no_reg': 0.10668722540140152, 'corr': [0.24912718322815258, 0.5068998236307227, 0.6309832269463294, 0.7705714315357137, 0.5319059287576023, 0.5346846850788727, 0.520120371703077, 0.6408317123454845, 0.6865919840773927, 0.7231067981413972, 0.7077028760410898, 0.5371326334728767, 0.6229793904031748, 0.6478457782767866, 0.637029434802604, 0.6156284886895599, 0.7215564609023009, 0.6293974466905188, 0.6016678166109821, 0.6987688842180309, 0.5001981183784345, 0.45999706631323356, 0.36199050338434113, 0.0902284778922372, 0.7777633357601963, 0.632452120421794, 0.5730396843442693, 0.393078128057028, 0.6811300290358493, 0.560517142003863, 0.37885819502000023, 0.6018311702301788, 0.5541220438159059, 0.6023757345968273, 0.6419972165029831, 0.5288795048939283, 0.6367802186383819, 0.6630429407627324, 0.5009342289011907, 0.42658782279747626, 0.4410075798280881, 0.6123643853073322, 0.471