In [1]:
from helpers import *
from vbi.feature_extraction.features_utils import get_fc

In [2]:
path = "output/"
data_path = "../../../data/SBI/PRE/"

In [3]:
data_ppc = np.load(path + "/output_ppc.npz")
data_ppc_Th = np.load(path + "/output_ppc_Th.npz")
Bolds_ppc = data_ppc["fmri_d"]
Bolds_ppc_Th = data_ppc_Th["fmri_d"]
Bolds_ppc = Bolds_ppc.transpose(2, 1, 0)
Bolds_ppc_Th = Bolds_ppc_Th.transpose(2, 1, 0)
Bold_ppc = Bolds_ppc[:, :, 50:-1]
Bold_ppc_Th = Bolds_ppc_Th[:, :, 50:-1]
Bold_ppc = Bold_ppc[:, :, :900]
Bold_ppc_Th = Bold_ppc_Th[:, :, :900]
Bold_ppc.shape, Bold_ppc_Th.shape
Bold_ppc = np.concatenate((Bold_ppc, Bold_ppc_Th), axis=0)
Bold_ppc.shape

(3700, 108, 900)

In [4]:
def interpolate_vector(x, nt_new):
    nt = len(x)
    if nt == nt_new:
        return x
    else:
        return np.interp(np.linspace(0, nt, nt_new), np.arange(nt), x)
    
def interpolate_bold(bold, nt_new):
    '''  
    interpolate each row to have nn by nt_new matrix 
    
    bold : np.ndarray (nn x nt)
    '''
    nn, _ = bold.shape
    bold_new = np.zeros((nn, nt_new))
    for i in range(nn):
        bold_new[i] = interpolate_vector(bold[i], nt_new)
    
    # append nn x 50 zeros and the beginning and 50 x 1 zeros at the end
    bold_new = np.hstack([np.zeros((nn, 50)), bold_new, np.zeros((nn, 1))])

    return bold_new

def load_pickle(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)
    
def get_bold_emp(group, subject_id, prepost, BOLD, nt=900, offset=50):
    TR_map = {"ACA": 0.3, "RSC":0.3, "CTRL": 0.3, "Th": 0.3}
    bold = BOLD[group][f'{subject_id}_{prepost}'].T
    if group == "Th":
        bold = interpolate_bold(bold, 900)
    
    t = np.arange(0, bold.shape[1]) * TR_map[group]
    
    return bold[:, offset:nt+offset], t[offset:nt+offset]


def preprocess(bold, nt=900):
    bold = bold - np.mean(bold, axis=1, keepdims=True)
    return bold[:, :nt]

# uncomment the following lines to load the empirical data if available
# load empirical data:
# BOLD_file = (
#     "/home/ziaee/Desktop/workstation/Giovanni/OCT/G/Allen_connectome/BOLD_data.pkl"
# )
groups = ["RSC", "ACA", "CTRL"]  # "Th"
# BOLD = load_pickle(BOLD_file)

In [8]:
fs = {"ACA": 1./0.3, "RSC": 1./0.3, "Th": 1./1.75, "CTRL": 1./0.3}
n_subjects = {"ACA": 7, "RSC": 14, "Th": 8, "CTRL": 8}
Bold = {"ACA":[], "RSC":[], "CTRL":[], "Th":[]}

# for g in ['RSC', 'ACA', "CTRL", "Th"]:
#     for p in ["pre"]:
#         for i in range(n_subjects[g]):
#             bold, times = get_bold_emp(g, i, p, BOLD, nt=900, offset=50)
#             Bold[g].append(bold)                

In [9]:
# len(Bold['ACA']), Bold['ACA'][0].shape

(7, (108, 900))

In [5]:
def get_features(bold):
    """  
    bold: [nn, nt]
    """
    bold = preprocess(bold)
    fcd = get_fcd(bold, wwidth=50)['full']
    fc = get_fc(bold)['full']
    
    return fc, fcd


In [6]:
def plot_half_matrix(
    A,
    ax,
    cmap="hot",
    mask="upper",
    k=2,
    vmax=None,
    vmin=None,
    colorbar_ticks=None,
    colorbar=True,
):
    if mask == "upper":
        _mask = np.triu(np.ones_like(A), k=k)
    elif mask == "lower":
        _mask = np.tril(np.ones_like(A), k=-k)
    else:
        raise ValueError("mask must be either 'upper' or 'lower'")

    from mpl_toolkits.axes_grid1 import make_axes_locatable

    masked_matrix = np.ma.masked_where(_mask == 0, A)
    im = ax.imshow(masked_matrix, cmap=cmap, vmax=vmax, vmin=vmin)

    if colorbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.05)
        plt.colorbar(im, cax=cax, ax=ax, ticks=colorbar_ticks)
    ax.set_xticks([])
    ax.set_yticks([])


def plot_hist(ax, data, bins=100, color="blue", alpha=0.5, label=None):
    ax.hist(data, bins=bins, color=color, alpha=alpha, label=label)
    ax.set_yticks([])

In [10]:
# print(Bold["ACA"][0].shape)
print(Bold_ppc[0].shape)

((108, 900), (108, 900))

In [11]:
len(Bold_ppc)

3700

In [23]:
import tqdm
from scipy.stats import ks_2samp

def get_ks_dist(n_ensemble=100):
    
    n_subjects = {"ACA": 7, "RSC": 14, "CTRL": 8, "Th": 8}
    
    idx = 0
    ks_dist_fc = []
    ks_dist_fcd = []
    corr_fc = []
    for g in ['RSC', "ACA", "CTRL", "Th"]:
        for i in tqdm.tqdm(range(n_subjects[g])):
            bold_e = Bold[g][i]
            fc_e, fcd_e = get_features(bold_e)
            
            for j in range(n_ensemble):
                bold_s = Bold_ppc[idx]
                idx += 1
                if np.isnan(bold_s).any():
                    ks_dist_fc.append([idx, 1])
                    ks_dist_fcd.append([idx, 1])
                    corr_fc.append([idx, 0])
                else:
                    fc_s, fcd_s = get_features(bold_s)
                    ks_dist_fcd.append([idx, ks_2samp(fcd_e.flatten(), fcd_s.flatten())[0]])
                    ks_dist_fc.append([idx, ks_2samp(fc_e.flatten(), fc_s.flatten())[0]])
                    corr_fc.append([idx, np.corrcoef(fc_e.flatten(), fc_s.flatten())[0, 1]])
                    
    print(idx)
    return ks_dist_fc, ks_dist_fcd, corr_fc

In [24]:
ks_dist_fc, ks_dist_fcd, corr_fc = get_ks_dist()

100%|██████████| 14/14 [01:14<00:00,  5.35s/it]
100%|██████████| 7/7 [00:33<00:00,  4.85s/it]
100%|██████████| 8/8 [00:40<00:00,  5.10s/it]
100%|██████████| 8/8 [01:02<00:00,  7.79s/it]

3700





In [25]:
len(ks_dist_fc), len(ks_dist_fcd)

(3700, 3700)

In [12]:
peak_df = pd.read_csv(data_path + "/peaks.csv")
peak_df.shape

(3700, 10)

In [27]:
df = peak_df.copy()
ks_fc = [ks_dist_fc[i][1] for i in range(len(ks_dist_fc))]
ks_fcd = [ks_dist_fcd[i][1] for i in range(len(ks_dist_fcd))]
corr_fc = [corr_fc[i][1] for i in range(len(corr_fc))]
df['ks_fc'] = ks_fc
df['ks_fcd'] = ks_fcd
df['ks_fc_fcd'] = np.array(ks_fcd) + np.array(ks_fc)
df['corr_fc'] = corr_fc

In [14]:
# df.to_csv(path + "/peaks_ks.csv", index=False)
df = pd.read_csv(data_path + "/peaks_ks.csv")
print(df.shape)
df.head()

(3700, 14)


Unnamed: 0,group,subject_id,ensemble,g,eta1,eta2,eta3,eta4,eta5,eta6,ks_fc,ks_fcd,ks_fc_fcd,corr_fc
0,RSC,0,0,0.382328,-4.544004,-5.971716,-4.206444,-4.049391,-4.943881,-4.069718,1.0,1.0,2.0,0.0
1,RSC,0,1,0.571165,-3.69264,-4.888504,-4.900546,-4.597633,-4.413955,-4.107019,0.108711,0.763261,0.871971,0.110446
2,RSC,0,2,0.643984,-3.627785,-5.271295,-4.567267,-4.341366,-4.802573,-4.135063,1.0,1.0,2.0,0.0
3,RSC,0,3,0.570303,-3.735444,-4.474549,-4.982661,-4.18489,-5.887011,-4.087141,0.086763,0.792793,0.879556,0.084613
4,RSC,0,4,0.532017,-3.701018,-5.512277,-4.41649,-4.064951,-4.359322,-3.695864,0.105453,0.703418,0.808871,0.130877


In [32]:

os.makedirs(path+"/figs/fcfcd", exist_ok=True)

def plot(n_ensemble=100, groups=['RSC', "ACA", "CTRL", "Th"]):
    
    n_subjects = {"ACA": 7, "RSC": 14, "CTRL": 8, "Th": 8}
    
    idx = 0
    for g in groups:
        for i in tqdm.tqdm(range(n_subjects[g]), disable=True, desc=g):
            bold_e = Bold[g][i]
            fc_e, fcd_e = get_features(bold_e)
            
            # find min value of ks_fc from dataframe df with group g and subject i
            # ks_fc_i = df[(df.group == g) & (df.subject_id == i)].ks_fc.idxmin()
            # ks_fcd_i = df[(df.group == g) & (df.subject_id == i)].ks_fcd.idxmin() 
            ks_fc_fcd_i = df[(df.group == g) & (df.subject_id == i)].ks_fc_fcd.idxmin()
            
            bold_s = Bold_ppc[ks_fc_fcd_i]

            fc_e, fcd_e = get_features(bold_e)
            fc_s, fcd_s = get_features(bold_s)
            
            fig, axs = plt.subplots(1, 4, figsize=(13, 3))
            plot_half_matrix(fc_e, axs[0], cmap="coolwarm", vmax=1, vmin=-0.2, colorbar=False, mask="lower")
            plot_half_matrix(fc_s, axs[0], cmap="coolwarm", vmax=1, vmin=-0.2, colorbar=True, mask='upper')

            plot_half_matrix(fcd_e, axs[1], cmap="coolwarm", vmax=1, vmin=-0.2, colorbar=False, mask="lower")
            plot_half_matrix(fcd_s, axs[1], cmap="coolwarm", vmax=1, vmin=-0.2, colorbar=True, mask='upper')
            
            plot_hist(axs[2], fc_e.flatten(), bins=100, color="blue", alpha=0.5, label="Empirical")
            plot_hist(axs[2], fc_s.flatten(), bins=100, color="red", alpha=0.5, label="Predicted")
            
            plot_hist(axs[3], fcd_e.flatten(), bins=100, color="blue", alpha=0.5, label="Empirical")
            plot_hist(axs[3], fcd_s.flatten(), bins=100, color="red", alpha=0.5, label="Predicted")
            
            fcd_ks_dist, p_value = ks_2samp(fcd_e.flatten(), fcd_s.flatten())
            fc_ks_dist, p_value = ks_2samp(fc_e.flatten(), fc_s.flatten())
            corr_fc = np.corrcoef(fc_e.flatten(), fc_s.flatten())[0, 1]
            axs[2].legend(frameon=False)
            axs[0].set_title('FC predicted')
            axs[0].set_ylabel('FC empirical')
            axs[1].set_title('FCD predicted')
            axs[1].set_ylabel('FCD empirical')
            axs[2].set_title("KS FC: {:.2f}, Corr: {:.2f}".format(fc_ks_dist, corr_fc))
            axs[3].set_title("KS FCD: {:.2f}".format(fcd_ks_dist))
            plt.tight_layout()
            plt.savefig(f"{path}/figs/fcfcd/{g}_{i}.png", bbox_inches='tight')
            plt.close()

In [33]:
plot(groups=['Th'])

In [16]:
import tqdm 

In [17]:
# optimal parameters by minimizing ks_fc_fcd

df_opt = pd.DataFrame(columns=["group", "subject_id", "ks_fc_fcd", 'ks_fc', 'ks_fcd', 
                               'g', 'eta1', 'eta2', 'eta3', 'eta4', 'eta5', 'eta6'])
data_opt = []
n_subjects = {"ACA": 7, "RSC": 14, "CTRL": 8, "Th": 8}   
optimal_indices = []
for g in ['RSC', "ACA", "CTRL", "Th"]:
    for i in tqdm.tqdm(range(n_subjects[g]), disable=True, desc=g):
        ii = df[(df.group == g) & (df.subject_id == i)].ks_fc_fcd.idxmin()
        data_opt.append([g, i, df.iloc[ii].ks_fc_fcd, df.iloc[ii].ks_fc, df.iloc[ii].ks_fcd, 
                         df.iloc[ii].g, df.iloc[ii].eta1, df.iloc[ii].eta2, 
                         df.iloc[ii].eta3, df.iloc[ii].eta4, df.iloc[ii].eta5, 
                         df.iloc[ii].eta6])
        optimal_indices.append(ii)

df_opt = pd.DataFrame(data_opt, columns=["group", "subject_id", "ks_fc_fcd", 'ks_fc', 'ks_fcd',
                                         'g', 'eta1', 'eta2', 'eta3', 'eta4', 'eta5', 'eta6'])

# df_opt.to_csv(path + "/peaks_ks_opt.csv", index=False)
df_opt.head()


Unnamed: 0,group,subject_id,ks_fc_fcd,ks_fc,ks_fcd,g,eta1,eta2,eta3,eta4,eta5,eta6
0,RSC,0,0.246743,0.161351,0.085392,0.647236,-3.850671,-4.965111,-4.645602,-4.107086,-5.199185,-4.656292
1,RSC,1,0.296758,0.073045,0.223713,0.853356,-3.807516,-4.64815,-4.33748,-5.691139,-3.733423,-4.690051
2,RSC,2,0.445386,0.314986,0.1304,0.543796,-4.182987,-4.90388,-4.78125,-4.013793,-4.155088,-4.265911
3,RSC,3,0.099736,0.066358,0.033378,0.738217,-4.236919,-4.705533,-5.444082,-3.511773,-4.412962,-4.820805
4,RSC,4,0.225897,0.086248,0.139649,0.874624,-4.578669,-5.199881,-5.47266,-5.333493,-4.941657,-4.204763


In [20]:
np.savetxt(data_path + "/optimal_indices.txt", optimal_indices, fmt='%d')

In [23]:
# Bold_ppc.shape

# Bold_optimal = []
# for i in optimal_indices:
#     bold = Bold_ppc[i]
#     if np.isnan(bold).any():
#         print(i)
#     Bold_optimal.append(bold)
# Bold_optibal = np.array(Bold_optimal)
# np.savez(path + "/Bold_optimal.npz", Bold=Bold_optimal)