In [1]:
import pandas as pd
import numpy as np
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.datasets import make_multilabel_classification
from skmultilearn.model_selection import IterativeStratification
# from sklearn.model_selection import StratifiedGroupKFold
import random
import csv
import glob
import sklearn
import os

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [4]:
from collections import defaultdict

from sklearn.utils import (
    _safe_indexing,
    check_random_state,
    indexable,
    metadata_routing,
)

from sklearn.utils.validation import _num_samples, check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target

In [5]:
# FIP
n_splits=10
n_test_splits = 2
fip = pd.read_excel('/neurospin/dico/data/bv_databases/human/partially_labeled/FIP_patterns/IPS_labels_390.xlsx')
fip = fip.dropna()
fip.columns = ['Subject', 'Sex', 'Left', 'Right']
fip = fip.set_index('Subject')

In [6]:
def print_frequencies(df):
    for right in df['Right'].unique():
        for left in df['Left'].unique():
            for sex in df['Sex'].unique():
                freq = df.query("Right==@right and Left==@left and Sex==@sex")
                print(f"{right}, {left}, {sex}: {len(freq)}")

In [7]:
print_frequencies(fip)

1, 0, M: 44
1, 0, F: 74
1, 1, M: 96
1, 1, F: 56
0, 0, M: 21
0, 0, F: 66
0, 1, M: 12
0, 1, F: 21


In [12]:
def print_results(parent, folds, col, verbose=True):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    if verbose:
        print("query   : #rows      : #rows per fold\n")

    for col0 in parent[col[0]].unique():
        for col1 in parent[col[1]].unique():
            #for col2 in parent[col[2]].unique():
            #df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
            df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1")
            len_query = len(df)
            if verbose:
                #print(f"{col0}, {col1}, {col2} : total = {len_query} : per fold =", end = ' ')
                print(f"{col0}, {col1} : total = {len_query} : per fold =", end = ' ')
            for fold in folds:
                #df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
                df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1")
                len_query_fold = len(df0)
                if abs(len_query_fold-len_query/n_splits) >= 2:
                    total_errors += 1
                if verbose:
                    print(f"{len_query_fold} -", end= ' ')
            if verbose:
                print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 2:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

In [9]:
def iterative_split_through_sorting_shuffle(df, n_splits, stratify_columns, random_state):
    """Custom iterative train test split which
    maintains balanced representation.
    """
    # Dataframe random row shuffle + sorting according to stratify_columns
    sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
    # for each fold, we take one row every n_splits rows
    folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
    # Further shuffling
    folds = [fold.sample(frac=1, random_state=random_state) for fold in folds]
    random.Random(random_state).shuffle(folds)
    return folds

# Step 1 : split in 5 folds, keep last for test

In [23]:
side = 'Left'

In [24]:
# Save Results
save_path = f"/neurospin/dico/data/deep_folding/current/datasets/hcp/FIP/{side}"

In [25]:
n_splits=5
results = iterative_split_through_sorting_shuffle(fip, n_splits, [side, 'Sex'], 1)
print_results(fip, results, [side, 'Sex'])

query   : #rows      : #rows per fold

0, M : total = 65 : per fold = 13 - 13 - 13 - 13 - 13 - 
0, F : total = 140 : per fold = 28 - 28 - 28 - 28 - 28 - 
1, M : total = 108 : per fold = 22 - 22 - 22 - 21 - 21 - 
1, F : total = 77 : per fold = 15 - 15 - 15 - 16 - 16 - 

lengths of folds :  78 78 78 78 78 
Expected total_length = 390
Effective total_length = 390
total number of stratification errors: 0
total number of mismatched fold sizes : 0


In [26]:
results[-1].reset_index()['Subject'].to_csv(
    f"{save_path}/test_split.csv",
    header=False,
    index=False,
    quoting=csv.QUOTE_ALL)

# Step 2 : split remaining 80% in 5 for cross val and train eval

In [27]:
fip_train_val = fip.drop(index=results[-1].index.to_list())

In [28]:
n_splits=5
results = iterative_split_through_sorting_shuffle(fip_train_val, n_splits, [side, 'Sex'], 1)
print_results(fip_train_val, results, [side, 'Sex'])

query   : #rows      : #rows per fold

0, M : total = 52 : per fold = 11 - 11 - 10 - 10 - 10 - 
0, F : total = 112 : per fold = 22 - 22 - 22 - 23 - 23 - 
1, M : total = 87 : per fold = 17 - 17 - 17 - 18 - 18 - 
1, F : total = 61 : per fold = 12 - 12 - 13 - 12 - 12 - 

lengths of folds :  62 62 62 63 63 
Expected total_length = 312
Effective total_length = 312
total number of stratification errors: 0
total number of mismatched fold sizes : 0


In [29]:
# save folds
for i in range(len(results)):
    results[i].reset_index()['Subject'].to_csv(
        f"{save_path}/train_val_split_{i}.csv",
        header=False,
        index=False,
        quoting=csv.QUOTE_ALL)