In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)



In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
    multi_path_separate_bn,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
            multi_path=True,
            multi_path_separate_bn=multi_path_separate_bn,
            # tried le5, gave same results as `test_full_training_ff_1st_multipath_train`
            multi_path_hack = 'geD4',
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_multipath_train/sep_bn{multi_path_separate_bn}_geD4/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

# print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=False))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

# results here are exactly the same as those in `scripts/debug/rcnn_basic_kriegeskorte/test_full_training_ff_1st.ipynb`

In [4]:
print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=True))

{'fc', 'pooling', 'final_act', 'bn_output'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 29613
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

0-0, train loss 0.14437025785446167
train loss 0.14437025785446167
val metric {'loss': 0.14625344872474672, 'loss_no_reg': 0.14376747608184814, 'corr': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'corr_mean': 0.0, 'corr_mean_neg': -0.0, 'corr2_mean': 0.0, 'corr2_mean_neg': -0.0, 'acc': None}
test metric {'loss': 0.14607491024902888, 'loss_no_reg': 0.14356966316699982, 'corr': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'corr_mean': 0.0, 'corr_mean_neg': -0.0, 'corr2_mean': 0.0, 'corr2_mean_neg': -0.0, 'acc': None}
100-0, train loss 0.12135522812604904
train loss 0.12135522812604904
val metric {'loss': 0

300-0, train loss 0.11342887580394745
train loss 0.11342887580394745
val metric {'loss': 0.11566703617572785, 'loss_no_reg': 0.11270786821842194, 'corr': [0.2569424866635066, 0.4904678427282851, 0.562971320071463, 0.7526071964064864, 0.458873799284007, 0.5108498754777377, 0.4314608681787703, 0.5754219335690138, 0.6610083935385654, 0.683411999654859, 0.6569889807190783, 0.48550559474540217, 0.5799746970411022, 0.646149229240939, 0.5574724390222003, 0.5988056250224245, 0.6649026924473259, 0.591267784202351, 0.5673268187216496, 0.6597692857135946, 0.44686393899459453, 0.41905221065745357, 0.268207771256167, 0.08279702993069982, 0.739492137741825, 0.5603078400538783, 0.5413781143425187, 0.33236131050559825, 0.6441112101435045, 0.532060326559849, 0.3154529393866041, 0.5863017370759447, 0.4977355719120833, 0.539311305621759, 0.6165525903676798, 0.49058665768567095, 0.5742060315652115, 0.5946707934995625, 0.4384687591794634, 0.41001932737119695, 0.3632251174078376, 0.5638306693989863, 0.44885

test metric {'loss': 0.11468637628214699, 'loss_no_reg': 0.11135408282279968, 'corr': [0.3004339120751045, 0.49681659282484814, 0.6026841081063142, 0.7552197323487734, 0.4860702396210518, 0.5670692803622278, 0.5169559484694615, 0.6111538191165413, 0.6741630839913946, 0.720806027367465, 0.6959345375296114, 0.4629319417493281, 0.5741434104490102, 0.679468799649956, 0.5926373256946831, 0.6076580617118739, 0.6761087639733208, 0.6241564908559947, 0.5813030298665793, 0.6583871687655738, 0.4940243340121394, 0.4522439828956247, 0.25053796039056875, 0.0514144178404559, 0.7497429785598787, 0.5856778097516088, 0.5409086793274235, 0.29733122492949465, 0.641132804174086, 0.5277665651250658, 0.3466694726275532, 0.6038766948905987, 0.43918554774214835, 0.5995454169781624, 0.6402684293725289, 0.529131607586224, 0.620065735678364, 0.5894803064097505, 0.4488112812441418, 0.4070889628610957, 0.4423600945594274, 0.5630169628178523, 0.46051716115106034, 0.4396106473412853, 0.5708462240830618, 0.44321629618

800-0, train loss 0.10717889666557312
train loss 0.10717889666557312
val metric {'loss': 0.11307193487882614, 'loss_no_reg': 0.10995587706565857, 'corr': [0.26786721529496116, 0.5019730405073076, 0.6046173201128344, 0.77120727541305, 0.49965680467614637, 0.5106977421935612, 0.4631682520652085, 0.627433105629921, 0.6716261398862805, 0.7107709161338833, 0.6949289826326226, 0.5352284500367582, 0.6084871173304809, 0.6497474544019017, 0.6095119162285404, 0.6040058382048769, 0.6901080641555692, 0.6243837892153596, 0.6008384085272934, 0.6814925580786105, 0.4741133062582914, 0.4417694011263635, 0.3146589651496812, 0.08813097973106611, 0.7619604803650157, 0.6205107039304477, 0.5666698761781124, 0.3433170936172386, 0.6620359037122892, 0.5492196752362112, 0.3473604633459764, 0.5906229931992184, 0.5219347777335008, 0.5872935777011701, 0.6404301202381768, 0.5292019288706893, 0.6231330340842955, 0.6313710353695744, 0.4767073217494084, 0.41293479654154414, 0.4389038358927327, 0.5925421570144838, 0.45

test metric {'loss': 0.11328572886330741, 'loss_no_reg': 0.10962720215320587, 'corr': [0.3153830943058401, 0.49958467109928967, 0.6035197104909193, 0.7629248409445992, 0.49399943527597234, 0.5698913223339312, 0.5180077281067983, 0.6292606935940535, 0.675843481007431, 0.7219995038791082, 0.7031153063781004, 0.48319642400187124, 0.593632771755068, 0.6750342715847498, 0.6272417984725329, 0.6229890973930382, 0.7115831308342908, 0.637911758330366, 0.5997071975629131, 0.6658929192081662, 0.5110433103382851, 0.469383732193023, 0.2958913791691786, 0.0744497287856185, 0.7561215685337073, 0.6336754818633246, 0.5370191029182947, 0.32746990546973825, 0.6453023110139853, 0.545198387513263, 0.3893348978821959, 0.5969519416899001, 0.45870368021930136, 0.6036284435981424, 0.6406072615477061, 0.5450449468199314, 0.6394320047784506, 0.6190256851954283, 0.46574855379483016, 0.4071403307834953, 0.46687740987982274, 0.582283210023047, 0.4660626287639418, 0.4600323147622897, 0.6140804767831388, 0.4645861256

1300-0, train loss 0.10040703415870667
train loss 0.10040703415870667
val metric {'loss': 0.11216034144163131, 'loss_no_reg': 0.1088283583521843, 'corr': [0.26947230854619203, 0.5044001396902437, 0.6155835104699281, 0.765353275756047, 0.5204478841631701, 0.5249838187866023, 0.5036552260875629, 0.6160208554950066, 0.6769038708512021, 0.7131401885337022, 0.693138777798387, 0.5438594204522864, 0.611148284784709, 0.6407879157426848, 0.6140059915926873, 0.5987918527562915, 0.7115640719907402, 0.6129151305083435, 0.5986703928350336, 0.6994931401894748, 0.48296004012104204, 0.4524858917226909, 0.3348030928898146, 0.0856805405759627, 0.7628651238157849, 0.627377765307988, 0.5566720431396772, 0.36435024839280217, 0.6576970755743383, 0.5524815205085786, 0.3670566117733818, 0.5879582103648264, 0.531258026448767, 0.5729197736278613, 0.6383150285380984, 0.5299202341964092, 0.6301462700958482, 0.6364245582384509, 0.4761535085731133, 0.3906683030312611, 0.44278297906651004, 0.6003599892398066, 0.4604

test metric {'loss': 0.1152859064085143, 'loss_no_reg': 0.11141239106655121, 'corr': [0.3338793611373846, 0.4942110261812803, 0.5956640195344267, 0.768094477873229, 0.4858599763139402, 0.5710491756125209, 0.5232508449633281, 0.6333533488972126, 0.6715849027875865, 0.7283248127224056, 0.703767679598778, 0.48412487286335215, 0.6003867828652086, 0.6754073756282973, 0.6239337414913396, 0.6252487204858881, 0.715593477915289, 0.6375482806003827, 0.6023945763165279, 0.6687557243900332, 0.5141328623073145, 0.47202440854005956, 0.3122591881410609, 0.0750093664513277, 0.7507662595812684, 0.6306592188195522, 0.5455602968315592, 0.32653331843029754, 0.6461762244023852, 0.5474511553425837, 0.39445888872999596, 0.5926214254625629, 0.4544131554111951, 0.5993044618833517, 0.6415422501554147, 0.5520665715306818, 0.6429901865181302, 0.6276833288333196, 0.46755429351556654, 0.4068198748654105, 0.4534852534503078, 0.5721496498525602, 0.4582212976426763, 0.4490325663853568, 0.6168505787669141, 0.4645613596

early stopping after epoch 1750 metric 0.10848529636859894
for grp of sz 152, lr from 0.001000 to 0.000333
val metric init {'loss': 0.11196275800466537, 'loss_no_reg': 0.10848529636859894, 'corr': [0.25896542206421225, 0.5020849899096234, 0.6157454713292274, 0.7683021724931325, 0.5175515996299141, 0.5195439853130711, 0.49990878985417053, 0.6271062952979298, 0.6774249106116257, 0.7162363821219689, 0.6962498121400194, 0.545658360500873, 0.6165925690972097, 0.637581440565578, 0.6130749389529325, 0.6003569649802394, 0.7070444255661659, 0.6138249417116464, 0.6007124410125946, 0.6974249677462214, 0.48054911309002074, 0.4500828744043385, 0.3351254983122398, 0.0916234300675795, 0.7621763408429834, 0.6377777933288326, 0.5542819785199951, 0.36369503279223725, 0.6608367529778789, 0.5490562581419773, 0.3713971111845744, 0.5732908464338554, 0.5374319864989201, 0.5791068293478128, 0.6350774402829272, 0.5402510526551294, 0.6384191648521507, 0.6412177697783115, 0.4858144357550874, 0.3997829390138202, 

200-0, train loss 0.10197165608406067
train loss 0.10197165608406067
val metric {'loss': 0.11146320998668671, 'loss_no_reg': 0.10813881456851959, 'corr': [0.27209262225744735, 0.5038290803790263, 0.6200700372035878, 0.7714257575473844, 0.5226835143597324, 0.527338678922894, 0.5010233672244602, 0.6365828827461573, 0.6791124749902916, 0.7207960390502188, 0.698828969744737, 0.5470189716139862, 0.6176019335669937, 0.6408519021989764, 0.6250029253149867, 0.6013872800840868, 0.7127912045576774, 0.6149087463606768, 0.6063024395536121, 0.7056286170213465, 0.48584623150053136, 0.45453235267379405, 0.34144934025618084, 0.0817016290478407, 0.765738739853665, 0.6362346815210502, 0.5601541948988799, 0.36412380093145025, 0.6687784610077238, 0.5572585391802829, 0.36636068722243714, 0.5897088109952521, 0.533398802315957, 0.5871114562607248, 0.638686447143879, 0.5312603060734279, 0.6394716742291868, 0.6456443251345413, 0.48432349500663985, 0.39547954885161324, 0.4441775169426919, 0.6002390272019396, 0.

test metric {'loss': 0.113714656659535, 'loss_no_reg': 0.10997943580150604, 'corr': [0.3335762730264861, 0.49385956727683733, 0.6072476064635571, 0.7722033947523781, 0.5016123066908728, 0.5725170102285722, 0.5350596711378044, 0.6310137420521192, 0.67398849610655, 0.7303498793294263, 0.7041600066321366, 0.48729894487000763, 0.5996674248862177, 0.6747775602427131, 0.6256280523125326, 0.6282375379651015, 0.710272416839566, 0.6423795571843347, 0.6068205231749925, 0.6726439445863035, 0.5160455642860691, 0.4739376794995839, 0.3035640298180715, 0.06896853529927638, 0.7540212726829693, 0.637485381495903, 0.5412973773408281, 0.33460589089676973, 0.6464908596063552, 0.5457640039802577, 0.3944336750932687, 0.59394647550885, 0.46466397688350414, 0.5983868543198737, 0.6457084486070253, 0.5434275107300288, 0.6356648325645037, 0.630568502927568, 0.4700143493859764, 0.4082325344449207, 0.4569759900705006, 0.5726010043540045, 0.45821791325464944, 0.4532353864024985, 0.6131422416402189, 0.46398193851322

early stopping after epoch 650 metric 0.10792459547519684
for grp of sz 152, lr from 0.000333 to 0.000111
val metric init {'loss': 0.11122055351734161, 'loss_no_reg': 0.10792459547519684, 'corr': [0.2753894248975638, 0.5086790982852742, 0.6191033734097999, 0.7707959696802291, 0.5242515368055723, 0.5261468376231453, 0.5059249168976301, 0.6390626930317762, 0.6800627028773079, 0.7200800811423335, 0.6979372181739179, 0.5429618625943244, 0.6163287280426571, 0.6444963301886646, 0.6255008754850006, 0.6029316241843868, 0.7137080783856022, 0.619678554548304, 0.6047478884619704, 0.701774923396661, 0.48624601735752854, 0.4527849738256935, 0.3422184452144121, 0.08389168467970083, 0.762861640929076, 0.6343301773869541, 0.5588156994171677, 0.3672626956796639, 0.6680526437257173, 0.5546665363667496, 0.3719647044339758, 0.590716036462178, 0.5350997973563945, 0.5908096399302427, 0.6394956491566479, 0.5347597516776876, 0.6397491265813504, 0.6463698175169379, 0.4812216337834541, 0.3995122776950188, 0.444

200-0, train loss 0.09987127035856247
train loss 0.09987127035856247
val metric {'loss': 0.11126098483800888, 'loss_no_reg': 0.10801573097705841, 'corr': [0.27388072646739026, 0.5081987406752184, 0.6204648552622811, 0.7681614184489821, 0.523519413105027, 0.5265051262749126, 0.5060247097782946, 0.6374999260094706, 0.6803880266936181, 0.7211054964464914, 0.697616941767002, 0.5443733413100895, 0.6159285802528183, 0.6431944228076648, 0.6240647590821035, 0.6017773382791103, 0.7137182215142244, 0.6168564584652668, 0.6051620360054495, 0.7035725163805083, 0.48779594174290103, 0.45265411043969306, 0.33941273150304985, 0.08388949577626237, 0.7617950872247439, 0.6375366445419898, 0.5585212655011194, 0.36674853831720367, 0.667662440545145, 0.5553489289358968, 0.37017466765887375, 0.588092232062568, 0.5374321829694586, 0.5896444238029241, 0.6395119732590626, 0.5326911582100653, 0.6379800827522608, 0.6463920133932045, 0.48339185019776065, 0.39770241673173584, 0.4404815246659873, 0.6009151982244065, 

test metric {'loss': 0.11292390099593572, 'loss_no_reg': 0.10925576835870743, 'corr': [0.3363229291915206, 0.4948821723714802, 0.6077975838411895, 0.771405460256988, 0.5023731943393681, 0.5751688570338345, 0.5308005565125367, 0.6363567303722184, 0.6717013409950109, 0.7308008776788284, 0.702674078573817, 0.4857437152696993, 0.5992822508981481, 0.6745330449504101, 0.6251045280587006, 0.6298047591070126, 0.7151741487677472, 0.6431255333110786, 0.6092644045926583, 0.6708608815348888, 0.5161949650584324, 0.47084006496045094, 0.31055691713862543, 0.06956255774629082, 0.7546073532732765, 0.6385727924751641, 0.5424551991456452, 0.3348292814238848, 0.6484062830415196, 0.5448710987654103, 0.3998371408262874, 0.5939922184277545, 0.464562214536401, 0.6045258481946453, 0.6442353707831225, 0.5481836692441895, 0.6360516744466949, 0.6292518335851843, 0.46911894165389467, 0.40680027461050694, 0.45365364684404036, 0.5742977135334724, 0.458943224557232, 0.45466251672593505, 0.6197010569934858, 0.46706155