In [1]:
# you need to first run `/model_fitting/cnn_popluation/sub.py`

In [2]:
import pandas as pd

In [3]:
from itertools import product

In [4]:
from maskcnn import postprocess, training_aux_wrapper
from tang_jcompneuro.cell_stats import compute_ccmax
from tang_jcompneuro.cell_classification import get_ready_to_use_classification
from tang_jcompneuro import dir_dictionary
import os.path

In [5]:
class_dict_mka = get_ready_to_use_classification()['MkA_Shape']

In [6]:
opt_names = list(training_aux_wrapper.all_opt_configs.keys())

In [7]:
opt_names

['poisson_10000',
 'mse_10000',
 'poisson_1000',
 'mse_1000',
 'poisson_100',
 'mse_100',
 'poisson_10',
 'mse_10',
 'poisson_1',
 'mse_1']

In [8]:
arch_names = list(training_aux_wrapper.gen_all_arch_config('MkA_Shape', 'all', 'OT').keys())

In [9]:
arch_names

['5_3_100',
 '7_3_100',
 '9_3_100',
 '11_3_100',
 '13_3_100',
 '5_3_75',
 '7_3_75',
 '9_3_75',
 '11_3_75',
 '13_3_75',
 '5_3_50',
 '7_3_50',
 '9_3_50',
 '11_3_50',
 '13_3_50',
 '5_3_25',
 '7_3_25',
 '9_3_25',
 '11_3_25',
 '13_3_25']

In [10]:
# for each arch, report the one with best test performance (so it's overfitting)
# I just want to see how good they can be.
def collect_one_model_performance(neuron_subset):
    # return the one with highest mean ccnorm_5^2 score.
    ccmax_this = compute_ccmax('MkA_Shape', 'all', 5)
    assert ccmax_this.shape == class_dict_mka[neuron_subset].shape
    ccmax_this = ccmax_this[class_dict_mka[neuron_subset]]
    
    dt_all_this = []
    
    # ok. just load
    for arch_name, opt_name in product(arch_names, opt_names):
        print(arch_name, opt_name)
        corr_this_1 = postprocess.load_model_performance('MkA_Shape', 'all', neuron_subset, 0, arch_name, opt_name)['corr']
        assert corr_this_1.shape == ccmax_this.shape
        corr_this_1 /= ccmax_this
        corr_this_2 = postprocess.load_model_performance('MkA_Shape', 'all', neuron_subset, 1, arch_name, opt_name)['corr']
        assert corr_this_2.shape == ccmax_this.shape
        corr_this_2 /= ccmax_this

        dt_all_this.append({
            'arch': arch_name,
            'opt': opt_name,
            'score': ((corr_this_1**2 + corr_this_2**2)/2).mean(),
            'score_raw': (corr_this_1**2 + corr_this_2**2)/2
        })
    dt_all_this = pd.DataFrame(dt_all_this, columns=['arch', 'opt', 'score', 'score_raw'])
    dt_all_this = dt_all_this.set_index(['arch', 'opt'], verify_integrity=True).sort_index()
    return dt_all_this

In [11]:
dir_to_save = os.path.join(dir_dictionary['analyses'], 'cnn_population')
os.makedirs(dir_to_save, exist_ok=True)
HO_file = os.path.join(dir_to_save, 'MkA_HO.hdf5')
if not os.path.exists(HO_file):
    HO_perm = collect_one_model_performance('HO')
    HO_perm.to_pickle(HO_file)
else:
    HO_perm = pd.read_pickle(HO_file)
    
    
OT_file = os.path.join(dir_to_save, 'MkA_OT.hdf5')
if not os.path.exists(OT_file):
    OT_perm = collect_one_model_performance('OT')
    OT_perm.to_pickle(OT_file)
else:
    OT_perm = pd.read_pickle(OT_file)

In [12]:
HO_perm['score'].unstack('opt')
# basically nothing good. just pick one example should be fine.

opt,mse_1,mse_10,mse_100,mse_1000,mse_10000,poisson_1,poisson_10,poisson_100,poisson_1000,poisson_10000
arch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
11_3_100,0.399199,0.410067,0.419266,0.410759,0.361716,0.409338,0.408817,0.411599,0.416296,0.374328
11_3_25,0.310795,0.303171,0.321143,0.321133,0.286965,0.307785,0.307979,0.319969,0.336827,0.309618
11_3_50,0.369067,0.379955,0.383386,0.38691,0.338444,0.371989,0.375304,0.389621,0.381568,0.357645
11_3_75,0.397475,0.395505,0.409628,0.402732,0.357568,0.402417,0.39526,0.390011,0.418665,0.366693
13_3_100,0.421422,0.425573,0.422722,0.425343,0.360856,0.416911,0.420422,0.42626,0.428312,0.377287
13_3_25,0.318987,0.308643,0.326237,0.322736,0.279274,0.316925,0.318703,0.319765,0.341867,0.318486
13_3_50,0.378448,0.382162,0.384678,0.377676,0.333847,0.392466,0.384741,0.392768,0.387769,0.353227
13_3_75,0.395974,0.398413,0.407364,0.406249,0.351292,0.414627,0.415396,0.406563,0.411332,0.374861
5_3_100,0.258368,0.273673,0.308568,0.335048,0.310052,0.262647,0.266401,0.254096,0.359279,0.320075
5_3_25,0.261562,0.266322,0.260754,0.306232,0.273222,0.262208,0.261648,0.26624,0.309833,0.310306


In [13]:
OT_perm['score'].unstack('opt')

opt,mse_1,mse_10,mse_100,mse_1000,mse_10000,poisson_1,poisson_10,poisson_100,poisson_1000,poisson_10000
arch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
11_3_100,0.484178,0.486279,0.523496,0.522604,0.453815,0.490168,0.500603,0.490377,0.516349,0.471477
11_3_25,0.411236,0.41589,0.425472,0.420416,0.372857,0.417995,0.413199,0.421729,0.427902,0.398613
11_3_50,0.479872,0.47517,0.478429,0.475144,0.422511,0.476123,0.464528,0.472762,0.470014,0.443042
11_3_75,0.50237,0.496062,0.506977,0.50135,0.442766,0.487234,0.481382,0.496623,0.504132,0.461329
13_3_100,0.51986,0.51011,0.520724,0.51123,0.465207,0.514292,0.515988,0.521773,0.521024,0.479272
13_3_25,0.414743,0.417628,0.411659,0.411245,0.375285,0.41852,0.415134,0.423411,0.421938,0.387691
13_3_50,0.47538,0.466297,0.473827,0.475775,0.426995,0.470828,0.475008,0.475671,0.484725,0.443482
13_3_75,0.501945,0.502512,0.497516,0.494683,0.447089,0.499774,0.505949,0.491439,0.504136,0.456436
5_3_100,0.395952,0.399841,0.386465,0.416588,0.419087,0.382637,0.398809,0.441832,0.43725,0.404659
5_3_25,0.381836,0.372955,0.385247,0.405395,0.376542,0.371552,0.370725,0.385914,0.403044,0.385551
