In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict

from thesis_v2.data.prepared.cadena_plos_cb19 import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)

In [2]:
def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
):
    
    load_modules()
    datasets = get_data(
         px_kept=80, final_size=input_size, scale=0.5,
        seed=split_seed
    )

    datasets = {
        'X_train': datasets[0].astype(np.float32),
        'y_train': datasets[1],
        'X_val': datasets[2].astype(np.float32),
        'y_val': datasets[3],
        'X_test': datasets[4].astype(np.float32),
        'y_test': datasets[5],
    }

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_cb19_data_cls4/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
            }
        },
        print_model=True
    )
    
    return result['stats_best']['stats']['test']['corr_mean']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 0,
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,

        
        'smoothness_name': '0.000005',
        

        
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 40,
        'n_timesteps': 4,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

{'fc', 'final_act', 'bn_output', 'pooling'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
num_param 26672
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (l_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (

100-0, train loss 0.30070820450782776
train loss 0.30070820450782776
val metric {'loss': 0.2941181123256683, 'loss_no_reg': 0.2887372374534607, 'corr': [0.07352977748565663, 0.3811018813473136, 0.3520577203636432, 0.39981510367635287, 0.5791615920631378, 0.34135180898475437, 0.33221141054807635, 0.3935964147239117, 0.2831746662724414, 0.11054512771068796, 0.39837875182113797, 0.1370861272883418, 0.34500693280112604, 0.2416060430807726, 0.3535946803993614, 0.16204168773708963, 0.1808962524893817, 0.1321784091506123, 0.10500250236115698, 0.21279174346469898, 0.2220435384494003, 0.35590177381481597, 0.11712833585863666, 0.2961904080378846, 0.20571036923226146, 0.1320560812065854, 0.11410519460260163, 0.4907988418826686, 0.33462458602906464, 0.30013530594158166, 0.23120493184170454, 0.42471128700828076, 0.5349432824554505, 0.23812645248138203, 0.33386457070542375, 0.44413263233141154, 0.11246128240038868, 0.49737754129441786, 0.2826940780822861, 0.04288648347361937, 0.13719316931940448, 0.

300-0, train loss 0.2764189541339874
train loss 0.2764189541339874
val metric {'loss': 0.2750251442193985, 'loss_no_reg': 0.26727405190467834, 'corr': [0.23111257071316416, 0.44644572007492467, 0.41821690281939, 0.44237683776371084, 0.5743599079667427, 0.45314026616227415, 0.3726407337608895, 0.40544406478834805, 0.2984002546400698, 0.1522970087683591, 0.4551578979415099, 0.16032560497434586, 0.4730173998644853, 0.23925862193593145, 0.4873703416918177, 0.17462482632419696, 0.2471674757145974, 0.11749208768045373, 0.23097334397138305, 0.4509510979507269, 0.3796921308916002, 0.5079739488207883, 0.32493949928830795, 0.3702229534327693, 0.40860338916268024, 0.22644818391684607, 0.41595618991386585, 0.5091686462952156, 0.5869266071784124, 0.4063863269997249, 0.347228707742129, 0.44395076919155385, 0.6057457757919283, 0.4253175762132737, 0.3974647018153592, 0.49839321050313595, 0.2590742325154642, 0.5307419686513625, 0.4638674368140465, 0.09176071331340085, 0.32110588081997093, 0.42680614931

500-0, train loss 0.24851959943771362
train loss 0.24851959943771362
val metric {'loss': 0.26748629510402677, 'loss_no_reg': 0.2595166563987732, 'corr': [0.2275702618681873, 0.47768319682385, 0.41702517708280423, 0.4481771088006814, 0.6070233497861612, 0.4785857071495804, 0.3867166108342255, 0.40237026792204367, 0.29608113306915135, 0.15102881293387702, 0.461729551286943, 0.16450380429267333, 0.5346240853130839, 0.25205221471440453, 0.5025243421605152, 0.17431590367344116, 0.2833087328957693, 0.10890155297430357, 0.2570600561527926, 0.5115957109984761, 0.40573260212898293, 0.5499348137185384, 0.34961972905147604, 0.400127970565628, 0.47072992961564863, 0.3909243607162256, 0.44414739853892954, 0.5096583257947754, 0.6021688397028154, 0.40751237013510444, 0.3448016259367966, 0.42709261484026173, 0.6144746426178224, 0.437485261372285, 0.4393329824971095, 0.4948195863784373, 0.26152547040423146, 0.5353074632937962, 0.5156740154490499, 0.10765723447304015, 0.3882109866978342, 0.5115410804882

700-0, train loss 0.25937291979789734
train loss 0.25937291979789734
val metric {'loss': 0.26404690444469453, 'loss_no_reg': 0.2560061514377594, 'corr': [0.21834686713541748, 0.4837841774367565, 0.42425903484790023, 0.4513272347733669, 0.6111861099021649, 0.4729798550041288, 0.39960057630671775, 0.40424430495141656, 0.2928051059683793, 0.1599528237131769, 0.4701047063125687, 0.19029171384091778, 0.5342727829211436, 0.2539011570876885, 0.5012074992164625, 0.18675770993298724, 0.2903608495064066, 0.17676295240942808, 0.26617455229896175, 0.5144303920865381, 0.4348930397176336, 0.5629179137168552, 0.35092215102012336, 0.40520563427874934, 0.5003827727340785, 0.44792876467850873, 0.4399763085295531, 0.5179707213908806, 0.6045674596026095, 0.40557439415168983, 0.3576158963381751, 0.42859564084669166, 0.6169225558530872, 0.4410022179714946, 0.44479683706065903, 0.5039488019917402, 0.27818635951307785, 0.542521760850924, 0.5390161689936367, 0.10927556542519526, 0.38477171000213717, 0.53172653

900-0, train loss 0.23992519080638885
train loss 0.23992519080638885
val metric {'loss': 0.26337526440620423, 'loss_no_reg': 0.25572216510772705, 'corr': [0.2371300399057273, 0.48972016592917633, 0.42229868590509106, 0.44911412611845675, 0.6113742417483408, 0.45452009287645423, 0.39594770917202365, 0.4094494299635464, 0.2863714614050659, 0.15701703752821167, 0.47586203053629306, 0.20214737080627468, 0.5412637550792179, 0.2685652337467055, 0.5039352350396651, 0.17743182723061846, 0.29505524652320203, 0.21361683425091638, 0.2619973914333102, 0.5115558405550028, 0.4299396143014692, 0.5575989612004507, 0.3528909262490127, 0.3999992116357667, 0.49176707701610173, 0.458781238806122, 0.4429025402751509, 0.5139129141914458, 0.5980743151549752, 0.41606273709125224, 0.3537737645631052, 0.4352061771546066, 0.6162336154977195, 0.4275536193748551, 0.4456113519904911, 0.5029449857928464, 0.27904828645436075, 0.5399230930453688, 0.5359470424287509, 0.11739113306521883, 0.37752863939351533, 0.52174555

1100-0, train loss 0.2411472350358963
train loss 0.2411472350358963
val metric {'loss': 0.2631061375141144, 'loss_no_reg': 0.25580137968063354, 'corr': [0.2306411605245516, 0.48519548339867347, 0.4279955467769671, 0.4444539947840025, 0.599146971730382, 0.4537840679951309, 0.4029916377242118, 0.41883265650488083, 0.28405721199357375, 0.1552535505657241, 0.48134357820813345, 0.19698263202383584, 0.5497076893961376, 0.27768538709787777, 0.516381484357788, 0.17964765485919465, 0.285821613460979, 0.2479255911457924, 0.2530607773423153, 0.5133606857628781, 0.43032506902141654, 0.5723668690980589, 0.35859934380950453, 0.4117583982637629, 0.48724744046636953, 0.46300059960889695, 0.44279521796854127, 0.5229355384628074, 0.5997028552429274, 0.4124531403563151, 0.3680534317279197, 0.442610839250977, 0.5993582153237706, 0.43790032557351344, 0.4368318790812738, 0.5021268670231771, 0.29694074517713764, 0.5439390360153508, 0.5421377441408253, 0.12273414592957602, 0.38396660345504624, 0.5327699543780

1300-0, train loss 0.2350364327430725
train loss 0.2350364327430725
val metric {'loss': 0.26559517085552214, 'loss_no_reg': 0.2606021463871002, 'corr': [0.2320026838924136, 0.4828004956976336, 0.4197334685515779, 0.4458923535561061, 0.6065560341013404, 0.41982580895104726, 0.39844671269096815, 0.41151287934464204, 0.281596814165224, 0.15989004649646968, 0.47294635751592956, 0.21236506450011233, 0.5399886138214792, 0.2782644318426365, 0.5151598915406328, 0.18092793697489082, 0.2953004993267186, 0.2561363557279384, 0.27060953774739516, 0.49615565197623246, 0.42498856802159174, 0.5774774938001879, 0.35268407347380937, 0.41538301490324797, 0.5004912524508187, 0.450790126079794, 0.44772983896046087, 0.5197448151495305, 0.6026811568114049, 0.4060555708321628, 0.37696282175400847, 0.44258843126530517, 0.5998173664076305, 0.4332342271867807, 0.43352975346103073, 0.5032125374735035, 0.2923500403632671, 0.5384907935728112, 0.5452912089475421, 0.12440943824037164, 0.3911382300460496, 0.5382723617

1500-0, train loss 0.24269318580627441
train loss 0.24269318580627441
val metric {'loss': 0.26311068534851073, 'loss_no_reg': 0.25498196482658386, 'corr': [0.2431073752189647, 0.47991267338039717, 0.42572447613260006, 0.4626283371711714, 0.6027242321967177, 0.46278536503285744, 0.4021267400959855, 0.4160867075500386, 0.2688437203343045, 0.16148644620141123, 0.478523632509284, 0.21510726957609277, 0.5451611000808003, 0.2774045053801859, 0.5136205866338281, 0.17791393227655644, 0.2948399510826162, 0.2470781959263136, 0.21769104062734468, 0.5051841148944828, 0.4340819852590296, 0.5704699804236307, 0.3574073288184366, 0.4189985482306736, 0.500064877060145, 0.45528597888227407, 0.45873627979825615, 0.5136705921355155, 0.6021573096684756, 0.41397424463764426, 0.38157031530525065, 0.4349684148715101, 0.6139695534820362, 0.4327845662386966, 0.44577132183728213, 0.49981602136033393, 0.29014052290176484, 0.5344388542913894, 0.5414424516037014, 0.11892279042045444, 0.382200257882632, 0.5300721974

1700-0, train loss 0.25313863158226013
train loss 0.25313863158226013
val metric {'loss': 0.26405700445175173, 'loss_no_reg': 0.25615760684013367, 'corr': [0.23627689975242533, 0.4828160340122262, 0.4163479227266098, 0.4547785419374698, 0.5948132098468168, 0.4478485933850442, 0.394047449793965, 0.41974407720806567, 0.2713408595349975, 0.1679076351490965, 0.47124824226792594, 0.21183150243531004, 0.5533110002477669, 0.2813743261735011, 0.5166314834085166, 0.1756028227739176, 0.295810174972328, 0.2522098972880132, 0.22147531849076044, 0.5045766411087153, 0.4388280966631643, 0.5784007077997937, 0.3468638002371004, 0.41631493393813507, 0.49263375170155177, 0.4588644927704657, 0.4611442399204811, 0.5120776510207343, 0.5950674350543467, 0.4128484459255421, 0.38448958038254816, 0.4357598277416445, 0.6132600853149742, 0.4443466040389624, 0.44939228775828927, 0.4998443123746291, 0.2997595713910392, 0.538238368912775, 0.5451275982366951, 0.12597294605013698, 0.37987335626036567, 0.53275926809600

100-0, train loss 0.2541203796863556
train loss 0.2541203796863556
val metric {'loss': 0.263776621222496, 'loss_no_reg': 0.2575703561306, 'corr': [0.23138011952777926, 0.48182763468346035, 0.4236760154073214, 0.4551092222738142, 0.5972539740186823, 0.4458302766695623, 0.39598769981060006, 0.41705471326305843, 0.27782303971735606, 0.1610976403413053, 0.4773049060855974, 0.21437227062866973, 0.5397645937506366, 0.2772901645368045, 0.517165663952007, 0.17792031650489176, 0.2905439569253483, 0.2535976746852443, 0.24908193649109522, 0.5022481690703511, 0.43003964417725093, 0.5758559838524613, 0.3550131269593772, 0.4158603255515267, 0.5011044887617718, 0.45178843935330737, 0.45508968740280986, 0.5176646958692466, 0.5984236687775001, 0.41621015233697106, 0.3777312892578097, 0.4371922689237649, 0.5995634263878407, 0.43893692242134, 0.44814759155615747, 0.5049761458000441, 0.29445568648984954, 0.5386438634795314, 0.5436034712137502, 0.1174880069857319, 0.38206851389342383, 0.5369087047689537, 0

300-0, train loss 0.2574828565120697
train loss 0.2574828565120697
val metric {'loss': 0.2636892259120941, 'loss_no_reg': 0.2567813992500305, 'corr': [0.24059185440636113, 0.4841769789460034, 0.4231250179199715, 0.4541228538441523, 0.5958710329550838, 0.4606173145258347, 0.39989629986126063, 0.4160734888272032, 0.2771205755558102, 0.16191140784020458, 0.4808540212075799, 0.2132985540077915, 0.5411889353951164, 0.28010932243743525, 0.5161910448539514, 0.1767899021861218, 0.29362773368254974, 0.2511443216757504, 0.22682855602220456, 0.5042461674081232, 0.43482464319756864, 0.5785299047105921, 0.3537224877438396, 0.4172386279716451, 0.5052876157321905, 0.45127418350931736, 0.45893672114003375, 0.519861339234938, 0.599282711564386, 0.41307677699119383, 0.377477846685627, 0.44395189373684535, 0.6045751979371028, 0.4345429721809942, 0.44345587474375237, 0.5033283806482071, 0.2986954466627875, 0.5390880899725132, 0.5478708125776027, 0.11935283769528969, 0.3865587561660706, 0.5368301722636549,

500-0, train loss 0.23885618150234222
train loss 0.23885618150234222
val metric {'loss': 0.26266120076179506, 'loss_no_reg': 0.2548428475856781, 'corr': [0.23514778117143634, 0.4840004577628173, 0.4199700781501948, 0.45574232447238483, 0.5965650105797023, 0.44235881925945797, 0.39319648095912285, 0.41796650796643386, 0.27374231215261136, 0.16028183193183843, 0.4805225106939598, 0.21052429030832778, 0.5457608816528634, 0.2773945095173317, 0.5123953077518326, 0.17768485625617667, 0.2939648819689644, 0.25355459059625274, 0.2221120535679841, 0.4970432294492205, 0.4332722170416648, 0.577999896317544, 0.35510631788595726, 0.41877712930384603, 0.5052254581547511, 0.45053321934541474, 0.4572823045292571, 0.5194934871944856, 0.5962816766968019, 0.4169484664973528, 0.38233467362716417, 0.43747306755917664, 0.6062856812224793, 0.43672503061682333, 0.44383633676242545, 0.5017003467910881, 0.2911600165343996, 0.5388364354149053, 0.5451461497263965, 0.11778036503105056, 0.3840835232025903, 0.5331699

100-0, train loss 0.2303396463394165
train loss 0.2303396463394165
val metric {'loss': 0.26162215769290925, 'loss_no_reg': 0.2536889910697937, 'corr': [0.23714432866366675, 0.48505711746914804, 0.4248369133015218, 0.45573895805802755, 0.6019489150095503, 0.46399628113509417, 0.39881482555688946, 0.4174068672164165, 0.27950209606917203, 0.15904872090041638, 0.47726651461045577, 0.21060136251684877, 0.5446359642737967, 0.27682362371441327, 0.5140855623324928, 0.17526785302053985, 0.29633232844748486, 0.24953638271202994, 0.23123329101487133, 0.5079612721465661, 0.4355867795403256, 0.5751167271435247, 0.35554150120413225, 0.4180189695938364, 0.504685292369763, 0.4550920970526156, 0.4552169152995411, 0.5215815153796576, 0.5996825097011762, 0.41438330246083566, 0.37687586633765285, 0.4408147025106698, 0.6066773878391314, 0.4373518376702043, 0.44255998529614127, 0.504284955548806, 0.2918382556092762, 0.5405185956278206, 0.5486305372060833, 0.11825397624996162, 0.38850242142993935, 0.53494683

300-0, train loss 0.2377917468547821
train loss 0.2377917468547821
val metric {'loss': 0.2619974732398987, 'loss_no_reg': 0.2541056275367737, 'corr': [0.2366442423846331, 0.48405766038307485, 0.4229614798465568, 0.4546967477350136, 0.6001694003257325, 0.4605995524258923, 0.3993219757491716, 0.41705671413812384, 0.2784508454455988, 0.16006757632904153, 0.4783787765512635, 0.21152480148466993, 0.544628090624244, 0.277823400046886, 0.5150140474170446, 0.17491769829017328, 0.2953544047438003, 0.2520251333931383, 0.23165809396207582, 0.5050555943953016, 0.4361392191689474, 0.5773331769616925, 0.3552977496340056, 0.41823910145582477, 0.5042980180848824, 0.4501441591414533, 0.45670253570994795, 0.5216817414281965, 0.5982660299902527, 0.41588603914148087, 0.3783027600198027, 0.44039915553077974, 0.6082746515848266, 0.4358105926201161, 0.4445433109461281, 0.5033793124349103, 0.2902068460947286, 0.5403131689116814, 0.548092212384813, 0.11734876331144796, 0.38551723184353226, 0.535201766633293, 0

500-0, train loss 0.26022714376449585
train loss 0.26022714376449585
val metric {'loss': 0.2620244026184082, 'loss_no_reg': 0.25406867265701294, 'corr': [0.2369008757162004, 0.48365205451844967, 0.4228876506248005, 0.4557517341378521, 0.601346016857262, 0.4595862132089544, 0.39827357419098, 0.41726324569765044, 0.2783917594940024, 0.1590279590601848, 0.4787750112004875, 0.21200599015461952, 0.5449391385778771, 0.27821918592835443, 0.5148943820172154, 0.1754774980644254, 0.2950686678026884, 0.2527000668917474, 0.22707447480758114, 0.5037484345241477, 0.4350717685788964, 0.5741354034113485, 0.35555890757216424, 0.4183466852742568, 0.5030656167929896, 0.452486648085735, 0.4562667233919014, 0.5207513560045706, 0.5978797765747886, 0.4149156524933793, 0.37940274279512815, 0.44063736307219276, 0.6101624574824717, 0.43565944531485307, 0.44237564664414986, 0.5014551304002552, 0.2909464886114778, 0.5405349835115054, 0.5473283036076255, 0.11864910508533022, 0.38439268609772653, 0.5339854305438388

0.5071455326246751
