In [None]:
import os
import math
import os.path as pth
from scipy import io
import numpy as np
import nibabel as nib
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.rcParams.update({'font.size': 13})

In [None]:
base = '/users/jmy/data/image_sound'
mod = 'i'
data_base_path = pth.join('/data', '01_experiment_data', 'image_sound', 'prep_new_template')
label_list = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven',
                'eight', 'nine', 'bed', 'bird', 'cat', 'dog', 'house', 'tree']

sbjs_num = 15
sbj_list = [25,26,29,30,31,32,33,34,37,38,39,40,41,43,44]
interest_list = ['Digit','Object','Vs','Magnitude','Animacy']

In [None]:
# Preparing plot
sns.set(style='white',font_scale=1.3)
col_list = sns.color_palette('Spectral_r',15)
col_arr = np.array(col_list)
xticks = np.arange(15)
xlabels=['1','2','3','4','5','6','7','8','9','10','11','12','13','Fc','Out\nput']

In [None]:
# cnn rdm load
cdd_n, epoc = 5, 56
depth = 15
cnnrdms = np.zeros((sbjs_num,depth,16,16))
for li in range(depth):
    for sbj_idx in range(sbjs_num):
        filename = '_ind192to16_90_P{}_cdd0{}_L{}.npz'.format(str(sbj_list[sbj_idx]),str(cdd_n),str(li+1).zfill(2))
        savepath = base + '/RDMs/CNN/PCA/16x16/' + filename
        cnnrdms[sbj_idx,li] = np.load(savepath)['rdm']

In [None]:
# roi nii load
tot_roin = 51
roi_masks = np.zeros((tot_roin,64,76,64))
cnt = 0
stim_type_list_roi = []
for ti, type_ in enumerate(interest_list):
    fchar = type_[0].lower()
    path = base+'/01_STEP/16x16/offd/5e-02_'+fchar+'l2_20_r3_mask+tlrc.HEAD'
    tmp_vol = nib.load(path).get_fdata().squeeze()
    tmpn = int(tmp_vol.max())
    for roin in range(tmpn):
        roi_masks[cnt][(tmp_vol==roin+1)] = 1
        stim_type_list_roi.append(stim_type_list[ti])
        cnt += 1

In [None]:
# Calculate rhos
cd_n = len(interest_list)
stim_type_list = ['dpt','opt','offd','mag','opt']

def reorder_ani(rdm_org):
    rdm_reord = np.zeros((6,6))
    rdm_reord[:5,:5] = rdm_org[11:,11:] # 축 1이 아래 방향으로
    rdm_reord[5,5] = rdm_org[10,10]
    rdm_reord[:5,5] = rdm_org[11:,10]
    rdm_reord[5,0:5] = rdm_org[10,11:]
    return rdm_reord
    
def rsa(stim_type,cnnrdm):
    if stim_type == 'dpt':
        iu = np.triu_indices(10,1)
        rho, p = stats.spearmanr(roirdm[:10,:10][iu],cnnrdm[:10,:10][iu])
    elif stim_type == 'opt':
        iu = np.triu_indices(6,1)
        rho, p = stats.spearmanr(roirdm[10:,10:][iu],cnnrdm[10:,10:][iu])
    elif stim_type == 'offd':
        rho, p = stats.spearmanr(roirdm[:10,10:].flatten(),cnnrdm[:10,10:].flatten())
    elif stim_type == 'mag': 
        rho, p = stats.spearmanr(roirdm[7:10,1:4].flatten(),cnnrdm[7:10,1:4].flatten())
    else: #stim_type == 'ani': 
        rho, p = stats.spearmanr(reorder_ani(roirdm)[3:,:3].flatten(),reorder_ani(cnnrdm)[3:,:3].flatten())
    return rho, p 

def save_nifti(vol, savepath):
    img1 = nib.Nifti1Image(vol,affine=affine)
    nib.save(img1,savepath)

sv_key= 'PCA3_90'
savebase = base+'/RSA_Wholebrain_CNN/'+sv_key
# os.makedirs(savebase)

import time
for sbj_idx in range(15): 
    start_time = time.time()
    # load brain rdm
    load_path = base+'/RDMs/Neural/16x16/P'+str(sbj_list[sbj_idx])+'_i.npz' # r3
    l = np.load(load_path,allow_pickle=True)
    brain_rdm = l['rdm']
    brain_inform_list = l['info_list']
    nvv = brain_rdm.shape[0]
    for ti, type_ in enumerate(interest_list):
        if type_ in ['Animacy','Magnitude']:
            sv_key = type_+'_3x3'
        else:    
            sv_key = type_+'_'+stim_type
        stim_type = stim_type_list[ti]
        for li in range(15):
            vol_rho = np.zeros((64,76,64))
            vol_p = np.zeros((64,76,64))
            for vi in range(nvv):
                idx = brain_inform_list[vi][1]
                roirdm = brain_rdm[vi]
                vol_rho[idx], vol_p[idx] = rsa(stim_type,cnnrdms[sbj_idx,li])
            save_nifti(vol_rho,savebase+'/P'+str(sbj_list[sbj_idx])+'_rho_'+sv_key+'_L'+str(li+1)+'.nii')
            save_nifti(vol_p,savebase+'/P'+str(sbj_list[sbj_idx])+'_pvl_'+sv_key+'_L'+str(li+1)+'.nii')

In [None]:
# t-test
stim_type_list = ['_dpt', '_opt', '_offd', '3x3','3x3']
savebase = base+'/For_thesis/RSA_ROI_CNN/'
for ti, type_ in enumerate(interest_list):
    stim_type = stim_type_list[ti]
    for li in range(depth):
        # individual rho map load
        arr_all = np.zeros((sbjs_num,64,76,64))
        for sbj_idx in range(sbjs_num):
            arr_all[sbj_idx] = nib.load(base+'/RSA_Wholebrain_CNN/PCA3_90/P'+str(sbj_list[sbj_idx])+'_rho_'+type_+stim_type+'_L'+str(li+1)+'.nii').get_fdata()
        # group mask
        intsec_mask = np.zeros((64,76,64),dtype=int)
        vol_sum = np.sum((arr_all != 0), axis=0)
        ints_idx = np.where(vol_sum == sbjs_num)
        print(type_,':',ints_idx[0].shape[0])
        intsec_mask[ints_idx] = 1
        # ttest
        tmpvols_off = arr_all.copy() #sbjs_num,64,76,64
        # mask
        msk_idx = np.where(intsec_mask == 1)
        nvv = msk_idx[0].shape[0]
        # r to z
        vecs = np.arctanh(tmpvols_off.transpose(1,2,3,0)[msk_idx])
        # t-test    
        tmpt = np.zeros((nvv))
        tmpp = np.zeros((nvv))
        for vi, vec in enumerate(vecs):
            tmpt[vi], tmpp[vi] = stats.ttest_1samp(vec,0.)
        # save t
        vol_tmp = np.zeros((64,76,64))
        vol_tmp[msk_idx] = tmpt
        img1 = nib.Nifti1Image(vol_tmp,affine=affine)
        nib.save(img1,savebase+'tmap_'+type_+'_L'+str(li+1)+'_r3.nii')
        # save p
        vol_tmp = np.zeros((64,76,64))
        vol_tmp[msk_idx] = tmpp
        img1 = nib.Nifti1Image(vol_tmp,affine=affine)
        nib.save(img1,savebase+'pval_'+type_+'_L'+str(li+1)+'_r3.nii')

In [None]:
# load tmaps
tmaps = np.zeros((5,15,64,76,64))
tpmaps = np.zeros((5,15,64,76,64))
for ti, type_ in enumerate(interest_list):
    for li in range(depth):
        loadpath = savebase+'tmap_'+type_+'_L'+str(li+1)+'_r3.nii'
        tmaps[ti,li] = nib.load(loadpath).get_fdata()
        loadpath = savebase+'pval_'+type_+'_L'+str(li+1)+'_r3.nii'
        tpmaps[ti,li] = nib.load(loadpath).get_fdata()

In [None]:
from pathlib import Path
import matplotlib as mpl
fpath = Path(mpl.get_data_path(), "/users/jmy/data/image_sound/arial.ttf")

In [None]:
def save_bar_graph_mod2(ti, ylim=None, svk=''):
    type_ = interest_list[ti]
    fchar = type_[0].lower()
    path = base+'/01_STEP/16x16/offd/5e-02_'+fchar+'l2_20_r3_mask+tlrc.HEAD'
    clst_ = nib.load(path).get_fdata().squeeze()
    totn = int(clst_.max())
    rown = int(np.ceil(totn/4))
    height = rown*4
    fig, axes = plt.subplots(rown,4,figsize=(16,height-1),facecolor='white',sharex=True,constrained_layout=True)
    ax = axes.flat
    if rown*4 != totn:
        for ai in range(totn,rown*4):
            fig.delaxes(ax[ai])
    for clstn in range(totn):
        clst_idx = np.where(clst_== clstn+1)
        vxlnum = clst_idx[0].shape[0]
        # bar graph
        lay_mean = np.zeros((15))
        lay_ste = np.zeros((15))
        sig_info = np.ones((15))
        for li2 in range(depth):
            tmap, tpmap = tmaps[ti][li2],tpmaps[ti][li2]
            tmap_thr = tmap.copy()
            vec = tmap_thr[clst_idx]
            vec = vec[(vec>0)]
            if np.sum(vec) == 0:
                vec = tmap[clst_idx].copy()
                sig_info[li2] = 0
            lay_mean[li2] = vec.mean()
            lay_ste[li2] = vec.std()/vec.shape[0]
        cor_idx = np.where(sig_info==1)
        tran_idx = np.where(sig_info!=1)
        ax[clstn].bar(xticks[cor_idx],lay_mean[cor_idx],color=col_arr[cor_idx],width=0.9,linewidth=1.5,
                    yerr=lay_ste[cor_idx],error_kw=dict(lw=1.2),ecolor='k')
        if ylim != None:
            ax[clstn].set_yticks(np.arange(ylim[0],ylim[1]+0.1,1.0))
        ax[clstn].get_xaxis().set_visible(False)
        if (clstn%4 == 0) or clstn == 0:
            ax[clstn].set_ylabel('t-score', font=fpath,fontsize=20)
        sns.despine(ax=ax[clstn])
    plt.show()
    plt.close()
for ti in range(5):
    save_bar_graph_mod2(ti,ylim=(0,2),svk='Pos_')

In [None]:
# For Vs
def save_bar_graph_mod2(ti, ylim=None, svk=''):
    type_ = interest_list[ti]
    fchar = type_[0].lower()
    path = base+'/01_STEP/16x16/offd/5e-02_'+fchar+'l2_20_r3_mask+tlrc.HEAD'
    clst_ = nib.load(path).get_fdata().squeeze()
    totn = int(clst_.max())
    rown = int(np.ceil(totn/6))
    height = rown*4
    fig, axes = plt.subplots(rown,6,figsize=(24,height-0.5),facecolor='white',sharex=True,constrained_layout=True)
    ax = axes.flat
    for clstn in range(totn):
        clst_idx = np.where(clst_== clstn+1)
        vxlnum = clst_idx[0].shape[0]
        # bar graph
        lay_mean = np.zeros((15))
        lay_ste = np.zeros((15))
        sig_info = np.ones((15))
        for li2 in range(depth):
            tmap, tpmap = tmaps[ti][li2],tpmaps[ti][li2]
            tmap_thr = tmap.copy()
            vec = tmap_thr[clst_idx]
            vec = vec[(vec>0)]
            if np.sum(vec) == 0:
                vec = tmap[clst_idx].copy()
                sig_info[li2] = 0
            lay_mean[li2] = vec.mean()
            lay_ste[li2] = vec.std()/vec.shape[0]
        cor_idx = np.where(sig_info==1)
        tran_idx = np.where(sig_info!=1)
        ax[clstn].bar(xticks[cor_idx],lay_mean[cor_idx],color=col_arr[cor_idx],width=0.9,linewidth=1.5,
                    yerr=lay_ste[cor_idx],error_kw=dict(lw=1.2),ecolor='k')
        if ylim != None:
            ax[clstn].set_yticks(np.arange(ylim[0],ylim[1]+0.1,1.0))
        ax[clstn].get_xaxis().set_visible(False)
        if clstn == 0:
            ax[clstn].set_ylabel('t-score', font=fpath,fontsize=20)
        sns.despine(ax=ax[clstn])
    plt.show()
    plt.close()
ti = 2
save_bar_graph_mod2(ti,ylim=(0,2),svk='Pos_')