In [1]:
import os, sys
import numpy as np
import pandas as pd

codepath = '/user_data/mmhender/image_stats_gabor/code/'
sys.path.insert(0,codepath)

from plotting import summary_plots, load_fits
from utils import roi_utils, default_paths

#### Load fit results

In [2]:
subjects = np.arange(1,9)
n_subjects = len(subjects)

feature_type_list_long = ['gabor_solo_ridge_12ori_8sf',
                          'alexnet_all_conv_pca']
feature_type_list = ['gabor','alexnet']

n_models = len(feature_type_list)
out_list = []
for fi, fitting_type in enumerate(feature_type_list_long):

    out = [load_fits.load_fit_results(subject=ss, fitting_type=fitting_type, \
                                      n_from_end=0, verbose=False) \
           for ss in subjects]
    out_list.append(out)
    
out = out_list
    
fig_save_folder=None

# create ROI definitions
roi_def = roi_utils.multi_subject_roi_def(subjects, remove_ret_overlap=True, \
                                          remove_categ_overlap=True)
roi_names =roi_def.roi_names
n_rois = roi_def.n_rois

n_vox_each_subj = [roi_def.ss_roi_defs[si].retlabs.shape[0] for si in range(n_subjects)]
subject_inds = np.concatenate([si*np.ones((n_vox_each_subj[si],),dtype=int) \
                               for si in range(n_subjects)], axis=0)


#### Plot raw counts

In [3]:
counts = np.zeros((n_subjects, n_rois),dtype=int)

for ss in range(n_subjects):
  
    for rr in range(n_rois):
        
        counts[ss,rr] = np.sum(roi_def.ss_roi_defs[ss].get_indices(rr))
        
counts = np.concatenate([counts, np.sum(counts, axis=0, keepdims=True)], axis=0)
roi_df = pd.DataFrame(data=counts, columns=roi_names, \
                      index=['S%02d'%ss for ss in subjects] + ['total'])

roi_df.to_csv(os.path.join(default_paths.fig_path,'roi_defs_raw.csv'))
roi_df


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
S01,2392,2096,1632,568,497,1921,1611,1033,566,355,794,2798
S02,1630,1496,1467,447,555,2036,1380,994,813,441,869,3195
S03,2275,1732,1412,362,656,2117,1330,1251,838,701,1093,3237
S04,1526,1297,1139,380,332,2131,1235,960,813,602,942,2889
S05,1705,1433,1219,506,358,1532,1332,1221,771,782,907,4050
S06,1771,1719,1799,788,373,1792,1362,1229,845,519,826,4005
S07,1926,1475,1275,418,586,1452,1083,912,694,316,484,2785
S08,1791,1494,1355,417,359,1866,1360,946,799,331,1204,2901
total,15016,12742,11298,3886,3716,14847,10693,8546,6139,4047,7119,25860


#### Threshold by noise ceiling
(this is the set we use to plot overall R2 for each model)

In [4]:
counts = np.zeros((n_subjects, n_rois),dtype=int)

for ss in range(n_subjects):
    
    # take all voxels that are above a noise ceiling threshold
    abv_nc_thresh = summary_plots.get_noise_ceiling(out[0][ss])>0.01
      
    for rr in range(n_rois):
        
        counts[ss,rr] = np.sum(roi_def.ss_roi_defs[ss].get_indices(rr) & abv_nc_thresh)
        
counts = np.concatenate([counts, np.sum(counts, axis=0, keepdims=True)], axis=0)
roi_df = pd.DataFrame(data=counts, columns=roi_names, \
                      index=['S%02d'%ss for ss in subjects] + ['total'])

roi_df.to_csv(os.path.join(default_paths.fig_path,'roi_defs_thresh_nc.csv'))
roi_df


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
S01,2262,1980,1605,564,471,1891,1592,997,531,353,786,2788
S02,1590,1429,1422,437,529,1980,1371,969,797,421,850,3156
S03,2174,1579,1337,354,612,1946,1303,1179,737,689,1026,3135
S04,1430,1196,1082,375,306,1897,1204,912,731,579,910,2806
S05,1685,1418,1208,502,351,1503,1326,1190,742,773,884,3962
S06,1760,1691,1773,780,367,1771,1351,1196,803,518,804,3972
S07,1877,1443,1215,413,543,1413,1079,874,658,313,450,2736
S08,1516,1279,1111,372,242,1469,1263,846,623,288,1116,2701
total,14294,12015,10753,3797,3421,13870,10489,8163,5622,3934,6826,25256


#### Counts after thresholding by Gabor model R2

In [11]:
counts = np.zeros((n_subjects, n_rois),dtype=int)

for ss in range(n_subjects):
    
    # take all voxels that are above r2 threshold for AlexNet model
    abv_r2_thresh = out[0][ss]['val_r2'][:,0]>0.01
    
    for rr in range(n_rois):
        
        counts[ss,rr] = np.sum(roi_def.ss_roi_defs[ss].get_indices(rr) & abv_r2_thresh)
        
counts = np.concatenate([counts, np.sum(counts, axis=0, keepdims=True)], axis=0)
roi_df = pd.DataFrame(data=counts, columns=roi_names, \
                      index=['S%02d'%ss for ss in subjects] + ['total'])

roi_df.to_csv(os.path.join(default_paths.fig_path,'roi_defs_gabor_thresh.csv'))
roi_df


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
S01,1679,1501,1328,474,282,716,1204,725,350,263,560,2049
S02,1197,990,1053,388,272,892,1104,775,683,295,603,2149
S03,1358,1096,940,270,325,556,823,759,361,520,506,1576
S04,904,850,717,262,149,318,782,635,541,350,526,1560
S05,1251,1073,859,399,231,766,1142,956,432,594,552,2362
S06,1232,1161,1241,575,172,565,823,775,295,343,409,2361
S07,1234,1001,655,292,118,479,815,565,293,208,246,1683
S08,911,784,663,276,41,121,550,492,217,178,654,1018
total,9766,8456,7456,2936,1590,4413,7243,5682,3172,2751,4056,14758


#### Counts after thresholding by AlexNet model R2

In [12]:
counts = np.zeros((n_subjects, n_rois),dtype=int)

for ss in range(n_subjects):
    
    # take all voxels that are above r2 threshold for AlexNet model
    abv_r2_thresh = out[1][ss]['val_r2'][:,0]>0.01
    
    for rr in range(n_rois):
        
        counts[ss,rr] = np.sum(roi_def.ss_roi_defs[ss].get_indices(rr) & abv_r2_thresh)
        
counts = np.concatenate([counts, np.sum(counts, axis=0, keepdims=True)], axis=0)
roi_df = pd.DataFrame(data=counts, columns=roi_names, \
                      index=['S%02d'%ss for ss in subjects] + ['total'])

roi_df.to_csv(os.path.join(default_paths.fig_path,'roi_defs_alexnet_thresh.csv'))
roi_df


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
S01,1709,1590,1397,520,344,1187,1428,851,403,304,648,2438
S02,1261,1093,1146,414,335,1344,1236,841,713,344,686,2605
S03,1526,1185,1031,305,407,1123,1056,898,391,586,721,2189
S04,1022,952,855,327,201,845,1004,757,585,445,702,2139
S05,1421,1209,989,446,263,1061,1247,1029,527,688,688,2973
S06,1444,1263,1470,666,253,1011,1115,945,479,437,552,3004
S07,1364,1124,753,349,242,800,952,670,377,248,323,2161
S08,954,871,727,311,66,357,799,601,257,198,772,1510
total,10701,9287,8368,3338,2111,7728,8837,6592,3732,3250,5092,19019
