In [1]:
import pandas as pd
import math
import matplotlib.pyplot as plt
import numpy as np

In [2]:
df = pd.read_table('~/Desktop/TDBRAIN-dataset-derivatives/TDBRAIN_participants_V2_data/TDBRAIN_participants_V2.tsv')
df['indication'] = df['indication'].str.strip()

In [3]:
indications = df.indication.unique()
unique_indications = ['NaN']
for i in indications:
    if type(i) is not type(str()):
        continue
    elif "/" not in i: # Some patients have multiple diagnosis
        unique_indications.append(i)

In [None]:
forty_or_more = ["NaN"]
for d in unique_indications:
    if d == "NaN":
        num_nan = df.indication.isna().sum()
        print(f'Number of NaN instances: ' + str(num_nan))
    elif "/" not in d:
        run_sum = 0
        for i in df.indication.unique():
            if type(i) is not type(str()) and math.isnan(i):
                continue
            elif d in i:
                run_sum+=df.indication.value_counts()[i]
        print(f'Number of ' + d + ' instances: ' + str(run_sum))
        if run_sum >= 40:
            forty_or_more.append(d)

# Diagnoses with more than forty subjects
Note: NaN and REPLICATION categories are either unknown or withheld diagnoses

In [5]:
print(forty_or_more)

['NaN', 'REPLICATION', 'SMC', 'HEALTHY', 'MDD', 'ADHD', 'OCD']


# The only columns we care about: 
## \['participants_ID', 'indication', 'age', 'gender', 'sessID'\]
Will add two additional columns:
1) Whether the subject has multiple diagnoses <br>
2) The folder location with the subject's EEG data

# Six Subject Groupings
### \['UNKNOWN', 'SMC', 'HEALTHY', 'MDD', 'ADHD', 'OCD'\]

### Raw Participant Separation

In [7]:
df_copy = df.copy()
# Replace commas with periods and convert age to numeric type
df_copy['age'] = df_copy['age'].str.replace(',', '.')
df_copy['age'] = pd.to_numeric(df_copy['age'])

In [8]:
# Remove subjects whose age is unknown (NaN)
print("Number of entries: " + str(len(df_copy)))
df_copy.dropna(subset=['age'], inplace=True)
print("Number of entries after removing unknown age entries: " + str(len(df_copy)))

Number of entries: 1350
Number of entries after removing unknown age entries: 1326


In [9]:
# Convert Remaining NaN values to a string of 'NaN'
df_copy.replace(np.nan,'NaN',regex=True,inplace=True)

In [10]:
df_unknown = pd.DataFrame(columns=df_copy.columns)
df_smc = pd.DataFrame(columns=df_copy.columns)
df_healthy = pd.DataFrame(columns=df_copy.columns)
df_mdd = pd.DataFrame(columns=df_copy.columns)
df_adhd = pd.DataFrame(columns=df_copy.columns)
df_ocd = pd.DataFrame(columns=df_copy.columns)

In [11]:
# ['NaN', 'REPLICATION', 'SMC', 'HEALTHY', 'MDD', 'ADHD', 'OCD']
diag_dict = {
    'NaN': df_unknown,
    'REPLICATION': df_unknown,
    'SMC': df_smc,
    'HEALTHY': df_healthy,
    'MDD': df_mdd,
    'ADHD': df_adhd,
    'OCD': df_ocd,
}

In [12]:
# Go through each entry and add to the respective group(s)
for row in df_copy.itertuples(index=False, name='Pandas'):
    for diag in forty_or_more:
        if diag in row.indication:
            cur_df = diag_dict[diag]
            cur_df.loc[len(cur_df.index)] = row
            if diag == row.indication: # Single diagnosis for current entry?
                break # No need to continue inner loop

In [13]:
print(len(df_unknown))
print(len(df_smc))
print(len(df_healthy))
print(len(df_mdd))
print(len(df_adhd))
print(len(df_ocd))

377
115
37
360
232
57


In [14]:
# Save created diagnosis dataframes as csv files
df_unknown.to_csv('Patient_Groups/raw_splitting/unknown_diagnosis/participants.csv', index=False)
df_smc.to_csv('Patient_Groups/raw_splitting/smc/participants.csv', index=False)
df_healthy.to_csv('Patient_Groups/raw_splitting/healthy/participants.csv', index=False)
df_mdd.to_csv('Patient_Groups/raw_splitting/mdd/participants.csv', index=False)
df_adhd.to_csv('Patient_Groups/raw_splitting/adhd/participants.csv', index=False)
df_ocd.to_csv('Patient_Groups/raw_splitting/ocd/participants.csv', index=False)

### For Model Building

In [15]:
df_copy = df[['participants_ID', 'indication', 'age', 'gender', 'sessID']].copy()
# Replace commas with periods and convert age to numeric type
df_copy['age'] = df_copy['age'].str.replace(',', '.')
df_copy['age'] = pd.to_numeric(df_copy['age'])

In [16]:
# Remove subjects whose age is unknown (NaN)
print("Number of entries: " + str(len(df_copy)))
df_copy.dropna(subset=['age'], inplace=True)
print("Number of entries after removing unknown age entries: " + str(len(df_copy)))

Number of entries: 1350
Number of entries after removing unknown age entries: 1326


In [17]:
# Convert Remaining NaN values to a string of 'NaN'
df_copy.replace(np.nan,'NaN',regex=True,inplace=True)

In [18]:
df_unknown = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])
df_smc = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])
df_healthy = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])
df_mdd = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])
df_adhd = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])
df_ocd = pd.DataFrame(columns=['participants_ID', 'indication', 'multi', 'age', 'gender', 'sessID', 'data_loc'])

In [19]:
# ['NaN', 'REPLICATION', 'SMC', 'HEALTHY', 'MDD', 'ADHD', 'OCD']
diag_dict = {
    'NaN': df_unknown,
    'REPLICATION': df_unknown,
    'SMC': df_smc,
    'HEALTHY': df_healthy,
    'MDD': df_mdd,
    'ADHD': df_adhd,
    'OCD': df_ocd,
}
# Subject-Independent Data Path
sub_ind_path = '~/Desktop/TDBRAIN-dataset-derivatives/derivatives/'

In [20]:
# Go through each entry and add to the respective group(s)
for row in df_copy.itertuples(index=True, name='Pandas'):
    for diag in forty_or_more:
        if diag in row.indication:
            cur_df = diag_dict[diag]
            cur_df.loc[len(cur_df.index)] = [row.participants_ID, 
                                             row.indication, 
                                             bool(diag != row.indication),
                                             row.age,
                                             row.gender,
                                             row.sessID,
                                             (sub_ind_path + row.participants_ID)]
            if diag == row.indication: # Single diagnosis for current entry?
                break # No need to continue inner loop

In [25]:
print(len(df_unknown))
print(len(df_smc))
print(len(df_healthy))
print(len(df_mdd))
print(len(df_adhd))
print(len(df_ocd))

377
115
37
360
232
57


In [None]:
pd.set_option('max_colwidth', False)
display(df_healthy)

In [23]:
# Save created diagnosis dataframes as csv files
df_unknown.to_csv('Patient_Groups/for_model_building/unknown_diagnosis_subjects.csv', index=False)
df_smc.to_csv('Patient_Groups/for_model_building/smc_subjects.csv', index=False)
df_healthy.to_csv('Patient_Groups/for_model_building/healthy_subjects.csv', index=False)
df_mdd.to_csv('Patient_Groups/for_model_building/mdd_subjects.csv', index=False)
df_adhd.to_csv('Patient_Groups/for_model_building/adhd_subjects.csv', index=False)
df_ocd.to_csv('Patient_Groups/for_model_building/ocd_subjects.csv', index=False)

In [None]:
# End

In [34]:
# Generate CF Splits
def get_splits(group, df_group):
    participants = df_group['participants_ID'].copy()
    participants = participants.str.replace('sub-', '')
    participants = participants.tolist()
    
    min_per_group = len(participants) // 5
    extras = len(participants) % 5
    
    index=0
    for i in range(5):
        if extras > 0:
            span = min_per_group + 1
            extras-=1
        else:
            span = min_per_group
        
        cur_split = participants[index:index+span]
        index = index+span
        print(cur_split)
        np.save('Patient_Groups/for_model_building/'+group+'_group/split_'+str(i+1)+'.npy',cur_split)

In [35]:
get_splits('healthy', df_healthy)

['87974665', '87974709', '87974841', '87974973', '87976193', '87976369', '87976457', '87976505']
['87976817', '87976953', '87977045', '87980241', '87980373', '87980417', '87980869', '87982225']
['87982849', '88008997', '88041893', '88041941', '88048729', '88049585', '88051073']
['88053453', '88053545', '88055121', '88055301', '88057461', '88057869', '88058001']
['88058633', '88059397', '88067357', '88067853', '88068841', '88073029', '88075053']
