In [70]:
import pandas as pd
import numpy as np
import skmultilearn.model_selection
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.datasets import make_multilabel_classification
from skmultilearn.model_selection import IterativeStratification
# from sklearn.model_selection import StratifiedGroupKFold
import random
import csv
import glob
import sklearn

In [71]:
from collections import defaultdict

from sklearn.utils import (
    check_random_state
)

from sklearn.utils.validation import _num_samples, check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target

In [72]:
class StratifiedGroupKFold_WithoutCheck(sklearn.model_selection._split.GroupsConsumerMixin, sklearn.model_selection._split._BaseKFold):
    """Stratified K-Fold iterator variant with non-overlapping groups.

    This cross-validation object is a variation of StratifiedKFold attempts to
    return stratified folds with non-overlapping groups. The folds are made by
    preserving the percentage of samples for each class.

    Each group will appear exactly once in the test set across all folds (the
    number of distinct groups has to be at least equal to the number of folds).

    The difference between :class:`~sklearn.model_selection.GroupKFold`
    and :class:`~sklearn.model_selection.StratifiedGroupKFold` is that
    the former attempts to create balanced folds such that the number of
    distinct groups is approximately the same in each fold, whereas
    StratifiedGroupKFold attempts to create folds which preserve the
    percentage of samples for each class as much as possible given the
    constraint of non-overlapping groups between splits.

    Read more in the :ref:`User Guide <cross_validation>`.

    For visualisation of cross-validation behaviour and
    comparison between common scikit-learn split methods
    refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    shuffle : bool, default=False
        Whether to shuffle each class's samples before splitting into batches.
        Note that the samples within each split will not be shuffled.
        This implementation can only shuffle groups that have approximately the
        same y distribution, no global shuffle will be performed.

    random_state : int or RandomState instance, default=None
        When `shuffle` is True, `random_state` affects the ordering of the
        indices, which controls the randomness of each fold for each class.
        Otherwise, leave `random_state` as `None`.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import StratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> sgkf = StratifiedGroupKFold(n_splits=3)
    >>> sgkf.get_n_splits(X, y)
    3
    >>> print(sgkf)
    StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False)
    >>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
    ...     print(f"Fold {i}:")
    ...     print(f"  Train: index={train_index}")
    ...     print(f"         group={groups[train_index]}")
    ...     print(f"  Test:  index={test_index}")
    ...     print(f"         group={groups[test_index]}")
    Fold 0:
      Train: index=[ 0  1  2  3  7  8  9 10 11 15 16]
             group=[1 1 2 2 4 5 5 5 5 8 8]
      Test:  index=[ 4  5  6 12 13 14]
             group=[3 3 3 6 6 7]
    Fold 1:
      Train: index=[ 4  5  6  7  8  9 10 11 12 13 14]
             group=[3 3 3 4 5 5 5 5 6 6 7]
      Test:  index=[ 0  1  2  3 15 16]
             group=[1 1 2 2 8 8]
    Fold 2:
      Train: index=[ 0  1  2  3  4  5  6 12 13 14 15 16]
             group=[1 1 2 2 3 3 3 6 6 7 8 8]
      Test:  index=[ 7  8  9 10 11]
             group=[4 5 5 5 5]

    Notes
    -----
    The implementation is designed to:

    * Mimic the behavior of StratifiedKFold as much as possible for trivial
      groups (e.g. when each group contains only one sample).
    * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
      ``y = [1, 0]`` should not change the indices generated.
    * Stratify based on samples as much as possible while keeping
      non-overlapping groups constraint. That means that in some cases when
      there is a small number of groups containing a large number of samples
      the stratification will not be possible and the behavior will be close
      to GroupKFold.

    See also
    --------
    StratifiedKFold: Takes class information into account to build folds which
        retain class distributions (for binary or multiclass classification
        tasks).

    GroupKFold: K-fold iterator variant with non-overlapping groups.
    """

    def __init__(self, n_splits=5, shuffle=False, random_state=None):
        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

    def _iter_test_indices(self, X, y, groups):
        # Implementation is based on this kaggle kernel:
        # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
        # and is a subject to Apache 2.0 License. You may obtain a copy of the
        # License at http://www.apache.org/licenses/LICENSE-2.0
        # Changelist:
        # - Refactored function to a class following scikit-learn KFold
        #   interface.
        # - Added heuristic for assigning group to the least populated fold in
        #   cases when all other criteria are equal
        # - Swtch from using python ``Counter`` to ``np.unique`` to get class
        #   distribution
        # - Added scikit-learn checks for input: checking that target is binary
        #   or multiclass, checking passed random state, checking that number
        #   of splits is less than number of members in each class, checking
        #   that least populated class has more members than there are splits.
        rng = check_random_state(self.random_state)
        y = np.asarray(y)
        type_of_target_y = type_of_target(y)
        allowed_target_types = ("binary", "multiclass")
        if type_of_target_y not in allowed_target_types:
            raise ValueError(
                "Supported target types are: {}. Got {!r} instead.".format(
                    allowed_target_types, type_of_target_y
                )
            )

        y = column_or_1d(y)
        _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
        # if np.all(self.n_splits > y_cnt):
        #     raise ValueError(
        #         "n_splits=%d cannot be greater than the"
        #         " number of members in each class." % (self.n_splits)
        #     )
        # n_smallest_class = np.min(y_cnt)
        # if self.n_splits > n_smallest_class:
        #     warnings.warn(
        #         "The least populated class in y has only %d"
        #         " members, which is less than n_splits=%d."
        #         % (n_smallest_class, self.n_splits),
        #         UserWarning,
        #     )
        n_classes = len(y_cnt)

        _, groups_inv, groups_cnt = np.unique(
            groups, return_inverse=True, return_counts=True
        )
        y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
        for class_idx, group_idx in zip(y_inv, groups_inv):
            y_counts_per_group[group_idx, class_idx] += 1

        y_counts_per_fold = np.zeros((self.n_splits, n_classes))
        groups_per_fold = defaultdict(set)

        if self.shuffle:
            rng.shuffle(y_counts_per_group)

        # Stable sort to keep shuffled order for groups with the same
        # class distribution variance
        sorted_groups_idx = np.argsort(
            -np.std(y_counts_per_group, axis=1), kind="mergesort"
        )

        for group_idx in sorted_groups_idx:
            group_y_counts = y_counts_per_group[group_idx]
            best_fold = self._find_best_fold(
                y_counts_per_fold=y_counts_per_fold,
                y_cnt=y_cnt,
                group_y_counts=group_y_counts,
            )
            y_counts_per_fold[best_fold] += group_y_counts
            groups_per_fold[best_fold].add(group_idx)

        for i in range(self.n_splits):
            test_indices = [
                idx
                for idx, group_idx in enumerate(groups_inv)
                if group_idx in groups_per_fold[i]
            ]
            yield test_indices

    def _find_best_fold(self, y_counts_per_fold, y_cnt, group_y_counts):
        best_fold = None
        min_eval = np.inf
        min_samples_in_fold = np.inf
        for i in range(self.n_splits):
            y_counts_per_fold[i] += group_y_counts
            # Summarise the distribution over classes in each proposed fold
            std_per_class = np.std(y_counts_per_fold / y_cnt.reshape(1, -1), axis=0)
            y_counts_per_fold[i] -= group_y_counts
            fold_eval = np.mean(std_per_class)
            samples_in_fold = np.sum(y_counts_per_fold[i])
            is_current_fold_better = (
                fold_eval < min_eval
                or np.isclose(fold_eval, min_eval)
                and samples_in_fold < min_samples_in_fold
            )
            if is_current_fold_better:
                min_eval = fold_eval
                min_samples_in_fold = samples_in_fold
                best_fold = i
        return best_fold

# Loads and manipulates

In [73]:
orb = pd.read_csv("/neurospin/dico/data/deep_folding/current/datasets/hcp/Handedness/handedness_labels.csv", index_col=0)

In [74]:
orb

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
100004,95,M,1,1.0
100206,65,M,1,
100307,95,F,1,1.0
100408,55,M,1,
100610,85,M,1,1.0
...,...,...,...,...
992774,100,M,1,1.0
993675,85,F,1,1.0
994273,60,M,1,
995174,100,M,1,1.0


In [75]:
participants_file = "/neurospin/dico/jchavas/RESTRICTED_jchavas_1_18_2022_3_17_51.csv"
participants = pd.read_csv(participants_file, index_col=0)
participants = participants[['ZygosityGT', 'Family_ID']]
print(len(participants))
participants.head()

1206


Unnamed: 0_level_0,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1
100004,,52259_82122
100206,,56037_85858
100307,MZ,51488_81352
100408,MZ,51730_81594
100610,DZ,52813_82634


In [76]:
participants = pd.merge(orb, participants, left_index=True, right_index=True)
participants.loc[(participants['ZygosityGT']== " "), 'ZygosityGT'] = 'NotTwin'  
print(len(participants))
participants.head()

1206


Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100004,95,M,1,1.0,NotTwin,52259_82122
100206,65,M,1,,NotTwin,56037_85858
100307,95,F,1,1.0,MZ,51488_81352
100408,55,M,1,,MZ,51730_81594
100610,85,M,1,1.0,DZ,52813_82634


In [77]:
participants

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100004,95,M,1,1.0,NotTwin,52259_82122
100206,65,M,1,,NotTwin,56037_85858
100307,95,F,1,1.0,MZ,51488_81352
100408,55,M,1,,MZ,51730_81594
100610,85,M,1,1.0,DZ,52813_82634
...,...,...,...,...,...,...
992774,100,M,1,1.0,NotTwin,51345_81210
993675,85,F,1,1.0,NotTwin,55800_85621
994273,60,M,1,,NotTwin,52364_82227
995174,100,M,1,1.0,MZ,55923_85743


In [78]:
participants.index = participants.index.astype(str)

In [79]:
participants.loc[["100004"],:]

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100004,95,M,1,1.0,NotTwin,52259_82122


In [80]:
participants

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100004,95,M,1,1.0,NotTwin,52259_82122
100206,65,M,1,,NotTwin,56037_85858
100307,95,F,1,1.0,MZ,51488_81352
100408,55,M,1,,MZ,51730_81594
100610,85,M,1,1.0,DZ,52813_82634
...,...,...,...,...,...,...
992774,100,M,1,1.0,NotTwin,51345_81210
993675,85,F,1,1.0,NotTwin,55800_85621
994273,60,M,1,,NotTwin,52364_82227
995174,100,M,1,1.0,MZ,55923_85743


In [81]:
treated_subjects = glob.glob("/neurospin/dico/data/bv_databases/human/not_labeled/hcp/hcp/*[!.minf]")
treated_subjects = [x.split('/')[-1] for x in treated_subjects]
treated_subjects = [x for x in treated_subjects if 'database' not in x]
treated_subjects = [x for x in treated_subjects if x in participants.index]
print(treated_subjects[:5])
len(treated_subjects)
participants = participants.loc[treated_subjects,:]
print(len(participants))
set(treated_subjects) - set(participants.index)

['210112', '579665', '922854', '517239', '329440']
1113


set()

In [82]:
participants = participants.sort_index()

In [83]:
participants.head()

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100206,65,M,1,,NotTwin,56037_85858
100307,95,F,1,1.0,MZ,51488_81352
100408,55,M,1,,MZ,51730_81594
100610,85,M,1,1.0,DZ,52813_82634
101006,90,F,1,1.0,NotTwin,51283_52850_81149


In [84]:
orb = participants.copy(deep=True)

# Makes split without considering twins

In [85]:
orb[["Gender"]].value_counts()

Gender
F         606
M         507
Name: count, dtype: int64

In [86]:
orb.query("isRightHanded==1 and isStronglyRightHanded==1 and Gender=='M'")

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100610,85,M,1,1.0,DZ,52813_82634
101410,75,M,1,1.0,NotTwin,52198_82061
102008,80,M,1,1.0,NotTwin,52018_81882
102614,80,M,1,1.0,NotTwin,52838_82659
102715,95,M,1,1.0,NotTwin,51469_81334
...,...,...,...,...,...,...
989987,95,M,1,1.0,NotTwin,52040_81904
990366,95,M,1,1.0,NotTwin,56064_85885
991267,75,M,1,1.0,NotTwin,51639_81503
992774,100,M,1,1.0,NotTwin,51345_81210


In [87]:
orb["isStronglyRightHanded"] = orb["isStronglyRightHanded"].astype('Int64').fillna(-1)

In [88]:
orb

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100206,65,M,1,-1,NotTwin,56037_85858
100307,95,F,1,1,MZ,51488_81352
100408,55,M,1,-1,MZ,51730_81594
100610,85,M,1,1,DZ,52813_82634
101006,90,F,1,1,NotTwin,51283_52850_81149
...,...,...,...,...,...,...
992774,100,M,1,1,NotTwin,51345_81210
993675,85,F,1,1,NotTwin,55800_85621
994273,60,M,1,-1,NotTwin,52364_82227
995174,100,M,1,1,MZ,55923_85743


In [89]:
def print_frequencies(df):
    for right in df['isRightHanded'].unique():
        for strong_right in df['isStronglyRightHanded'].unique():
            for gender in df['Gender'].unique():
                freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
                print(f"{right}, {strong_right}, {gender}: {len(freq)}")

In [90]:
print_frequencies(orb)

1, -1, M: 184
1, -1, F: 138
1, 1, M: 269
1, 1, F: 421
1, 0, M: 0
1, 0, F: 0
0, -1, M: 40
0, -1, F: 27
0, 1, M: 0
0, 1, F: 0
0, 0, M: 14
0, 0, F: 20


  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHanded==@strong_right and Gender==@gender")
  freq = df.query("isRightHanded==@right and isStronglyRightHand

In [91]:
def print_results(parent, folds, col, verbose=True):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    if verbose:
        print("query   : #rows      : #rows per fold\n")

    for col0 in parent[col[0]].unique():
        for col1 in parent[col[1]].unique():
            for col2 in parent[col[2]].unique():
                df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2", engine='python')
                len_query = len(df)
                if verbose:
                    print(f"{col0}, {col1}, {col2} : total = {len_query} : per fold =", end = ' ')
                for fold in folds:
                    df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2", engine='python')
                    len_query_fold = len(df0)
                    if abs(len_query_fold-len_query/n_splits) >= 2:
                        total_errors += 1
                    if verbose:
                        print(f"{len_query_fold} -", end= ' ')
                if verbose:
                    print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 2:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

In [92]:
def iterative_split(df, folds, n_splits, stratify_columns):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = [pd.get_dummies(df[col]) for col in stratify_columns]
    one_hot_cols = pd.concat(one_hot_cols, axis=1).to_numpy()
    stratifier = IterativeStratification(
        n_splits=n_splits, order=len(stratify_columns), sample_distribution_per_fold=folds)
    folds = []
    for indices in stratifier.split(df.to_numpy(), one_hot_cols):
        folds.append(df.iloc[indices[1]])
    return folds

In [93]:
folds = iterative_split(orb, [0.2,]*5, 5, ['isRightHanded', 'isStronglyRightHanded', 'Gender'])

In [94]:
print(len(folds))

5


In [95]:
print_results(orb, folds, ['isRightHanded', 'isStronglyRightHanded', 'Gender'])

query   : #rows      : #rows per fold

1, -1, M : total = 184 : per fold = 37 - 36 - 37 - 37 - 37 - 
1, -1, F : total = 138 : per fold = 28 - 29 - 27 - 27 - 27 - 
1, 1, M : total = 269 : per fold = 54 - 55 - 54 - 53 - 53 - 
1, 1, F : total = 421 : per fold = 84 - 82 - 85 - 85 - 85 - 
1, 0, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
1, 0, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, -1, M : total = 40 : per fold = 8 - 8 - 8 - 8 - 8 - 
0, -1, F : total = 27 : per fold = 5 - 5 - 6 - 5 - 6 - 
0, 1, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 1, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 0, M : total = 14 : per fold = 3 - 3 - 3 - 2 - 3 - 
0, 0, F : total = 20 : per fold = 4 - 4 - 4 - 4 - 4 - 

lengths of folds :  223 222 224 221 223 
Expected total_length = 1113
Effective total_length = 1113
total number of stratification errors: 1
total number of mismatched fold sizes : 0


In [96]:
def iterative_split_through_sorting_shuffle(df, n_splits, stratify_columns, random_state):
    """Custom iterative train test split which
    maintains balanced representation.
    """
    # Dataframe random row shuffle + sorting according to stratify_columns
    sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
    # for each fold, we take one row every n_splits rows
    folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
    # Further shuffling
    folds = [fold.sample(frac=1, random_state=random_state) for fold in folds]
    random.Random(random_state).shuffle(folds)
    return folds

In [97]:
folds = iterative_split_through_sorting_shuffle(orb, 5, ['isRightHanded', 'isStronglyRightHanded', 'Gender'], 1)

In [98]:
print_results(orb, folds, ['isRightHanded', 'isStronglyRightHanded', 'Gender'])

query   : #rows      : #rows per fold

1, -1, M : total = 184 : per fold = 37 - 36 - 37 - 37 - 37 - 
1, -1, F : total = 138 : per fold = 28 - 28 - 27 - 27 - 28 - 
1, 1, M : total = 269 : per fold = 54 - 53 - 54 - 54 - 54 - 
1, 1, F : total = 421 : per fold = 84 - 85 - 84 - 84 - 84 - 
1, 0, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
1, 0, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, -1, M : total = 40 : per fold = 8 - 8 - 8 - 8 - 8 - 
0, -1, F : total = 27 : per fold = 5 - 5 - 5 - 6 - 6 - 
0, 1, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 1, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 0, M : total = 14 : per fold = 3 - 3 - 3 - 3 - 2 - 
0, 0, F : total = 20 : per fold = 4 - 4 - 4 - 4 - 4 - 

lengths of folds :  223 222 222 223 223 
Expected total_length = 1113
Effective total_length = 1113
total number of stratification errors: 0
total number of mismatched fold sizes : 0


# Makes splits by considering zygosity

In [99]:
set(orb['ZygosityGT'].tolist())

{'DZ', 'MZ', 'NotTwin'}

In [100]:
orb.sort_values("Family_ID")

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
257946,100,F,1,1,NotTwin,50263_80216
962058,75,M,1,1,NotTwin,50263_80216
571144,70,M,1,-1,NotTwin,50371_80310
517239,45,M,1,-1,NotTwin,50373_80312
213017,45,M,1,-1,NotTwin,50373_80312
...,...,...,...,...,...,...
618952,100,M,1,1,NotTwin,56202_86021
650746,95,M,1,1,NotTwin,99987_99988
516742,100,M,1,1,NotTwin,99989_99990
114823,40,F,1,-1,NotTwin,99996_99997


In [101]:
l1 = [2, 3, 4]
l2 = [3, 4, 5]
print(set(l1)-set(l2))
print(set(l2)-set(l1))
print(set(l1).intersection(set(l2)))

{2}
{5}
{3, 4}


In [102]:
def how_many_common_families(total, folds):
    nb_common = 0
    for i, source in enumerate(folds):
        for j, target in enumerate(folds):
            if j > i:
                nb_common += len(set(target.Family_ID).intersection(set(source.Family_ID)))
    print(f"number of common families = {nb_common} over {len(set(total.Family_ID))} total families")


In [103]:
print_results(orb, folds, ['isRightHanded', 'isStronglyRightHanded', 'Gender'], False)
how_many_common_families(orb, folds)


lengths of folds :  223 222 222 223 223 
Expected total_length = 1113
Effective total_length = 1113
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 659 over 445 total families


In [104]:
def bring_together_families(folds):
    nb_common = 0
    results = [fold.copy() for fold in folds]
    for i in range(len(results)):
        for j in range(len(results)):
            if j > i:
                source = results[i]
                target = results[j]
                common = set(target.Family_ID).intersection(set(source.Family_ID))
                if np.random.randint(2):
                    results[i] = pd.concat([source, target[target.Family_ID.isin(common)]], axis=0)
                    results[j] = target[~target.Family_ID.isin(common)]
                else:
                    results[i] = source[~source.Family_ID.isin(common)]
                    results[j] = pd.concat([target, source[source.Family_ID.isin(common)]], axis=0)
    return results

In [105]:
results = bring_together_families(folds)

In [106]:
print_results(orb, results, ['isRightHanded', 'isStronglyRightHanded', 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

1, -1, M : total = 184 : per fold = 54 - 90 - 7 - 7 - 26 - 
1, -1, F : total = 138 : per fold = 44 - 61 - 8 - 4 - 21 - 
1, 1, M : total = 269 : per fold = 74 - 120 - 12 - 7 - 56 - 
1, 1, F : total = 421 : per fold = 133 - 203 - 22 - 5 - 58 - 
1, 0, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
1, 0, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, -1, M : total = 40 : per fold = 15 - 14 - 0 - 1 - 10 - 
0, -1, F : total = 27 : per fold = 10 - 10 - 3 - 1 - 3 - 
0, 1, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 1, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 0, M : total = 14 : per fold = 4 - 6 - 1 - 0 - 3 - 
0, 0, F : total = 20 : per fold = 7 - 10 - 1 - 0 - 2 - 

lengths of folds :  341 514 54 25 179 
Expected total_length = 1113
Effective total_length = 1113
total number of stratification errors: 37
total number of mismatched fold sizes : 5
number of common families = 0 over 445 total families


In [107]:
def stratified_group_kfold_split(df, n_splits, stratify_columns, group_columns, random_state):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = df[group_columns].to_numpy()
    one_hot_grps = df[group_columns].to_numpy()
    print(one_hot_grps.shape)
    stratifier = StratifiedGroupKFold_WithoutCheck(
        n_splits=n_splits, shuffle=True, random_state=random_state)
    folds = []
    for indices in stratifier.split(df.to_numpy(), one_hot_cols, one_hot_grps):
        folds.append(df.iloc[indices[1]])
    return folds

In [108]:
results = stratified_group_kfold_split(orb, 5, ['isRightHanded', 'isStronglyRightHanded', 'Gender'], ['Family_ID'], 4)

(1113, 1)


In [109]:
print_results(orb, results, ['isRightHanded', 'isStronglyRightHanded', 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

1, -1, M : total = 184 : per fold = 38 - 34 - 36 - 39 - 37 - 
1, -1, F : total = 138 : per fold = 27 - 25 - 33 - 31 - 22 - 
1, 1, M : total = 269 : per fold = 40 - 59 - 64 - 44 - 62 - 
1, 1, F : total = 421 : per fold = 94 - 77 - 84 - 79 - 87 - 
1, 0, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
1, 0, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, -1, M : total = 40 : per fold = 8 - 11 - 8 - 5 - 8 - 
0, -1, F : total = 27 : per fold = 6 - 4 - 6 - 3 - 8 - 
0, 1, M : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 1, F : total = 0 : per fold = 0 - 0 - 0 - 0 - 0 - 
0, 0, M : total = 14 : per fold = 4 - 3 - 2 - 1 - 4 - 
0, 0, F : total = 20 : per fold = 5 - 6 - 3 - 4 - 2 - 

lengths of folds :  222 219 236 206 230 
Expected total_length = 1113
Effective total_length = 1113
total number of stratification errors: 21
total number of mismatched fold sizes : 4
number of common families = 0 over 445 total families


In [110]:
results[0].head()

Unnamed: 0_level_0,Handedness,Gender,isRightHanded,isStronglyRightHanded,ZygosityGT,Family_ID
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
100307,95,F,1,1,MZ,51488_81352
100408,55,M,1,-1,MZ,51730_81594
101107,5,M,1,-1,NotTwin,51969_81833
101410,75,M,1,1,NotTwin,52198_82061
101915,-75,F,0,0,NotTwin,51977_81841


# Save results

In [111]:
save_path = "/neurospin/dico/data/deep_folding/current/datasets/hcp/Handedness"

In [112]:
for i in range(len(results)):
    results[i].reset_index()['Subject'].to_csv(
        f"{save_path}/split_{i}.csv",
        header=False,
        index=False,
        quoting=csv.QUOTE_ALL)