In [1]:
# test full training of maskcnn_polished_with_rcnn_k_bl
from os.path import join

import numpy as np

import torch

from thesis_v2 import dir_dict
from thesis_v2.data.prepared import combine_two_separate_datasets

from thesis_v2.data.prepared.cadena_plos_cb19 import get_data

from thesis_v2.training_extra.maskcnn_like.opt import get_maskcnn_v1_opt_config
from thesis_v2.training_extra.maskcnn_like.training import (train_one,
                                                            partial)

from thesis_v2.models.maskcnn_polished_with_rcnn_k_bl.builder import (
    gen_maskcnn_polished_with_rcnn_k_bl,
    load_modules
)

In [2]:
def pad_nan(y_this):
    pad = np.full_like(y_this, fill_value=np.nan)
    return np.concatenate([pad, y_this, pad], axis=0)

def dup(x_this):
    return np.concatenate([x_this, x_this, x_this], axis=0)

def train_one_maskcnn_polished_with_rcnn_k_bl(
    split_seed,
    model_seed,
    act_fn,
    loss_type,
    input_size,
    out_channel,
    num_layer,
    kernel_size_l1,
    pooling_ksize,
    scale, scale_name,
    smoothness, smoothness_name,
    pooling_type,
    n_timesteps,
):
    
    load_modules()
    datasets = get_data(
         px_kept=80, final_size=input_size, scale=0.5,
        seed=split_seed
    )

    datasets = {
        'X_train': dup(datasets[0].astype(np.float32)),
        'y_train': pad_nan(datasets[1]),
        'X_val': dup(datasets[2].astype(np.float32)),
        'y_val': pad_nan(datasets[3]),
        'X_test': dup(datasets[4].astype(np.float32)),
        'y_test': pad_nan(datasets[5]),
    }
    
    for zzzz in datasets:
        print(zzzz, datasets[zzzz].shape)

    def gen_cnn_partial(input_size_cnn, n):
        return gen_maskcnn_polished_with_rcnn_k_bl(
                                    input_size=input_size_cnn,
                                    num_neuron=n,
                                    out_channel=out_channel,
                                    kernel_size_l1=kernel_size_l1,  # (try 5,9,13)
                                    kernel_size_l23=3,
                                    act_fn=act_fn,
                                    pooling_ksize=pooling_ksize,  # (try, 1,3,5,7)
                                    pooling_type=pooling_type,  # try (avg, max)  # looks that max works well here?
                                    num_layer=num_layer,
                                    n_timesteps=n_timesteps,
                                    factored_constraint=None,
                                    blstack_pool_ksize=1,
                                    blstack_pool_type=None,
                                    acc_mode='cummean',
                                    bn_after_fc=False,
                                    ff_1st_block=True,
                                    )

    opt_config_partial = partial(get_maskcnn_v1_opt_config,
                                 scale=scale,
                                 smoothness=smoothness,
                                 group=0.0,
                                 loss_type=loss_type,
                                 )
    
    result = train_one(
        arch_json_partial=gen_cnn_partial,
        opt_config_partial=opt_config_partial,
        datasets=datasets,
        key=f'debug/test_full_training_maskcnn_polished_with_rcnn_k_bl/ff_1st_block_cb19_data-handle_nan-nan_y_3/{model_seed}',
        show_every=100,
        max_epoch=40000,
        model_seed=model_seed,
        return_model=False,
        extra_params={
            # reduce on batch axis
            'eval_fn': {
                'yhat_reduce_axis': 1,
                'handle_nan': True,
            }
        },
        val_test_every=150,
        print_model=True,
        handle_nan=True,
    )
    
    return result['stats_best']['stats']['test']['corr_mean'], result['stats_best']['stats']['test']['corr']

In [3]:
maskcnn_param_template = {
    'out_channel': 16,
    'num_layer': 3,
    'kernel_size_l1': 9,
    'pooling_ksize': 3,
    'pooling_type': 'avg',
    'model_seed': 0,
    'split_seed': 0,
}

maskcnn_param_regular = {
    **maskcnn_param_template,
    **{
        'act_fn': 'relu',
        'loss_type': 'mse',
        'smoothness': 0.000005,
        # 0.000005 gives 0.505823562521005
        # 0.00005 gives 0.5051972572416786
        # 0.0000005 gives 0.5043545818049401
        
        'smoothness_name': '0.000005',
        
        # with smoothness set to 0.00005,
        # scale=0.01 gives 0.505823562521005
        # scale=0.1 gives 0.48046038315427236
        # scale=0.001 gives 0.4979968643549016
        
        # my previous hyperparameters worked the best!!!
        
        'scale': 0.01,
        'scale_name': '0.01',
        'input_size': 40,
        'n_timesteps': 1,
    }
}

print(train_one_maskcnn_polished_with_rcnn_k_bl(**maskcnn_param_regular))

# 27629 = 1*2 + 16*1*9*9 + 16*2 + 2*(16*16*3*3+16*16*3*3+4*16*2) + 79*(14*14 + 16 + 1) 

X_train (13920, 1, 40, 40)
y_train (13920, 115)
X_val (3480, 1, 40, 40)
y_val (3480, 115)
X_test (4350, 1, 40, 40)
y_test (4350, 115)
{'fc', 'bn_output', 'pooling', 'final_act'}
['conv0', 'bl_stack.layer_list.0.b_conv', 'bl_stack.layer_list.1.b_conv']
neurons with NaN mean on train 0/115
num_param 21872
JSONNet(
  (moduledict): ModuleDict(
    (accumulator): RecurrentAccumulator()
    (act0): ReLU()
    (bl_stack): BLConvLayerStack(
      (layer_list): ModuleList(
        (0): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): BLConvLayer(
          (b_conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (bn_layer_list): ModuleList(
        (0): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (capture_list): Modu

400-0, train loss 0.266475647687912
train loss 0.266475647687912
500-0, train loss 0.2364434152841568
train loss 0.2364434152841568
600-0, train loss 0.2217712551355362
train loss 0.2217712551355362
val metric {'loss': 0.13252717150109156, 'loss_no_reg': 0.2632727324962616, 'corr': [0.2489061501674369, 0.4429272084508796, 0.42191214806822225, 0.44408555357557933, 0.6237332552563833, 0.4732914448432528, 0.37082004140428826, 0.41790649488525444, 0.3616185192573312, 0.15078707061083477, 0.4574108494663125, 0.14991129219861396, 0.39992056338621584, 0.2319842962010319, 0.4794134518359002, 0.15369543171562747, 0.3016874975128933, 0.22124403809069992, 0.24145878334115892, 0.46070978533925533, 0.4419903570051627, 0.528293390327593, 0.32292464515228175, 0.3924756922978002, 0.44059101887074503, 0.25501520841611564, 0.44992287880410836, 0.5041830698909748, 0.5946942260894752, 0.41388053555954457, 0.3623221503029438, 0.4556624226862427, 0.6120249778359954, 0.4450272185093362, 0.45202831859726317, 

1000-0, train loss 0.23978768289089203
train loss 0.23978768289089203
1100-0, train loss 0.2300010770559311
train loss 0.2300010770559311
1200-0, train loss 0.25294673442840576
train loss 0.25294673442840576
val metric {'loss': 0.12946148748908723, 'loss_no_reg': 0.2562335431575775, 'corr': [0.2370338117014749, 0.4681149285678013, 0.4200923365986798, 0.4670382010307643, 0.6126821019720234, 0.470337894356489, 0.37775300215542656, 0.41220765093510514, 0.3698637802626027, 0.15550804256894465, 0.48116144454824017, 0.16085280240020602, 0.5466230169199972, 0.260990343569532, 0.5177537922317613, 0.17830030789022538, 0.31376988695724667, 0.23979103575350444, 0.24231823107104558, 0.5115983364164964, 0.44543363000335795, 0.5784686038405547, 0.33633658702310665, 0.41832547108246093, 0.4846607950862578, 0.3895793073627023, 0.45275568180756137, 0.5160133739440751, 0.6045566560789062, 0.42306622220381934, 0.3629316381653696, 0.4458094747277014, 0.6234007652063714, 0.46070774355577426, 0.454285819932

1600-0, train loss 0.2813226580619812
train loss 0.2813226580619812
1700-0, train loss 0.242415651679039
train loss 0.242415651679039
1800-0, train loss 0.23701849579811096
train loss 0.23701849579811096
val metric {'loss': 0.12880197486707143, 'loss_no_reg': 0.25480857491493225, 'corr': [0.24036598119418695, 0.47465121696128615, 0.42349550572337447, 0.46689799803724247, 0.6192920506060332, 0.4609349485497905, 0.3783602967887545, 0.41161287395847973, 0.3606487237427195, 0.15959059928964564, 0.48023819629109726, 0.1799476509890073, 0.5558436259377844, 0.265706320265254, 0.5258753899558914, 0.18554694929473498, 0.3100094381634792, 0.2632630391999606, 0.24879126184839714, 0.5131338848906758, 0.4420386425294388, 0.5782085175571001, 0.33743360239309217, 0.42387190451606466, 0.4977406670405037, 0.40164091796390367, 0.4733639824669326, 0.5020592734521345, 0.6094157984761843, 0.4237825634450161, 0.37220253283137433, 0.45744751082429475, 0.6237196007685794, 0.4763742401773708, 0.441258177846393

2200-0, train loss 0.2636623680591583
train loss 0.2636623680591583
2300-0, train loss 0.22081314027309418
train loss 0.22081314027309418
2400-0, train loss 0.2709617614746094
train loss 0.2709617614746094
val metric {'loss': 0.12897393533161708, 'loss_no_reg': 0.255598783493042, 'corr': [0.25571208129402195, 0.47953327471170293, 0.43306188401784274, 0.46566121257287774, 0.6202109230922888, 0.4685907835976717, 0.39059012770928137, 0.4186191335073215, 0.3554801110630678, 0.15350287724391987, 0.47607276075233956, 0.19159471351718357, 0.5651196874548575, 0.2762665997276159, 0.5230617525796336, 0.1814123424408067, 0.30919108943151, 0.2564792136485628, 0.22418326786884266, 0.5043841815706589, 0.4460269805181714, 0.5824051558975318, 0.3359917825250749, 0.4196601695382747, 0.48736763668497524, 0.43349003467357855, 0.466453108023396, 0.50904094214951, 0.6057531156333991, 0.4161217423411775, 0.3620117189008014, 0.4536727800058049, 0.6220257358505575, 0.47211818354636675, 0.43429194583489605, 0.

2800-0, train loss 0.24423976242542267
train loss 0.24423976242542267
2900-0, train loss 0.23048752546310425
train loss 0.23048752546310425
3000-0, train loss 0.2737901508808136
train loss 0.2737901508808136
val metric {'loss': 0.12908870354294777, 'loss_no_reg': 0.25491204857826233, 'corr': [0.25816919215078793, 0.48100490821669145, 0.4322942666067398, 0.4721590955540591, 0.6268011359289163, 0.4676602824629863, 0.3795706880191484, 0.41931301163709445, 0.35332514035873297, 0.14065740881970418, 0.48147961763194075, 0.19140925175556853, 0.5542618086394177, 0.2680772462502662, 0.5242290427709947, 0.18776381761690325, 0.313393987202355, 0.2447745619299158, 0.2081717191996469, 0.5135110917380185, 0.4517477178487609, 0.5862360839506262, 0.3385601909511231, 0.4178462536092776, 0.49922529691498657, 0.4408175430680636, 0.47142611987162947, 0.5005157786592167, 0.6062740712417561, 0.42471343141200574, 0.36748593144337893, 0.4561799582879743, 0.6227473785283834, 0.4620924290707081, 0.4334464307518

3400-0, train loss 0.23462189733982086
train loss 0.23462189733982086
3500-0, train loss 0.23520562052726746
train loss 0.23520562052726746
3600-0, train loss 0.2604605555534363
train loss 0.2604605555534363
val metric {'loss': 0.12949559411832265, 'loss_no_reg': 0.2565039098262787, 'corr': [0.25158681184961046, 0.4895931590482408, 0.4422322394328152, 0.46989796751444146, 0.6241558345678248, 0.46794288687491714, 0.38284879458572824, 0.41236231660322337, 0.35485606914736406, 0.13633035096078214, 0.481257178662482, 0.1950774539844162, 0.5624259496363578, 0.2733166807254343, 0.5199931552462065, 0.18836700370640944, 0.311301998289616, 0.23401415747890317, 0.18876884689727352, 0.5057603745198189, 0.4478351453837718, 0.5821695568394984, 0.34172780681894926, 0.4207256372608424, 0.4961408968678774, 0.4255921492954382, 0.48196309216932887, 0.5061047839672218, 0.5966919155072189, 0.41450444453582286, 0.37467039605129315, 0.4506640145397693, 0.6204270917929467, 0.4682738537953899, 0.4329946238433

100-0, train loss 0.25902172923088074
train loss 0.25902172923088074
200-0, train loss 0.24384364485740662
train loss 0.24384364485740662
300-0, train loss 0.23024310171604156
train loss 0.23024310171604156
val metric {'loss': 0.12804021393614157, 'loss_no_reg': 0.25356757640838623, 'corr': [0.24774201933090384, 0.47956182907911693, 0.42861735064825807, 0.46938758895242294, 0.6197959666928982, 0.46649715569433253, 0.38151390727240014, 0.41643214889164215, 0.3592067039605938, 0.14756718989966489, 0.479942117499025, 0.19057533300398313, 0.5585965627983484, 0.2709957055842431, 0.5217714375103829, 0.18206574661823804, 0.3169172048248289, 0.25337133392193, 0.22750957602927865, 0.5152643683829133, 0.4484377758088517, 0.5828627559912847, 0.3395903214774901, 0.4198694404741782, 0.4978001991214075, 0.4272845649648819, 0.4751219101434805, 0.5039086333172915, 0.6099481597655747, 0.4270235859787011, 0.3716474919452204, 0.45336318479884247, 0.6215850590592107, 0.47201596562166764, 0.440522444326589

700-0, train loss 0.24759694933891296
train loss 0.24759694933891296
800-0, train loss 0.2633591294288635
train loss 0.2633591294288635
900-0, train loss 0.2476639300584793
train loss 0.2476639300584793
val metric {'loss': 0.12791026436856814, 'loss_no_reg': 0.2535460591316223, 'corr': [0.2509733285196499, 0.4823707588559286, 0.43251622464848954, 0.47088111536289473, 0.6246670591054253, 0.4767370132957161, 0.38624747498041917, 0.417451467411729, 0.3570816161142311, 0.1458817168399715, 0.48115621556420063, 0.19770801071151822, 0.5592349087991367, 0.2742692921653674, 0.524722509194119, 0.18202018172711432, 0.3177051706105275, 0.25565752111604634, 0.22445437233752433, 0.5129917910029985, 0.44764519540849057, 0.5853654291892765, 0.33942177217446273, 0.4167534656144152, 0.4935473020825677, 0.4274088233661847, 0.4748500991522907, 0.5019338373670362, 0.6114819282358513, 0.4213101877682218, 0.37186562278008123, 0.4539852414880057, 0.6252686944878915, 0.47107382392417896, 0.43332611223965684, 0

1300-0, train loss 0.22389091551303864
train loss 0.22389091551303864
1400-0, train loss 0.24210411310195923
train loss 0.24210411310195923
1500-0, train loss 0.24475573003292084
train loss 0.24475573003292084
val metric {'loss': 0.12810514947133406, 'loss_no_reg': 0.2537984848022461, 'corr': [0.2529782148250312, 0.48210440302622903, 0.4320759406784096, 0.4700723022428139, 0.6239879344372193, 0.4734968299892934, 0.3822406755560229, 0.4160874223687395, 0.3574577057093183, 0.14333385607388788, 0.48032687962703735, 0.19592240203747632, 0.5586615951741415, 0.2781621883145087, 0.5213770245533079, 0.1838358421480873, 0.3149364181798902, 0.25496385901539964, 0.22080939634692184, 0.5083270172991754, 0.44812706914106476, 0.58016547910446, 0.34343457542862105, 0.41777934551114476, 0.4948629295437341, 0.4320435368984991, 0.47525256862089277, 0.504176370940594, 0.6114869780957032, 0.42055241544209016, 0.37524139956993247, 0.45249762393651627, 0.6266128639273886, 0.4750171367696996, 0.4378416065589

1900-0, train loss 0.21892432868480682
train loss 0.21892432868480682
2000-0, train loss 0.24076442420482635
train loss 0.24076442420482635
2100-0, train loss 0.2573792338371277
train loss 0.2573792338371277
val metric {'loss': 0.12815565722329275, 'loss_no_reg': 0.2542612552642822, 'corr': [0.24797624011373134, 0.4818361751468645, 0.4334682234337525, 0.4652694468813166, 0.6241182345867682, 0.46753166999993856, 0.38372677940642874, 0.4172347205485676, 0.35469636015194356, 0.14650004631827868, 0.48083096947438747, 0.19917288393108518, 0.56102846142749, 0.27570851337312136, 0.5206220976395107, 0.1821495827015133, 0.31383445764927137, 0.2591069137592318, 0.21206685006045067, 0.5106141264151185, 0.44779210566804456, 0.5846328960453452, 0.3418185562264701, 0.4183577398596065, 0.5015640097281103, 0.42861525054201843, 0.4746350402262657, 0.50629312886749, 0.6095496144094458, 0.41987222792333184, 0.3717440997134384, 0.4535220497782385, 0.6257831036717987, 0.4712563698024928, 0.4342722489731345

100-0, train loss 0.2582601308822632
train loss 0.2582601308822632
200-0, train loss 0.22548827528953552
train loss 0.22548827528953552
300-0, train loss 0.1936332881450653
train loss 0.1936332881450653
val metric {'loss': 0.12781899023268903, 'loss_no_reg': 0.2533469498157501, 'corr': [0.25180126307470707, 0.4820917324579897, 0.43030231869336844, 0.47055557511531604, 0.6223912172551468, 0.46952134517785576, 0.3826650776144666, 0.41751647146580095, 0.35889381518635993, 0.14487693467638518, 0.47974849618039767, 0.19459900693569226, 0.5585887593532248, 0.2737006194363957, 0.5226548667934557, 0.18414122783724163, 0.31510424313281366, 0.2558549095602614, 0.2279635699402306, 0.5163177382162228, 0.4468423646643042, 0.5858321387884455, 0.33999517265557544, 0.42058120055773107, 0.4973498511486048, 0.42840674614031465, 0.4765821315932385, 0.5021216428320481, 0.6106096162357978, 0.42391459353300537, 0.36994594655083646, 0.4543670368694658, 0.6266041425840504, 0.47439976860103356, 0.4395839523222

700-0, train loss 0.23508524894714355
train loss 0.23508524894714355
800-0, train loss 0.25377166271209717
train loss 0.25377166271209717
900-0, train loss 0.238283172249794
train loss 0.238283172249794
val metric {'loss': 0.12790215787078654, 'loss_no_reg': 0.25355786085128784, 'corr': [0.25103437056847694, 0.4811210959707647, 0.43190129144180284, 0.4686550800391724, 0.622837007086664, 0.4684530937000882, 0.38130546125856957, 0.41754972760698367, 0.356752037209456, 0.14531695206656864, 0.4784745448478397, 0.19566600567283107, 0.5575416481151565, 0.27335059300686065, 0.5213005454576467, 0.1806503697545197, 0.31572808704782146, 0.2535400890777606, 0.2212940621387332, 0.516076557063498, 0.446312300633587, 0.5841409483411327, 0.3400521538230482, 0.4190098416828053, 0.49801162991443404, 0.4298437760479714, 0.4764673689926254, 0.5015831110503421, 0.6106054024837558, 0.4234047229385915, 0.3742394817813105, 0.45451520074797, 0.6263228142762364, 0.47256021448695945, 0.4385310648411649, 0.50525

1300-0, train loss 0.2613489031791687
train loss 0.2613489031791687
1400-0, train loss 0.26326271891593933
train loss 0.26326271891593933
1500-0, train loss 0.23571883141994476
train loss 0.23571883141994476
val metric {'loss': 0.12788234677697932, 'loss_no_reg': 0.2536022663116455, 'corr': [0.2503213042404424, 0.48064438639723817, 0.43025894288564537, 0.4692144131115562, 0.6206156093411063, 0.4690691701671108, 0.38155788466655366, 0.4178379064892543, 0.35725944352707134, 0.14453686494179208, 0.4791330982506485, 0.19767538177696312, 0.5580725253399498, 0.27660058732805126, 0.5203052035092464, 0.1818019665114586, 0.3158673403869492, 0.25464236279307517, 0.2181079152049995, 0.5154055254147079, 0.4472773865924407, 0.5853204808768235, 0.3409811364415988, 0.4188815079940643, 0.4971724271527724, 0.4268417875115791, 0.47634846455516194, 0.500783365091121, 0.6109682198896446, 0.42427481047159427, 0.37239800580906685, 0.4541033371536762, 0.6274913796016665, 0.4717324406659422, 0.439328728634703

1900-0, train loss 0.22918465733528137
train loss 0.22918465733528137
2000-0, train loss 0.28093022108078003
train loss 0.28093022108078003
2100-0, train loss 0.2693418264389038
train loss 0.2693418264389038
val metric {'loss': 0.12798239370541914, 'loss_no_reg': 0.25358596444129944, 'corr': [0.2546328900329053, 0.48219989797486573, 0.4312535372486304, 0.46900807280391416, 0.6218258884801057, 0.46927188888367066, 0.38397769225368283, 0.4172551448358972, 0.3571400276759161, 0.14503917123076948, 0.4771773398563201, 0.1994575505242578, 0.5570631172599183, 0.274739723680981, 0.5190754713874118, 0.18269612361774212, 0.3128069788563518, 0.25571727397306876, 0.21760491683125036, 0.5146490852427075, 0.4474534129214851, 0.5856641280994769, 0.33919838860555435, 0.41793896342536857, 0.49768673630686433, 0.4307968365756621, 0.47658997628284167, 0.5004599149578586, 0.6105900399842082, 0.42319432441038113, 0.3730177076239189, 0.4547805390869248, 0.6274826472106396, 0.47469874711323223, 0.43826258218

In [10]:
# get the non-zero part.

this_one = [0.22018256617130091, 0.4869975832065882, 0.47228129011513914, 0.45905626983797665, 0.621237405594686, 0.46426432226386516, 0.34839494436971774, 0.4123847375504384, 0.38500018054438234, 0.1898628542931494, 0.5180525210007833, 0.23161433036499138, 0.5582238525570329, 0.27508576617398095, 0.5493608659489684, 0.26259932188708796, 0.32779096278928316, 0.22763917357350913, 0.35131070560370914, 0.4753232617592147, 0.3781793256852377, 0.5637278413559904, 0.3328218505644466, 0.40935767563562797, 0.4642024124169721, 0.353047776321486, 0.48645058432861826, 0.5152167344722238, 0.6403300501840744, 0.4355465319790888, 0.41721890094817254, 0.4553353106083733, 0.6344076546953171, 0.5017943823781861, 0.46854368501937615, 0.501619193078052, 0.3250352249809067, 0.5642936521021515, 0.5560437667742284, 0.1509992188646114, 0.3702957224213098, 0.5069721506565019, 0.38195588441841427, 0.3807745431223978, 0.2855451515468333, 0.41305625223412124, 0.6550295393804862, 0.5128155783540855, 0.48251573300464173, 0.5408566738027406, 0.679629619792504, 0.5273247288309529, 0.5795956242510364, 0.511063872998033, 0.6541735578759779, 0.6703670311615281, 0.6909364056111341, 0.4458990840591408, 0.7481969873329024, 0.29835837216296035, 0.4518130701605749, 0.4255887417097827, 0.49378461644367827, 0.54655161759555, 0.5214053140061996, 0.5448183650157287, 0.5523762234410113, 0.3790843819226691, 0.528031214563742, 0.6100903551259104, 0.5342520312914681, 0.7505603539727742, 0.5548928919875635, 0.41199462900110634, 0.5391605141173166, 0.46245470677944994, 0.6380301529311603, 0.6023656588929015, 0.18458395736131225, 0.5177750383520018, 0.573612072316065, 0.5868815425493878, 0.6153848168601009, 0.7203689212778182, 0.631646639125167, 0.7616860377682483, 0.6839071675548605, 0.5934934831659342, 0.43962265168066, 0.5366753357100587, 0.733738621276102, 0.6469569437367578, 0.42962818673869185, 0.5994876280729442, 0.5630509724773918, 0.6330377734061688, 0.5020569672433688, 0.7816336715184735, 0.5016566966059309, 0.5301333643401223, 0.5435846386149144, 0.461490940614455, 0.5577139214372072, 0.7123123076103274, 0.5854145786710785, 0.6859403069332486, 0.5852227628919392, 0.6046339649959771, 0.5539064611684865, 0.4167772280998731, 0.5657065326094318, 0.40577058449613407, 0.6573483036373993, 0.6723945112676948, 0.6051675465467994]

In [11]:
# from rcnn_basic_kriegeskorte/test_full_training_ff_1st_cb19_data-handle_nan.ipynb

that_one = [0.22290867814140142, 0.4708074534173289, 0.4748005428071667, 0.4726406797165408, 0.6261681532403353, 0.47778410642306074, 0.355276401771913, 0.4059499988799383, 0.27306687854409467, 0.2060898428331426, 0.5066163009016792, 0.21927353933146942, 0.5601374651206583, 0.2686960626493656, 0.5464271834695443, 0.2654781651778954, 0.32077724062824814, 0.23326450810839897, 0.3647180023372379, 0.47421547330506253, 0.3785651481083519, 0.5456766079217238, 0.3364432552323184, 0.4076604394937908, 0.4603158349474313, 0.3521687998817147, 0.4759313387428043, 0.5159867024433585, 0.6517636553453058, 0.44227082061214457, 0.42572542630461446, 0.46016411728227063, 0.6354862858793571, 0.48589746898023056, 0.4898000842666996, 0.5060820441937934, 0.3344468595601764, 0.5575963223930174, 0.565608854549304, 0.14956376027860313, 0.3688827327773246, 0.5157210848584867, 0.37949105748215006, 0.38131425254857293, 0.29023178616260425, 0.4092873045632762, 0.6546056913526376, 0.5105307686644647, 0.4648774654180571, 0.5389574293108751, 0.6848201217249001, 0.5474474526035918, 0.5821566244878427, 0.51226782001166, 0.6609746645934759, 0.6702189332496112, 0.713059286114458, 0.45907239875451583, 0.7417166861688765, 0.30437573690915565, 0.4779997589486579, 0.427084702651328, 0.4953790066043401, 0.519572799621041, 0.5208000639149223, 0.5429269548689853, 0.5504635125870688, 0.3740523838454645, 0.5148547506060256, 0.617953049407335, 0.5446498963984046, 0.7483751444632495, 0.5351439216693811, 0.411975940704237, 0.5626833094802879, 0.48104878226225906, 0.6523571578566921, 0.6083369240852424, 0.15578991723678842, 0.5054916925086076, 0.5579364138230724, 0.5694297383452513, 0.6215502504670586, 0.7279874841137555, 0.6224736989941038, 0.761737156238404, 0.6861978108122229, 0.6067185997711673, 0.44793477697308465, 0.5357085593856187, 0.7331999385987633, 0.6475202233616871, 0.44000115178619764, 0.5841800693523288, 0.581523834811977, 0.6218072448705786, 0.5003692096558175, 0.7780975902298893, 0.49409784451955313, 0.527406207860786, 0.5361577988956837, 0.4588214757156862, 0.5553959701146822, 0.7058145972361495, 0.5886926663525995, 0.6853363853392981, 0.5815182730815066, 0.5858139427845256, 0.5553463898266813, 0.42168832978027504, 0.5526428668634553, 0.40211562266310247, 0.6608694233816872, 0.6747063506035611, 0.59971232260701]

In [12]:
from scipy.stats import pearsonr

In [13]:
# basically the same
print(pearsonr(this_one, that_one))

(0.9941515334648424, 4.478703016824586e-111)


In [14]:
np.mean(this_one), np.mean(that_one)

(0.5065552619018068, 0.505823562521005)

In [15]:
np.std(this_one), np.std(that_one)

(0.1333050489510738, 0.13470891707205376)