In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import exists
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from fastprogress import progress_bar
from scipy.spatial.distance import pdist, squareform
from matplotlib.patches import Rectangle


from jsputils import classes

import sys
sys.path.append('/home/jovyan/work/DropboxSandbox/GSN')
import gsn
from gsn.rsa_noise_ceiling import rsa_noise_ceiling

In [None]:
savedir = f'{os.getcwd()}/analysis_outputs/3c-NoiseCeilings'

In [None]:
subjs = [f'subj0{s}' for s in range(1,9)]
roi_list = ['FFA-1','FFA-2','OFA',
            'PPA','OPA','EBA','FBA-1','FBA-2',
            'VWFA-1','VWFA-2','OWFA']#,'PPA','EBA','VWFA-1']
ncsnr_threshold = 0.3

train_imageset = 'nonshared1000-3rep-batch0'
val_imageset = 'nonshared1000-3rep-batch1'
test_imageset = 'special515'

space = 'nativesurface'
beta_version = 'betas_fithrf_GLMdenoise_RR'
ncsnr_threshold = 0.3

In [None]:
def compute_gsn_noise_ceiling(data):
    
    threshold = 1000
    n_sims = 3
    nvox_per_sample = 1000
    
    data_gsn = np.transpose(copy.deepcopy(data), (2,0,1))
    
    rdmfuns = [lambda x: np.mean(x.T, axis = 1),
               lambda x: pdist(x.T,'correlation')]
    
    out = dict()
    
    ###########
    
    if data_gsn.shape[0] > threshold:
        
        out['ncs'] = []
        out['ncdists'] = []
        out['results'] = []
        
        for n in range(n_sims):
            
            # get indices of k random voxels from the ROI
            idx = np.random.choice(np.arange(data_gsn.shape[0]), nvox_per_sample)
            
            ncs, ncdists, results = rsa_noise_ceiling(data_gsn[idx], 
                                                     rdmfuns = rdmfuns,
                                                     wantverbose=False)
            
            out['ncs'].append(ncs)
            out['ncdists'].append(ncdists)
            out['results'].append(results)
            print(ncs)
        
    else:
        
        ncs, ncdists, results = rsa_noise_ceiling(data_gsn, 
                                                  rdmfuns = rdmfuns,
                                                  wantverbose=False)
        
        out['ncs'] = [ncs]
        out['ncdists'] = [ncdists]
        out['results'] = [results]
        print(ncs)
    
    return out
    

In [None]:
overwrite = False

for subj in progress_bar(subjs):
    
    NSDsubj = classes.fMRISubject(subj, space, beta_version)

    for roi in progress_bar(roi_list):
        
        print(roi, subj)
        
        savefn = f'{savedir}/GSN-NC_{roi}_{subj}_{test_imageset}_nc-{ncsnr_threshold}.npy'
        
        if exists(savefn) and overwrite is False: 
            
            print('skipping')
            
        else:

            ROI = classes.BrainRegion(NSDsubj, roi)
            ROI.load_betas()
            ROI.get_ncsnr_mask(threshold = ncsnr_threshold)
            ROI.load_encoding_data(train_imageset, val_imageset, test_imageset)
            encoder = classes.EncodingProcedure(ROI, DNN, 
                                         method = 'lasso', 
                                         positive = True,
                                         alphas = [0.1]) 

            y = encoder.get_encoding_voxels(mean = False)
            
            print(y['test'].shape)
            if y['test'].shape[2] > 2:
                NC = compute_gsn_noise_ceiling(y['test'])
                np.save(savefn, NC, allow_pickle=True)

            else:
                print('skipping, insufficient voxels')

In [None]:

data = []

for roi in progress_bar(roi_list):
    for subj in progress_bar(subjs):
        #print(roi, subj)
        
        savefn = f'{savedir}/GSN-NC_{roi}_{subj}_{test_imageset}_nc-{ncsnr_threshold}.npy'
        
        if exists(savefn):
            nc = np.load(savefn, allow_pickle=True).item()
            
            if len(nc['ncs']) > 1:
                this_ncs = np.mean(np.vstack(nc['ncs']),axis=0)
            else:
                this_ncs = np.array(nc['ncs'][0])
                
            #print(this_ncs)
            
            data.append([roi, subj, this_ncs[0], this_ncs[1]])
            
        else:
            print(savefn, 'does not exist')

df = pd.DataFrame(data, columns=['ROI', 'Subject', 'Univariate', 'RSA'])


In [None]:
# Custom Palette
palette = ['tomato']*3 + ['limegreen']*2 + ['dodgerblue']*3 + ['purple']*3


# Calculate unique ROIs and subjects
unique_rois = df['ROI'].unique()
unique_subjs = df['Subject'].unique()

# Create subplots
fig, ax = plt.subplots(2, 1, figsize=(16, 10))

# Marker size
marker_size = (0.12, 0.02)

# Metrics to plot
metrics = ['Univariate', 'RSA']

# Loop over each metric
for i, metric in enumerate(metrics):
    # Loop over each ROI
    for j, roi in enumerate(unique_rois):
        # Get the data for this ROI
        roi_data = df[df['ROI'] == roi]

        # Loop over each subject
        for k, subj in enumerate(unique_subjs):
            # Get the data for this subject
            subj_data = roi_data[roi_data['Subject'] == subj]

            # If there is no data for this subject in this ROI, skip this iteration
            if subj_data.empty:
                continue

            # Plot the data
            rect = Rectangle((j + k*0.1 - marker_size[0]/2, subj_data[metric].tolist()[0] - marker_size[1]/2),
                             marker_size[0], marker_size[1], color=palette[j])
            ax[i].add_patch(rect)

    # Set the x-ticks at the center of each group of subjects
    ax[i].set_xticks([j + (len(unique_subjs)-1)*0.05 for j in range(len(unique_rois))])

    # Set the x-tick labels to be the ROIs
    ax[i].set_xticklabels(unique_rois)

    # Set the titles and labels
    ax[i].set_title(f'{metric} Noise Ceiling')
    ax[i].set_xlabel('ROI')
    ax[i].set_ylabel('Noise Ceiling (r)')

    # Set the y-axis limits
    ax[i].set_ylim([0, 1])

    # Set the x-axis limits
    ax[i].set_xlim([-0.5, len(unique_rois)])

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Custom Palette
palette = ['tomato']*3 + ['limegreen']*2 + ['dodgerblue']*3 + ['purple']*3

# Calculate unique ROIs and subjects
unique_rois = df['ROI'].unique()
unique_subjs = df['Subject'].unique()

# Create subplots
fig, ax = plt.subplots(2, 1, figsize=(20, 12))

# Metrics to plot
metrics = ['Univariate', 'RSA']

# Space between ROIs
space = 0.3

# Loop over each metric
for i, metric in enumerate(metrics):
    # Create secondary x-axis for ROIs
    ax2 = ax[i].twiny()

    # Loop over each ROI
    for j, roi in enumerate(unique_rois):
        # Get the data for this ROI
        roi_data = df[df['ROI'] == roi]

        # Loop over each subject
        for k, subj in enumerate(unique_subjs):
            # Get the data for this subject
            subj_data = roi_data[roi_data['Subject'] == subj]

            # If there is no data for this subject in this ROI, skip this iteration
            if subj_data.empty:
                continue

            # Calculate the x-position
            x_pos = j * (1 + space) + k * 0.1

            # Plot the data with dots instead of rectangles
            ax[i].scatter(x_pos, subj_data[metric].tolist()[0], color=palette[j], s=30)

            # Add a stem (vertical line) going up to the dot
            ax[i].vlines(x_pos, 0, subj_data[metric].tolist()[0], color=palette[j], linestyle='dashed')

    # Set the x-ticks at the center of each group of subjects
    ax[i].set_xticks([(j * (1 + space)) + k * 0.1 for j in range(len(unique_rois)) for k in range(len(unique_subjs))])

    # Set the x-tick labels to be the subjects
    ax[i].set_xticklabels([f'subj{str(k+1).zfill(2)}' for j in range(len(unique_rois)) for k in range(len(unique_subjs))], rotation='vertical', fontsize='small')

    # Set the secondary x-ticks and labels to be the ROIs
    ax2.set_xticks([(j * (1 + space)) + ((len(unique_subjs) - 1) * 0.05) for j in range(len(unique_rois))])
    ax2.set_xticklabels(unique_rois,fontsize=14)

    # Set the titles and labels
    ax[i].set_title(f'{metric} Noise Ceiling',fontsize=20)
    #ax[i].set_xlabel('Subject')
    #ax2.set_xlabel('ROI')  # The ROI label is now on the secondary x-axis
    ax[i].set_ylabel('Noise Ceiling (r)',fontsize=14)

    # Set the y-axis limits
    ax[i].set_ylim([0, 1])

    # Set the x-axis limits
    ax[i].set_xlim([-0.5, len(unique_rois) * (1 + space)])
    ax2.set_xlim(ax[i].get_xlim())  # Set the secondary x-axis limits to be the same as the primary x-axis

plt.tight_layout()
plt.show()
