In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.yuanyuan_8k import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)



In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
    multi_path_separate_bn,
):
    
    load_modules()
    datasets = get_data('a', 200, input_size, ('042318', '043018', '051018'), scale=0.5,
                        seed=split_seed)

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
            multi_path=True,
            multi_path_separate_bn=multi_path_separate_bn,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_multipath_train/sep_bn{multi_path_separate_bn}/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 'legacy',
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        'smoothness_name': '0.000005',
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 50,
        'n_timesteps': 4,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=False))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

# results here are exactly the same as those in `scripts/debug/rcnn_basic_kriegeskorte/test_full_training_ff_1st.ipynb`

{'fc', 'final_act', 'pooling', 'bn_output'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 27629
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

In [4]:
print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular, multi_path_separate_bn=True))

{'fc', 'final_act', 'pooling', 'bn_output'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 29613
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

val metric init {'loss': 0.14660580158233644, 'loss_no_reg': 0.14377044141292572, 'corr': [0.025953612617967373, 0.06104704123918449, 0.023214588972358716, -0.015401625471615048, 0.01084001978212662, 0.047446847083557606, 0.04751394740794557, 0.044813407282045664, 0.11627672076307788, 0.042581391427912174, -0.030209710836390347, -0.06819483045948871, -0.1722509112831257, 0.1402520756842594, 0.006931831810783454, 0.15311124247646657, -0.033644680980999084, -0.024810609706296102, 0.006269847022357063, -0.09549268349417248, -0.028799940037346164, -0.005884994208127126, 0.057746755059920565, 0.006341857667328713, -0.216329817859501, 0.006169171857100775, 0.025304915867388075, -0.008566454235207632, 0.020416239924452642, 0.14458311892280074, 0.05034601033002685, 0.04991975947681433, -0.07579586415908462, -0.03322505323487416, 0.08175337810125836, -0.21376084904135945, 0.03504463808618788, -0.15038714543546094, -0.1113710104181104, -0.14618462578746222, -0.023334777764781246, -0.031625366031

test metric {'loss': 0.12506073181118285, 'loss_no_reg': 0.12044830620288849, 'corr': [0.25530327498061217, 0.44048187153798046, 0.5104142033027428, 0.7216850160845549, 0.4437268178876011, 0.489187788654863, 0.43321508351349747, 0.43966694017360464, 0.6394482307777658, 0.666013193484752, 0.63575953070083, 0.4168380290972187, 0.5225858849855805, 0.648193541667712, 0.4026444941570211, 0.5362936335376635, 0.6105525140187484, 0.5834531094587853, 0.542972265443011, 0.5768095788117292, 0.4421297625108152, 0.4068565403797053, 0.19425863987520997, 0.08195336276107577, 0.7111861023148017, 0.4634220690864044, 0.32432671126541895, 0.23688775653369437, 0.5920807171401622, 0.4816378519622494, 0.2692632971023675, 0.5623712534553453, 0.35425061119782447, 0.40153833365042935, 0.5802989787550801, 0.48488505088601075, 0.5332282174745412, 0.5106555024972874, 0.36468880901202017, 0.3879273333368405, 0.3040576183017861, 0.48861214459675556, 0.42279190670897115, 0.38126934674301527, 0.46526683857889717, 0.3

500-0, train loss 0.11132407933473587
train loss 0.11132407933473587
val metric {'loss': 0.11534538269042968, 'loss_no_reg': 0.11060020327568054, 'corr': [0.2661510658215487, 0.49902934848143543, 0.5857138926304856, 0.7536791130673335, 0.5170731751898362, 0.514653640503831, 0.4327602961260444, 0.5995249964798628, 0.6730141769885325, 0.6955243691644151, 0.679912658939944, 0.5161316521332393, 0.5911891690598189, 0.6505445474850491, 0.6148570301135302, 0.6021631327738203, 0.7034382321196697, 0.594990054898395, 0.577119858042459, 0.6815023519161385, 0.46542003511270286, 0.44534507960798864, 0.31292748195985, 0.09922616086931095, 0.7554176688577541, 0.5851655242397442, 0.5458796878405036, 0.3531454974258576, 0.6719030972138642, 0.5545932020098191, 0.32731572041476, 0.5725826843627093, 0.47243751528502953, 0.5757467358489561, 0.6289383933345126, 0.5275040374452103, 0.6150399851437718, 0.6143202063266044, 0.4564859967366363, 0.41787871462255044, 0.4232292540632996, 0.5804129152088553, 0.46051

test metric {'loss': 0.11414583453110286, 'loss_no_reg': 0.10952426493167877, 'corr': [0.3088737890376755, 0.5046133304020431, 0.6014249344997131, 0.7728283794941253, 0.498733915047482, 0.577842358430407, 0.5024687749092831, 0.6342606367807717, 0.6714129311131474, 0.7213272799250497, 0.6892739952119156, 0.47425996991759, 0.582898513506602, 0.6741481758692051, 0.5890038549854464, 0.6163743494603575, 0.7244407997315091, 0.643119033339023, 0.5910952895719592, 0.66716885469453, 0.5061655899647484, 0.46810105841003274, 0.2994364294710012, 0.07581799508332113, 0.747754217810888, 0.6134062925256445, 0.5267264688144087, 0.3338686355441878, 0.6508304200636273, 0.5417547640934142, 0.35608427803683185, 0.5955018163974581, 0.4113282523759123, 0.579085058075399, 0.6443970594284207, 0.552682538009297, 0.6407849523060635, 0.6112418734358445, 0.45467274510343214, 0.42193041748187515, 0.44630536818674404, 0.5821842168136072, 0.4758083847404803, 0.4763126413491364, 0.5760681451667802, 0.4687686814105161

1000-0, train loss 0.11321166157722473
train loss 0.11321166157722473
val metric {'loss': 0.11245597153902054, 'loss_no_reg': 0.10811379551887512, 'corr': [0.2698320036332936, 0.5073108507167183, 0.6139962691754288, 0.7728857320328759, 0.5300048959188318, 0.5190286112892568, 0.49405600144759254, 0.631405780022525, 0.6740593274612829, 0.70905468077387, 0.6880385907525539, 0.5392686863744877, 0.6140001417226979, 0.6508819042272441, 0.6352058483346663, 0.6116296686541336, 0.7183029055793999, 0.6247868132884717, 0.5924039968420451, 0.6912369885570347, 0.4806615706295464, 0.46683615813283724, 0.3285543255852732, 0.10465817217498508, 0.7663120611932439, 0.6288395817657092, 0.5485707521824841, 0.3799111291821493, 0.6825341977697895, 0.563702967359708, 0.3723368103927446, 0.5909933848185764, 0.5163763276671747, 0.5998842459422751, 0.6328524749478422, 0.5286185796898123, 0.6388714830287459, 0.6522224232940173, 0.49879069395296627, 0.40624248138686353, 0.4585846086172166, 0.5961209571316037, 0.4

test metric {'loss': 0.1129358389547893, 'loss_no_reg': 0.10882596671581268, 'corr': [0.33941816834538874, 0.5193140576248756, 0.603547632029083, 0.7760919558863415, 0.5159218762167679, 0.5775797882307349, 0.512793799036131, 0.6457033192193505, 0.6674205301559814, 0.713462483584714, 0.7001954249302499, 0.49496223894452623, 0.6007776252996606, 0.6667363527141443, 0.6045476843143076, 0.6187307638067641, 0.7305771797464894, 0.6343686442421336, 0.6054119215878406, 0.6671729430329718, 0.5114104360547855, 0.4739894531843103, 0.307571240419671, 0.08372261460580041, 0.7539856256966407, 0.6400413756126605, 0.5498527595397497, 0.34058478939178416, 0.6531541529263801, 0.553795040370757, 0.3814377100476463, 0.5966532365394828, 0.4333425529344109, 0.6050886236130584, 0.6448522944469711, 0.5518987631901278, 0.64582504251039, 0.6329911410991298, 0.46548956110946044, 0.41175646160891016, 0.4560488558704211, 0.5844503952696447, 0.4577065304764574, 0.4757517889400789, 0.6078172300939981, 0.4756240671417

1500-0, train loss 0.1131918802857399
train loss 0.1131918802857399
val metric {'loss': 0.11161967366933823, 'loss_no_reg': 0.10744824260473251, 'corr': [0.27829683971512476, 0.505801043078581, 0.6155858169430689, 0.772392375917517, 0.5464663516678756, 0.5213413161944705, 0.5259071594092731, 0.6230113670475772, 0.6752267861194156, 0.7108938249714107, 0.6924158806428, 0.5497367670563084, 0.6196040409102646, 0.652577549608788, 0.6339655052031705, 0.6236969562913857, 0.7189816824758448, 0.6174510625587665, 0.6022618098371767, 0.6926541355536184, 0.49114812446261236, 0.4743141512565766, 0.3362319352158733, 0.1047020017813941, 0.7699884004513733, 0.6375779761701632, 0.5582500090620133, 0.3783560002431752, 0.6886319318371054, 0.5692154747108835, 0.371469369581373, 0.5942823088675793, 0.5237774502263911, 0.5906934083910995, 0.6265104404341625, 0.5301015168303586, 0.6415002451981371, 0.6559472504225832, 0.5007795687540012, 0.4090879190522493, 0.4472039169391008, 0.6106631129866753, 0.474166305

test metric {'loss': 0.11221536993980408, 'loss_no_reg': 0.10838744789361954, 'corr': [0.34115909334694855, 0.5305212666288373, 0.6046115103958393, 0.7765311650542065, 0.5256570190362448, 0.5767262732379005, 0.5399185449880224, 0.649496919527948, 0.6657202934103416, 0.7208193736932091, 0.7056078361781184, 0.49446015859025527, 0.6120085545088916, 0.6789723487407122, 0.6117150892236743, 0.6333769016628626, 0.7479086430788964, 0.6337789957625145, 0.6058187009999941, 0.6664442370859702, 0.5218255875392199, 0.47635840240701083, 0.3033147487935064, 0.0885403014619722, 0.7624786571154967, 0.6369541363861824, 0.5457782624069945, 0.35338996100859565, 0.6596188996130155, 0.5536585958448947, 0.38772749383670446, 0.5979834066486278, 0.45288581632661873, 0.6005658173589146, 0.6440547602714437, 0.5547416570169377, 0.6437372329138665, 0.6337579392058802, 0.4678514845567862, 0.4205814045275195, 0.46083687758765146, 0.5943386959111736, 0.458420734345659, 0.47204143965143147, 0.6285713836308469, 0.46473

2000-0, train loss 0.10324560105800629
train loss 0.10324560105800629
val metric {'loss': 0.11232100278139115, 'loss_no_reg': 0.10857202857732773, 'corr': [0.28157265458791, 0.5174799472628289, 0.6223020488941189, 0.7755690774977801, 0.5451837028637261, 0.5223037316275477, 0.5236329271714464, 0.6349766438809163, 0.6766344786200368, 0.7123031475233784, 0.6989605746584795, 0.544758093672064, 0.6245865156225339, 0.6477746263762478, 0.6383705648831304, 0.6173456384317256, 0.7195940217115852, 0.6255433488017013, 0.5982524090250616, 0.6941299485816712, 0.49698637400706286, 0.4672660274575159, 0.33280571764488853, 0.08848347330004017, 0.7603347661492943, 0.6317719375482378, 0.545389812102735, 0.3863009384955719, 0.6894244976088854, 0.5736457468717883, 0.38203752028101057, 0.5912789194432155, 0.5432641779512558, 0.5958408528848163, 0.6358506064959925, 0.5289122387612979, 0.6335194159285907, 0.6587646813781034, 0.513313782809266, 0.4099871367460155, 0.4461435550943011, 0.610258480376873, 0.4726

100-0, train loss 0.10232219845056534
train loss 0.10232219845056534
val metric {'loss': 0.11104811280965805, 'loss_no_reg': 0.10693413764238358, 'corr': [0.280867646046561, 0.5089150919297851, 0.6215669048983046, 0.7734984741980204, 0.5461573121946682, 0.5241504626778154, 0.5307030533684698, 0.6297318348694985, 0.6813897584517099, 0.7136929308437192, 0.6983121878537437, 0.550533309710166, 0.6225722280115026, 0.6536477849742569, 0.6321778083161949, 0.6210986941310463, 0.7204842226022341, 0.6200511305498831, 0.6025024850809056, 0.698664431756413, 0.49041744065904136, 0.47141751728005593, 0.33652942657285223, 0.09652060309932056, 0.7633594484373372, 0.6331040767472659, 0.5519017248331484, 0.38918997811923994, 0.6889948895092716, 0.5729325400128311, 0.37294573855977414, 0.5973833832317386, 0.5336732971992137, 0.5958963749690527, 0.6367051464930533, 0.5262902619518156, 0.631450579453473, 0.6626041319787104, 0.5135870208111355, 0.40843486243341875, 0.44511770413339014, 0.6148263270683444, 0

test metric {'loss': 0.11143513875348228, 'loss_no_reg': 0.10728379338979721, 'corr': [0.34453554218852844, 0.52495898577295, 0.6064453407110462, 0.7759083310064598, 0.5225966227361716, 0.5810400064565283, 0.5397632869801401, 0.6499055942442071, 0.6748563651238654, 0.7289655044678232, 0.7078339914216913, 0.49354697608729264, 0.6131828404137246, 0.6775392930646148, 0.6135080296494955, 0.6323274139710399, 0.7458073107723568, 0.6387469913180462, 0.605856481419806, 0.6704922886142128, 0.5199609081227483, 0.4739957141487714, 0.31134389045458344, 0.09047371953748506, 0.7645796196005124, 0.6413407918857038, 0.5469907988556062, 0.3560450196937568, 0.6618199211184019, 0.5581323864109637, 0.3814394019781192, 0.5968537369670002, 0.4538257409507394, 0.6006907479562317, 0.6446179092586419, 0.5587759462828366, 0.6430497798261814, 0.6414334485805733, 0.47356342643760135, 0.4215688572247421, 0.4614231486657372, 0.592220369940007, 0.45772308096252307, 0.47866862758257955, 0.6288701891576822, 0.47026610

600-0, train loss 0.10754410177469254
train loss 0.10754410177469254
val metric {'loss': 0.11098619997501373, 'loss_no_reg': 0.10690272599458694, 'corr': [0.2833731596196008, 0.5129765814366751, 0.6226870196017542, 0.774643876727952, 0.5522951566645391, 0.5228401971737717, 0.5341065041686178, 0.637190267568083, 0.6783704983224889, 0.7100859802863879, 0.696325604965491, 0.5542177851461818, 0.6223980746388096, 0.6531276680506534, 0.6328891276998421, 0.6177703755153514, 0.7217830658572211, 0.6231963863583871, 0.6011794772712398, 0.702713145593733, 0.49519021604298324, 0.4768630641134183, 0.3327132510936455, 0.09255280116150745, 0.7666568162013263, 0.6329391916821905, 0.5496091042579051, 0.3871338034218663, 0.6894516812684834, 0.5689226994016212, 0.3838517597779091, 0.5924055025596684, 0.5359732265099514, 0.6002295813127903, 0.6344588011375805, 0.5257479395140079, 0.6318360669353282, 0.6624061879424478, 0.5141360277432703, 0.4073065249673039, 0.4426080474040731, 0.6156827130300442, 0.47757

100-0, train loss 0.10987178981304169
train loss 0.10987178981304169
val metric {'loss': 0.11065865308046341, 'loss_no_reg': 0.10649164021015167, 'corr': [0.28332393175416193, 0.5102958258657135, 0.6221008665599759, 0.7731854290326421, 0.5515378987478388, 0.5240586461218358, 0.5329642260826092, 0.6358024072604733, 0.6813725108999659, 0.7129759843635826, 0.6974831765168499, 0.551895904620493, 0.622091023488911, 0.6528310116095211, 0.6345877537930193, 0.6215119744372063, 0.7199068230664207, 0.6209944882551094, 0.6029480418600757, 0.6983738249725454, 0.4922465773143243, 0.4745422639804199, 0.33549525243354683, 0.09447164146038124, 0.7679574171737049, 0.6346542856861849, 0.5526398121740335, 0.38836992733022024, 0.6891640898502108, 0.5711168619640548, 0.37778301132136294, 0.5983462103184275, 0.5333887698460057, 0.6001633773238717, 0.636443830598153, 0.5280709029633487, 0.6348124166294477, 0.6626093954799959, 0.5131963850797461, 0.4084439689720603, 0.4512346948372905, 0.6156213537827854, 0.4

test metric {'loss': 0.11109844063009534, 'loss_no_reg': 0.10680586844682693, 'corr': [0.3435322037279452, 0.5229044483777043, 0.6066387119614538, 0.7772257974370624, 0.5221817065348872, 0.5814006364951955, 0.5404535500305649, 0.6495425930184737, 0.674178769500079, 0.7283860496750589, 0.7087015702658561, 0.4946432394700373, 0.6126292785533851, 0.6780046919612612, 0.6141973225374078, 0.6328546046910248, 0.7442171768850538, 0.6385917867003542, 0.6083899528488952, 0.6709959891568259, 0.5223470043261937, 0.4776000078282837, 0.3115620484110664, 0.09025308702216722, 0.7654258361278959, 0.6405623832682339, 0.5495410882842574, 0.3539608985431094, 0.661159841152035, 0.5572719423760806, 0.38606755586695485, 0.5952110079485811, 0.45483630988984725, 0.6021089194336232, 0.6469919820483571, 0.557300903143677, 0.6447773111748075, 0.6406428355548988, 0.47276000164094706, 0.41943770165878835, 0.4615345378126047, 0.59190520403006, 0.4627795522201504, 0.4770668703082299, 0.6310860124238502, 0.47319883381

600-0, train loss 0.10657568275928497
train loss 0.10657568275928497
val metric {'loss': 0.11087391674518585, 'loss_no_reg': 0.106857068836689, 'corr': [0.2816586736602015, 0.510481569550925, 0.6227996470346546, 0.7723293793559962, 0.549220493789303, 0.5233865553440916, 0.5337955818648594, 0.6373011352397506, 0.6812459147173893, 0.7135579477029309, 0.6969608824581837, 0.5512156262160601, 0.6217575741294663, 0.6513425431690562, 0.6333716736085477, 0.6209655138831955, 0.7200613720973877, 0.619962415110173, 0.6030528867577329, 0.6987784252637239, 0.49308667685931323, 0.47277250163129614, 0.3356338189061966, 0.09380757385208237, 0.766751433833895, 0.634312980785996, 0.550381792144467, 0.3881328886897168, 0.6888924460055856, 0.5715503709229601, 0.3789000570091461, 0.5982910196141291, 0.5323257964918621, 0.6003815668389477, 0.6373128484685409, 0.5285675601227271, 0.6325182932083041, 0.6633456500547728, 0.5139175657690893, 0.4067040711641081, 0.4495560280958442, 0.6148863922586146, 0.47273994