In [57]:
import pandas as pd
import numpy as np
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.datasets import make_multilabel_classification
from skmultilearn.model_selection import IterativeStratification
from sklearn.model_selection import StratifiedGroupKFold
import random
import csv
import glob
import sklearn

In [58]:
from collections import defaultdict

from sklearn.utils import (
    #_approximate_mode,
    _safe_indexing,
    check_random_state,
    indexable,
    metadata_routing,
)

from sklearn.utils.validation import _num_samples, check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target

# Load data

In [59]:
central = pd.read_csv("/neurospin/dico/data/deep_folding/current/datasets/hcp/hcp_isomap_labels_SC-sylv_left.csv", index_col=0)
central = central.dropna()

In [60]:
participants_file = "/neurospin/dico/data/bv_databases/human/not_labeled/hcp/participants.csv"
participants_unrestricted = pd.read_csv(participants_file)
participants_unrestricted = participants_unrestricted[['Subject', "Gender"]]
print(len(participants_unrestricted))
participants_unrestricted.head()

1206


Unnamed: 0,Subject,Gender
0,100004,M
1,100206,M
2,100307,F
3,100408,M
4,100610,M


In [61]:
participants_file = "/neurospin/dico/jchavas/RESTRICTED_jchavas_1_18_2022_3_17_51.csv"
participants = pd.read_csv(participants_file)
participants = participants[['Subject', 'ZygosityGT', 'Family_ID']]
print(len(participants))
participants.head()

1206


Unnamed: 0,Subject,ZygosityGT,Family_ID
0,100004,,52259_82122
1,100206,,56037_85858
2,100307,MZ,51488_81352
3,100408,MZ,51730_81594
4,100610,DZ,52813_82634


In [62]:
participants = pd.merge(participants_unrestricted, participants)
participants.loc[(participants['ZygosityGT']== " "), 'ZygosityGT'] = 'NotTwin'  
participants['Subject'] = participants['Subject'].astype('string')
print(len(participants))
participants.head()

1206


Unnamed: 0,Subject,Gender,ZygosityGT,Family_ID
0,100004,M,NotTwin,52259_82122
1,100206,M,NotTwin,56037_85858
2,100307,F,MZ,51488_81352
3,100408,M,MZ,51730_81594
4,100610,M,DZ,52813_82634


In [63]:
participants.dtypes

Subject       string[python]
Gender                object
ZygosityGT            object
Family_ID             object
dtype: object

In [64]:
treated_subjects = glob.glob("/neurospin/dico/data/bv_databases/human/not_labeled/hcp/hcp/*[!.minf]")
treated_subjects = [x.split('/')[-1] for x in treated_subjects]
treated_subjects = [x for x in treated_subjects if 'database' not in x]
print(treated_subjects[:5])
len(treated_subjects)
participants = participants[participants['Subject'].isin(treated_subjects)]
print(len(participants))
set(treated_subjects) - set(participants['Subject'])

['210112', '579665', '922854', '517239', '329440']
1113


{'142626'}

In [65]:
print(len(central))
print(central.dtypes)
central.index = central.index.astype(str)

883
Isomap_central_left_dim1      float64
Isomap_central_left_dim2      float64
Isomap_central_left_dim3      float64
Isomap_central_left_dim4      float64
Isomap_central_left_dim5      float64
Isomap_central_left_dim6      float64
Isomap_cingulate_left_dim1    float64
Isomap_cingulate_left_dim2    float64
Isomap_cingulate_left_dim3    float64
Isomap_cingulate_left_dim4    float64
Isomap_cingulate_left_dim5    float64
Isomap_cingulate_left_dim6    float64
dtype: object


In [66]:
set(central.index) - set(treated_subjects)

set()

In [67]:
central.index.isin(set(treated_subjects) - set(participants['Subject'])).sum()

1

In [68]:
central = pd.merge(participants, central, left_on='Subject', right_index=True).set_index('Subject')

In [69]:
central

Unnamed: 0_level_0,Gender,ZygosityGT,Family_ID,Isomap_central_left_dim1,Isomap_central_left_dim2,Isomap_central_left_dim3,Isomap_central_left_dim4,Isomap_central_left_dim5,Isomap_central_left_dim6,Isomap_cingulate_left_dim1,Isomap_cingulate_left_dim2,Isomap_cingulate_left_dim3,Isomap_cingulate_left_dim4,Isomap_cingulate_left_dim5,Isomap_cingulate_left_dim6
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
100206,M,NotTwin,56037_85858,-0.229650,7.389883,0.348397,-1.861678,0.842310,0.582308,2.254531,4.487012,-1.857427,2.059584,2.011947,3.053309
100307,F,MZ,51488_81352,1.573895,-2.751750,1.045617,-1.803867,0.096300,1.880110,0.271804,2.689471,2.810346,1.485902,0.797527,-5.502638
100408,M,MZ,51730_81594,-5.200937,-3.329636,-1.687213,-0.164033,-0.129700,0.668525,-3.087221,-2.509417,5.447575,0.570411,2.187741,5.075901
100610,M,DZ,52813_82634,2.131914,0.025349,-0.288740,-1.789072,-1.106825,-0.100832,4.672111,4.856687,1.527257,-0.033238,0.722929,-0.659395
101006,F,NotTwin,51283_52850_81149,0.411594,0.454800,-2.476906,1.865777,-1.723096,1.408883,11.017965,-1.182422,-4.016098,-3.912979,1.706526,-1.575350
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
992673,F,NotTwin,56143_85963,2.688382,7.092495,-2.360218,-0.831440,-1.927135,-0.885679,-6.090316,-2.531791,-1.879810,-0.426724,0.536265,2.508170
992774,M,NotTwin,51345_81210,3.216161,-1.978102,-3.380473,-0.248682,0.209391,-0.840647,1.895446,6.396108,0.609038,-3.159835,-4.150092,-1.000147
993675,F,NotTwin,55800_85621,1.654016,1.375164,0.233411,-0.785303,-1.868195,-2.888935,8.262977,4.142512,-2.316728,-2.291203,1.612961,0.830513
994273,M,NotTwin,52364_82227,1.177521,4.096813,-1.256984,3.123394,0.146602,1.708059,-4.200693,-4.050652,-0.542605,-3.780800,-2.132015,0.958736


# Split

In [70]:
def print_results(parent, folds, col, verbose=True):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    if verbose:
        print("query   : #rows      : #rows per fold\n")

    for col0 in parent[col[0]].unique():
            df = parent.query(f"{col[0]}==@col0")
            len_query = len(df)
            if verbose:
                print(f"{col0} : total = {len_query} : per fold =", end = ' ')
            for fold in folds:
                df0 = fold.query(f"{col[0]}==@col0")
                len_query_fold = len(df0)
                if abs(len_query_fold-len_query/n_splits) >= 2:
                    total_errors += 1
                if verbose:
                    print(f"{len_query_fold} -", end= ' ')
            if verbose:
                print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 2:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

In [71]:
def how_many_common_families(total, folds):
    nb_common = 0
    for i, source in enumerate(folds):
        for j, target in enumerate(folds):
            if j > i:
                nb_common += len(set(target.Family_ID).intersection(set(source.Family_ID)))
    print(f"number of common families = {nb_common} over {len(set(total.Family_ID))} total families")

# Step 1 : split in 5 folds, keep last for test

In [72]:
save_path = f"/neurospin/dico/data/deep_folding/current/datasets/hcp/Isomap/splits/"

In [73]:
n_splits = 5
df = central[['Gender', 'Family_ID']]

# group based on family
groups = df['Family_ID'].to_numpy()

# labels are sex only
labels = df['Gender'].to_numpy()

In [74]:
stratifier = StratifiedGroupKFold(
    n_splits=n_splits, shuffle=False)
results = []
for indices in stratifier.split(df.to_numpy(), labels, groups):
    results.append(df.iloc[indices[1]])

In [75]:
print_results(central, results, ['Gender'], True)
how_many_common_families(central, results)

query   : #rows      : #rows per fold

M : total = 388 : per fold = 78 - 78 - 78 - 77 - 77 - 
F : total = 494 : per fold = 99 - 99 - 99 - 98 - 99 - 

lengths of folds :  177 177 177 175 176 
Expected total_length = 882
Effective total_length = 882
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 0 over 378 total families


In [76]:
results[-1].reset_index()['Subject'].to_csv(
    f"{save_path}/test_split.csv",
    header=False,
    index=False,
    quoting=csv.QUOTE_ALL)

# Step 2 : split remaining 80% in 5 for cross val and train eval

In [77]:
central_train_val = central.drop(index=results[-1].index.to_list())

In [78]:
# repeat the process
n_splits = 5
df = central_train_val[['Gender', 'Family_ID']]

# group based on family
groups = df['Family_ID'].to_numpy()

# labels are sex only
labels = df['Gender'].to_numpy()

In [79]:
stratifier = StratifiedGroupKFold(
    n_splits=n_splits, shuffle=False)
results = []
for indices in stratifier.split(df.to_numpy(), labels, groups):
    results.append(df.iloc[indices[1]])

In [80]:
print_results(central_train_val, results, ['Gender'], True)
how_many_common_families(central_train_val, results)

query   : #rows      : #rows per fold

F : total = 395 : per fold = 79 - 79 - 79 - 79 - 79 - 
M : total = 311 : per fold = 62 - 62 - 62 - 62 - 63 - 

lengths of folds :  141 141 141 141 142 
Expected total_length = 706
Effective total_length = 706
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 0 over 302 total families


In [81]:
# save folds
for i in range(len(results)):
    results[i].reset_index()['Subject'].to_csv(
        f"{save_path}/train_val_split_{i}.csv",
        header=False,
        index=False,
        quoting=csv.QUOTE_ALL)