This notebook creates the test file from the aymeric stratification by adding the precatatoes dataset.
It also stratifies more correctly the dataset using the concatenation of the train+val+test_intra dataset

In [62]:
import pandas as pd
import random
import csv

# Initialization

In [63]:
src_path = "/neurospin/dico/data/deep_folding/current/datasets/schiz/aymeric_stratification/with_sub"
src_test_file = f"{src_path}/test_subjects.csv"
src_test_intra_file = f"{src_path}/test_intra_subjects.csv"
src_val_file = f"{src_path}/val_subjects.csv"
src_train_file = f"{src_path}/train_subjects.csv"

In [64]:
precatatoes_file = "/neurospin/dico/data/bv_databases/human/partially_labeled/orbital_patterns/PreCatatoes/OFC_sulcal_type_data_186-subjects_only-schiz-and-control.csv"

tgt_path = "/neurospin/dico/data/deep_folding/current/datasets/schiz_extended"
tgt_test_file = f"{tgt_path}/test_subjects.csv"
tgt_test_intra_file = f"{tgt_path}/test_intra_subjects.csv"
tgt_val_file = f"{tgt_path}/val_subjects.csv"
tgt_train_file = f"{tgt_path}/train_subjects.csv"
label_file = f"{tgt_path}/used_schiz_subjects.csv"

# Creates the new test file

In [65]:
test_src = pd.read_csv(src_test_file, header=None).rename(columns={0: "participant_id"})
test_src

Unnamed: 0,participant_id
0,sub-A00014522_ses-v1
1,sub-A00001243_ses-v1
2,sub-A00028405_ses-v1
3,sub-A00028408_ses-v1
4,sub-A00020968_ses-v1
...,...
158,sub-A00027537_ses-v1
159,sub-A00026945_ses-v1
160,sub-A00014636_ses-v1
161,sub-A00022915_ses-v1


In [66]:
precatatoes = pd.read_csv(precatatoes_file)[["participant_id"]].astype(str)
precatatoes

Unnamed: 0,participant_id
0,1
1,2
2,3
3,4
4,5
...,...
157,182
158,183
159,184
160,185


In [67]:
test_tgt = pd.concat([test_src, precatatoes], axis=0, ignore_index=True)

In [68]:
test_tgt.to_csv(tgt_test_file, index=False, header=False)

# Making stratification on train/val/test_intra

In [69]:
train_src = pd.read_csv(src_train_file, header=None).rename(columns={0: "participant_id"})
val_src = pd.read_csv(src_val_file, header=None).rename(columns={0: "participant_id"})
test_intra_src = pd.read_csv(src_test_intra_file, header=None).rename(columns={0: "participant_id"})

In [70]:
all_src = pd.concat([train_src, val_src, test_intra_src], axis=0, ignore_index=True)

In [71]:
all_src

Unnamed: 0,participant_id
0,sub-ESOC10104_ses-v1
1,sub-NM2020_ses-v1
2,sub-or130001_ses-v1
3,sub-ESOC10112_ses-v1
4,sub-CC6287_ses-v1
...,...
1124,sub-10696_ses-1
1125,sub-SS093_ses-1
1126,sub-SS100_ses-1
1127,sub-HC021_ses-1


In [72]:
label = pd.read_csv(label_file)[["participant_id", "sex", "diagnosis", "site"]]
label

Unnamed: 0,participant_id,sex,diagnosis,site
0,sub-INV07WT2ZL3,M,control,Dallas
1,sub-INV0AL14J6U,M,schizophrenia,Dallas
2,sub-INV14XK7P6E,M,control,Dallas
3,sub-INV1HXNTXYF,F,control,Dallas
4,sub-INV1XCNF4J5,F,control,Dallas
...,...,...,...,...
1449,182,M,schizophrenia,Sainte-Anne
1450,183,F,schizophrenia,Sainte-Anne
1451,184,M,control,Sainte-Anne
1452,185,M,schizophrenia,Sainte-Anne


In [73]:
label_except_test = label[label.participant_id.isin(all_src.participant_id)]

In [74]:
def iterative_split_through_sorting_shuffle(df, n_splits, stratify_columns, random_state):
    """Custom iterative train test split which
    maintains balanced representation.
    """
    # Dataframe random row shuffle + sorting according to stratify_columns
    sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
    # for each fold, we take one row every n_splits rows
    folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
    # Further shuffling
    folds = [fold.sample(frac=1, random_state=random_state) for fold in folds]
    random.Random(random_state).shuffle(folds)
    return folds

In [75]:
def print_results(parent, folds, col):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    print("query   : #rows      : #rows per fold\n")

    for col0 in parent[col[0]].unique():
        for col1 in parent[col[1]].unique():
            for col2 in parent[col[2]].unique():
                df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
                len_query = len(df)
                print(f"{col0}, {col1}, {col2} : total = {len_query} : per fold =", end = ' ')
                for fold in folds:
                    df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
                    len_query_fold = len(df0)
                    if abs(len_query_fold-len_query/n_splits) >= 1:
                        total_errors += 1
                    print(f"{len_query_fold} -", end= ' ')
                print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 1:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

In [76]:
folds = iterative_split_through_sorting_shuffle(label_except_test, 10, ['sex', 'diagnosis', 'site'], 1)

In [77]:
print_results(label_except_test, folds, ['sex', 'diagnosis', 'site'])

query   : #rows      : #rows per fold

M, control, Dallas : total = 18 : per fold = 2 - 1 - 1 - 2 - 2 - 2 - 2 - 2 - 2 - 2 - 
M, control, Detroit : total = 12 : per fold = 1 - 2 - 2 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 
M, control, Baltimore : total = 22 : per fold = 2 - 3 - 2 - 3 - 2 - 2 - 2 - 2 - 2 - 2 - 
M, control, Boston : total = 12 : per fold = 1 - 1 - 2 - 1 - 1 - 1 - 2 - 1 - 1 - 1 - 
M, control, Hartford : total = 18 : per fold = 2 - 1 - 1 - 2 - 2 - 2 - 2 - 2 - 2 - 2 - 
M, control, CANDI : total = 14 : per fold = 1 - 1 - 1 - 1 - 1 - 2 - 1 - 2 - 2 - 2 - 
M, control, CNP : total = 65 : per fold = 7 - 7 - 7 - 7 - 7 - 6 - 6 - 6 - 6 - 6 - 
M, control, NU : total = 18 : per fold = 1 - 2 - 2 - 1 - 2 - 2 - 2 - 2 - 2 - 2 - 
M, control, WUSTL : total = 68 : per fold = 7 - 7 - 7 - 7 - 6 - 7 - 7 - 6 - 7 - 7 - 
M, control, vip : total = 23 : per fold = 3 - 2 - 2 - 2 - 3 - 2 - 2 - 3 - 2 - 2 - 
M, control, PRAGUE : total = 40 : per fold = 4 - 4 - 4 - 4 - 4 - 4 - 4 - 4 - 4 - 4 - 
M, schizophrenia, Dal

In [78]:
for i in range(len(folds)):
    print(i)
    folds[i]["participant_id"].to_csv(
        f"{tgt_path}/split_{i}.csv",
        index=False,
        header=False)

0
1
2
3
4
5
6
7
8
9


In [79]:
folds[-2]["participant_id"].to_csv(
    tgt_val_file,
    index=False,
    header=False)
folds[-1]["participant_id"].to_csv(
    tgt_test_intra_file,
    index=False,
    header=False)

In [80]:
train_tgt = pd.concat([folds[i]["participant_id"] for i in range(len(folds)-2)], axis=0, ignore_index=True)

In [81]:
train_tgt.to_csv(tgt_train_file, index=False, header=False)

In [82]:
test_src

Unnamed: 0,participant_id
0,sub-A00014522_ses-v1
1,sub-A00001243_ses-v1
2,sub-A00028405_ses-v1
3,sub-A00028408_ses-v1
4,sub-A00020968_ses-v1
...,...
158,sub-A00027537_ses-v1
159,sub-A00026945_ses-v1
160,sub-A00014636_ses-v1
161,sub-A00022915_ses-v1


In [83]:
label[label.participant_id.isin(test_src.participant_id)]

Unnamed: 0,participant_id,sex,diagnosis,site
602,sub-A00000300_ses-v1,M,control,MRN
603,sub-A00000368_ses-v1,M,schizophrenia,MRN
604,sub-A00000456_ses-v1,M,schizophrenia,MRN
605,sub-A00000838_ses-v1,M,schizophrenia,MRN
606,sub-A00000909_ses-v1,M,schizophrenia,MRN
...,...,...,...,...
760,sub-A00037649_ses-v1,F,schizophrenia,MRN
761,sub-A00037665_ses-v1,M,control,MRN
762,sub-A00037854_ses-v1,M,schizophrenia,MRN
763,sub-A00038441_ses-v1,M,schizophrenia,MRN
