In [10]:
import os, sys

codepath = '/user_data/mmhender/image_stats_gabor/code/'
sys.path.insert(0,codepath)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
import scipy.stats
import cmocean
import pandas as pd

from plotting import plot_utils, load_fits
from utils import roi_utils, default_paths
from feature_extraction import gabor_feature_extractor


#### Load model fits

In [2]:
subjects = np.arange(1,9)

n_subjects = len(subjects)

fitting_type = 'gabor_solo_ridge_12ori_8sf'

trial_subsets = ['balance_indoor_outdoor','outdoor_only','indoor_only',\
                 'balance_animacy', 'animate_only', 'inanimate_only',\
                 'balance_real_world_size_binary','large_only', 'small_only']

n_trial_subsets = len(trial_subsets)

out_all = []
for ti, ts in enumerate(trial_subsets):
    
    if ts=='all_trials':
        ft = fitting_type
    else:
        ft ='%s_%s'%(fitting_type, ts)

    out = [load_fits.load_fit_results(subject=ss, fitting_type=ft, \
                                      n_from_end=0, \
                                  verbose=False) \
           for ss in subjects]
    out_all.append(out)
    
out = out_all

fig_save_folder=None

roi_def = roi_utils.multi_subject_roi_def(subjects, \
                                          remove_ret_overlap=True, \
                                          remove_categ_overlap=True)
roi_names =roi_def.roi_names
n_rois = len(roi_names)

#### Get info about the Gabor feature space

In [3]:
n_sf=8; n_ori=12;
_gabor_ext_complex = gabor_feature_extractor.gabor_extractor_multi_scale(n_ori=n_ori, n_sf=n_sf)

screen_eccen_deg=8.4
sf_cyc_per_stim = _gabor_ext_complex.feature_table['SF: cycles per stim']
sf_cyc_per_deg = sf_cyc_per_stim/screen_eccen_deg
sf_unique, sf_inds = np.unique(sf_cyc_per_deg, return_inverse=True)

ori_deg = _gabor_ext_complex.feature_table['Orientation: degrees']
ori_unique, orient_inds = np.unique(ori_deg, return_inverse=True)


#### threshold voxels within groups by R2

groups are one semantic axis at a time. 

For instance, group 1 is [balanced indoor outdoor/just outdoor/just indoor], each w same number of trials.
Always use the same set of voxels when comparing within a group, so threshold them each here.

In [8]:

subsets2group = [[0,1,2], [3,4,5], [6,7,8]];
group_names=[trial_subsets[ti].split('balance_')[1] for ti in [0,3,6]]
n_groups = 3;

group_inds = [np.where([ti in group for group in subsets2group])[0][0] \
    for ti in range(n_trial_subsets)]


voxels_use = [ [[] for ss in subjects] for ai in range(n_groups)]

min_trials = 10;

n_prfs = 1456
n_trn_trials = np.zeros((n_subjects, n_prfs, n_groups))
n_val_trials = np.zeros((n_subjects, n_prfs, n_groups))
n_out_trials = np.zeros((n_subjects, n_prfs, n_groups))

for si, subject in enumerate(subjects):

    for gi, group in enumerate(subsets2group):
               
        # using all voxels with R2 above threshold for all fits in this group
        # make matrix [n_voxels x 3]
        val_r2_alltrials = np.array([out[xi][si]['val_r2'][:,0] for xi in group])
    
        voxels_use[gi][si] = np.all(val_r2_alltrials>0.01, axis=0)        
        
        # load the file that has the sub-sampled trial order in it.
        # use this to get the exact number of trials that were used, so that 
        # we can make sure it was not too small.
        # there is one of these files for each sub-sampled order in each "group"
        # but can just load one since they have the same num trials.
        subset = trial_subsets[group[0]]
        subset = 'both_'+subset.split('balance_')[1]
        fn2load = os.path.join(default_paths.stim_labels_root, 'resampled_trial_orders',\
                       'S%d_trial_resamp_order_%s.npy'%\
                               (subject, subset)) 
        if si==0:
            print(group)
            print(np.array(trial_subsets)[np.array(group)])       
            print('loading balanced trial order (pre-computed) from %s'%fn2load)
            
        trials = np.load(fn2load, allow_pickle=True).item()
    
        assert(np.all(trials['min_counts_trn']>=min_trials))
        assert(np.all(trials['min_counts_val']>=min_trials))
        assert(np.all(trials['min_counts_out']>=min_trials))
        
        n_trn_trials[si,:,gi] = trials['min_counts_trn']
        n_val_trials[si,:,gi] = trials['min_counts_val']
        n_out_trials[si,:,gi] = trials['min_counts_out']


[0, 1, 2]
['balance_indoor_outdoor' 'outdoor_only' 'indoor_only']
loading balanced trial order (pre-computed) from /user_data/mmhender/nsd/labels/resampled_trial_orders/S1_trial_resamp_order_both_indoor_outdoor.npy
[3, 4, 5]
['balance_animacy' 'animate_only' 'inanimate_only']
loading balanced trial order (pre-computed) from /user_data/mmhender/nsd/labels/resampled_trial_orders/S1_trial_resamp_order_both_animacy.npy
[6, 7, 8]
['balance_real_world_size_binary' 'large_only' 'small_only']
loading balanced trial order (pre-computed) from /user_data/mmhender/nsd/labels/resampled_trial_orders/S1_trial_resamp_order_both_real_world_size_binary.npy


#### Print how many trials were actually included in each sub-sampled order
different for different pRFs; because depends on pRF-specific labels

##### Training

In [11]:
column_names = [trial_subsets[ti].split('balance_')[1] for ti in [0,3,6]]
column_names = [c + ': median' for c in column_names] + [c + ': min' for c in column_names]

pd.DataFrame(np.concatenate([np.median(n_trn_trials, axis=1).astype(int), \
                                          np.min(n_trn_trials, axis=1).astype(int)], \
                                         axis=1), columns = column_names, index=subjects)

Unnamed: 0,indoor_outdoor: median,animacy: median,real_world_size_binary: median,indoor_outdoor: min,animacy: min,real_world_size_binary: min
1,2518,1338,1994,2518,364,292
2,2440,1370,2140,2440,380,256
3,2290,1280,1948,2290,286,242
4,2248,1270,1890,2248,364,238
5,2478,1360,2028,2478,342,250
6,2240,1312,1984,2240,338,248
7,2448,1421,2094,2448,336,246
8,2242,1277,1918,2242,300,252


##### Validation set

In [12]:
pd.DataFrame(np.concatenate([np.median(n_val_trials, axis=1).astype(int), \
                                          np.min(n_val_trials, axis=1).astype(int)], \
                                         axis=1), columns = column_names, index=subjects)

Unnamed: 0,indoor_outdoor: median,animacy: median,real_world_size_binary: median,indoor_outdoor: min,animacy: min,real_world_size_binary: min
1,242,218,279,242,16,12
2,242,218,279,242,16,12
3,226,206,254,226,14,10
4,218,198,248,218,12,10
5,242,218,279,242,16,12
6,226,206,254,226,14,10
7,242,218,279,242,16,12
8,218,198,248,218,12,10


##### Held-out test set 

In [13]:
pd.DataFrame(np.concatenate([np.median(n_out_trials, axis=1).astype(int), \
                                          np.min(n_out_trials, axis=1).astype(int)], \
                                         axis=1), columns = column_names, index=subjects)

Unnamed: 0,indoor_outdoor: median,animacy: median,real_world_size_binary: median,indoor_outdoor: min,animacy: min,real_world_size_binary: min
1,246,170,246,246,36,28
2,262,154,220,262,38,28
3,242,134,218,242,32,22
4,230,130,220,230,26,22
5,274,148,228,274,38,26
6,274,126,196,274,26,32
7,240,148,230,240,38,26
8,250,146,214,250,32,24


#### Print how many voxels can be used for each of these analyses
This is just based on R2, nothing to do with the trial counts. 
Since all the pRFs had enough to do the analysis

In [14]:
orig_sizes = np.zeros((n_subjects, n_rois),dtype=int)

thresh_sizes = np.zeros((n_subjects, n_rois, n_groups),dtype=int);
for si in range(n_subjects):
    for ri in range(n_rois):
        for gi in range(n_groups):
            thresh_sizes[si,ri,gi] = np.sum(roi_def.ss_roi_defs[si].get_indices(ri) \
                                 & voxels_use[gi][si])
            
            orig_sizes[si,ri] = np.sum(roi_def.ss_roi_defs[si].get_indices(ri))

In [15]:
print('original (unthresholded) sizes')
pd.DataFrame(orig_sizes, index=subjects, columns=roi_names)

original (unthresholded) sizes


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
1,2392,2096,1632,568,497,1921,1611,1033,566,355,794,2798
2,1630,1496,1467,447,555,2036,1380,994,813,441,869,3195
3,2275,1732,1412,362,656,2117,1330,1251,838,701,1093,3237
4,1526,1297,1139,380,332,2131,1235,960,813,602,942,2889
5,1705,1433,1219,506,358,1532,1332,1221,771,782,907,4050
6,1771,1719,1799,788,373,1792,1362,1229,845,519,826,4005
7,1926,1475,1275,418,586,1452,1083,912,694,316,484,2785
8,1791,1494,1355,417,359,1866,1360,946,799,331,1204,2901


In [16]:
print('indoor vs outdoor')
pd.DataFrame(thresh_sizes[:,:,0], index=subjects, columns=roi_names)

indoor vs outdoor


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
1,1316,1139,967,324,115,222,577,392,169,146,317,1154
2,960,722,777,293,132,260,583,538,455,161,387,1069
3,1021,788,547,124,79,71,194,334,129,279,245,672
4,621,572,330,115,32,47,262,294,300,137,191,379
5,950,715,529,232,103,228,691,626,202,370,287,1087
6,876,734,592,215,60,96,256,383,76,191,179,830
7,856,645,345,157,29,99,361,246,93,101,118,610
8,613,477,339,172,10,8,100,186,52,58,239,224


In [17]:
print('animacy')
pd.DataFrame(thresh_sizes[:,:,1], index=subjects, columns=roi_names)

animacy


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
1,1114,987,876,273,112,105,358,248,151,114,73,219
2,832,650,768,290,121,114,315,246,167,149,143,327
3,938,691,493,108,97,36,60,82,88,153,18,210
4,567,474,297,60,37,15,140,114,239,71,21,91
5,889,684,474,155,110,80,324,204,81,171,57,323
6,794,631,469,169,24,30,169,215,64,32,5,99
7,822,647,358,171,20,42,222,125,106,30,38,151
8,518,375,283,131,6,2,64,65,48,26,50,84


In [18]:
print('real-world size')
pd.DataFrame(thresh_sizes[:,:,2], index=subjects, columns=roi_names)

real-world size


Unnamed: 0,V1,V2,V3,hV4,V3ab,IPS,OPA,PPA,RSC,OFA,FFA,EBA
1,978,848,846,271,139,261,686,439,208,121,290,1115
2,769,563,697,280,139,320,650,594,531,127,355,1113
3,900,667,480,133,119,98,227,389,166,288,246,737
4,553,443,299,102,40,47,304,321,375,147,204,774
5,931,714,544,223,124,278,749,678,224,369,271,1255
6,725,611,500,229,65,153,397,454,111,132,159,1081
7,705,549,283,163,38,137,446,294,119,110,116,786
8,468,375,315,148,12,13,192,260,60,63,214,263
