In [54]:
import os
import pickle
import pandas as pd
import glob
from tqdm import tqdm, trange
import joblib

# math imports
import numpy as np
import scipy
import sklearn

import warnings
import matplotlib
from matplotlib import pyplot

import copy
from functools import reduce

from dataclasses import dataclass

from method.ICQF.ICQF import ICQF
import sys

In [55]:
@dataclass
class matrix_class:

    M : np.ndarray # (column)-normalized data matrix
    M_raw : np.ndarray # raw data matrix
    confound : np.ndarray # normalized confounder matrix
    confound_raw : np.ndarray # raw confounder matrix
    nan_mask : np.ndarray # mask matrix for missing entires (0=missing, 1=available)
    row_idx : np.ndarray # global row index (for multiple data matrices)
    col_idx : np.ndarray # global column index (for multiple data matrices)
    mask : np.ndarray # global mask (for multiple data matrices)
    dataname : str # dataname
    subjlist : list # information on subjects (row information)
    itemlist : list # information on items (column information)
    W : np.ndarray # subject embedding (recall M = [W, C]Q^T)
    Q : np.ndarray # item embedding (recall M = [W, C]Q^T)
    C : np.ndarray # confounder matrix
    Qc : np.ndarray # confounders' loadings (recall Q = [RQ, CQ])
    Z : np.ndarray # auxiliary Z=WQ^T (ADMM)
    aZ : np.ndarray # auxiliary variables (ADMM)


In [56]:
demo = pd.read_csv('./data/participants_post_motion_QA.csv')
data = pd.read_csv('./data/snycq.csv')

In [57]:
display(demo.head())
display(data.head())

Unnamed: 0,Subject,gender,age (5-year bins)
0,sub-010014,M,20-25
1,sub-010080,F,30-35
2,sub-010082,F,25-30
3,sub-010084,F,20-25
4,sub-010015,M,20-25


Unnamed: 0,Subject,Run,Positive,Negative,Future,Past,Myself,People,Surroundings,Images,Words,Specific,Intrusive
0,sub-010014,post-ses-02-run-01-acq-AP,80.0,15.0,50.0,55.0,30.0,65.0,55.0,95.0,90.0,25.0,30.0
1,sub-010014,post-ses-02-run-01-acq-PA,60.0,10.0,75.0,70.0,65.0,65.0,80.0,95.0,85.0,25.0,55.0
2,sub-010014,post-ses-02-run-02-acq-AP,55.0,60.0,65.0,40.0,50.0,60.0,65.0,75.0,85.0,30.0,60.0
3,sub-010014,post-ses-02-run-02-acq-PA,40.0,55.0,95.0,5.0,55.0,60.0,60.0,80.0,90.0,30.0,55.0
4,sub-010080,post-ses-02-run-02-acq-AP,75.0,20.0,50.0,30.0,35.0,70.0,75.0,70.0,50.0,40.0,10.0


In [58]:
ageclass = np.unique(demo.iloc[:,2].values)

In [59]:
normal_ageclass = 0.5*(np.linspace(0,1,8)[1:] + np.linspace(0,1,8)[:-1])

In [60]:
print('age class: ', ageclass)
print('normalized age class center: ', normal_ageclass)

age class:  ['20-25' '25-30' '30-35' '35-40' '40-45' '45-50' '60-65']
normalized age class center:  [0.07142857 0.21428571 0.35714286 0.5        0.64285714 0.78571429
 0.92857143]


In [61]:
gender_list = []
age_list = []
subjlist = data['Subject'].values
for subj in subjlist:
    gender = demo.loc[demo['Subject']==subj]['gender'].values
    age = demo.loc[demo['Subject']==subj]['age (5-year bins)'].values
    if gender == 'M':
        gender_list.append(1.0)
    else:
        gender_list.append(0.0)
    
    normal_age = float(normal_ageclass[ np.where(ageclass == age)[0] ])
    age_list.append(normal_age)
    

In [62]:
data.insert(1, 'gender', gender_list)
data.insert(2, 'age', age_list)

In [63]:
data.to_csv('./data/processed_data.csv', index=False)

In [64]:
age_split = np.zeros(data.shape[0])

for idx, age in enumerate(normal_ageclass):
    age_split[ np.where(np.array(age_list) == age)[0] ] += idx
    
classes = age_split.copy()
classes[ np.where(np.array(gender_list) == 1.0)[0] ] += 7

In [65]:
data.iloc[:,4:]

Unnamed: 0,Positive,Negative,Future,Past,Myself,People,Surroundings,Images,Words,Specific,Intrusive
0,80.0,15.0,50.0,55.0,30.0,65.0,55.0,95.0,90.0,25.0,30.0
1,60.0,10.0,75.0,70.0,65.0,65.0,80.0,95.0,85.0,25.0,55.0
2,55.0,60.0,65.0,40.0,50.0,60.0,65.0,75.0,85.0,30.0,60.0
3,40.0,55.0,95.0,5.0,55.0,60.0,60.0,80.0,90.0,30.0,55.0
4,75.0,20.0,50.0,30.0,35.0,70.0,75.0,70.0,50.0,40.0,10.0
...,...,...,...,...,...,...,...,...,...,...,...
466,75.0,5.0,75.0,45.0,25.0,80.0,40.0,65.0,20.0,10.0,5.0
467,50.0,55.0,55.0,80.0,55.0,75.0,20.0,85.0,25.0,55.0,60.0
468,85.0,35.0,85.0,45.0,40.0,75.0,55.0,55.0,15.0,45.0,40.0
469,60.0,30.0,85.0,20.0,45.0,60.0,45.0,60.0,15.0,40.0,40.0


In [45]:
from sklearn.model_selection import train_test_split

nsubj = data.shape[0]
split_idx = 0

while split_idx < 10:
    
    skf = StratifiedKFold(n_splits=2, random_state=None, shuffle=True)
    population_list = []
    for i, (population, group) in enumerate(skf.split(np.arange(nsubj), classes)):
        population_list.append(population)
    
    population_1 = population_list[0]
    population_2 = population_list[1]
    
    os.makedirs('./output/population_split', exist_ok=True)
    np.savez('./output/population_split/split_{}_idx.npz'.format(split_idx),
             population_1=population_1, population_2=population_2)
    
    p1_data = data.iloc[population_1].copy()
    p1_data.to_csv('./output/population_split/split_{}_p1_idx.csv'.format(split_idx),  index=False)
    
    p2_data = data.iloc[population_2].copy()
    p2_data.to_csv('./output/population_split/split_{}_p2_dx.csv'.format(split_idx),  index=False)
    
    p1_confound = p1_data[['age','gender']].values
    p2_confound = p2_data[['age','gender']].values
    
    M1 = p1_data.iloc[:,4:].values
    M2 = p2_data.iloc[:,4:].values
    
    nan_mask1 = np.ones_like(M1)
    nan_mask2 = np.ones_like(M2)
    
    np.savez('./output/population_split/split_{}_population_1.npz'.format(split_idx),
             M = M1, nan_mask=nan_mask1, confound=p1_confound)
    np.savez('./output/population_split/split_{}_population_2.npz'.format(split_idx),
             M = M2, nan_mask=nan_mask2, confound=p2_confound)
    

    split_idx +=1
    tt_size = len(list(set( list(population_1) + list(population_2) )))
    assert nsubj == tt_size




In [53]:
confound = data[['age','gender']].values
M = data.iloc[:,4:].values
nan_mask = np.ones_like(M)
    
np.savez('./output/full_data.npz', M = M, nan_mask=nan_mask, confound=confound)

In [66]:
np.savez('./output/full_data_noconfound.npz', M = M, nan_mask=nan_mask, confound=None)

In [34]:
for i in range(10):
    split = np.load('./output/population_split/split_{}.npz'.format(i))
    p1 = split['population_1']
    p1 = np.sort(p1)
    print(p1[:10])

[ 3  4  5  6 11 13 15 17 18 22]
[ 6  9 10 11 14 15 19 20 21 22]
[ 2  5  6  9 12 15 16 17 20 21]
[ 0  1  5  6 10 12 13 15 16 21]
[ 0  1  2  5  6  9 11 12 14 16]
[ 3  4  6  7 10 12 13 15 16 17]
[ 0  1  5  7  8  9 10 11 14 15]
[ 0  3  5  6  7  8  9 10 12 13]
[ 0  1  2  4  6 11 12 15 18 19]
[ 1  4 10 11 12 14 16 17 18 21]
