In [5]:
import pandas as pd
import numpy as np
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.datasets import make_multilabel_classification
from skmultilearn.model_selection import IterativeStratification
from sklearn.model_selection import StratifiedGroupKFold
import random
import csv
import glob
import sklearn

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [6]:
from collections import defaultdict

from sklearn.utils import (
    #_approximate_mode,
    _safe_indexing,
    check_random_state,
    indexable,
    metadata_routing,
)

from sklearn.utils.validation import _num_samples, check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target

In [7]:
class StratifiedGroupKFold_WithoutCheck(sklearn.model_selection._split.GroupsConsumerMixin, sklearn.model_selection._split._BaseKFold):
    """Stratified K-Fold iterator variant with non-overlapping groups.

    This cross-validation object is a variation of StratifiedKFold attempts to
    return stratified folds with non-overlapping groups. The folds are made by
    preserving the percentage of samples for each class.

    Each group will appear exactly once in the test set across all folds (the
    number of distinct groups has to be at least equal to the number of folds).

    The difference between :class:`~sklearn.model_selection.GroupKFold`
    and :class:`~sklearn.model_selection.StratifiedGroupKFold` is that
    the former attempts to create balanced folds such that the number of
    distinct groups is approximately the same in each fold, whereas
    StratifiedGroupKFold attempts to create folds which preserve the
    percentage of samples for each class as much as possible given the
    constraint of non-overlapping groups between splits.

    Read more in the :ref:`User Guide <cross_validation>`.

    For visualisation of cross-validation behaviour and
    comparison between common scikit-learn split methods
    refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`

    Parameters
    ----------
    n_splits : int, default=5
        Number of folds. Must be at least 2.

    shuffle : bool, default=False
        Whether to shuffle each class's samples before splitting into batches.
        Note that the samples within each split will not be shuffled.
        This implementation can only shuffle groups that have approximately the
        same y distribution, no global shuffle will be performed.

    random_state : int or RandomState instance, default=None
        When `shuffle` is True, `random_state` affects the ordering of the
        indices, which controls the randomness of each fold for each class.
        Otherwise, leave `random_state` as `None`.
        Pass an int for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import StratifiedGroupKFold
    >>> X = np.ones((17, 2))
    >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
    >>> sgkf = StratifiedGroupKFold(n_splits=3)
    >>> sgkf.get_n_splits(X, y)
    3
    >>> print(sgkf)
    StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False)
    >>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
    ...     print(f"Fold {i}:")
    ...     print(f"  Train: index={train_index}")
    ...     print(f"         group={groups[train_index]}")
    ...     print(f"  Test:  index={test_index}")
    ...     print(f"         group={groups[test_index]}")
    Fold 0:
      Train: index=[ 0  1  2  3  7  8  9 10 11 15 16]
             group=[1 1 2 2 4 5 5 5 5 8 8]
      Test:  index=[ 4  5  6 12 13 14]
             group=[3 3 3 6 6 7]
    Fold 1:
      Train: index=[ 4  5  6  7  8  9 10 11 12 13 14]
             group=[3 3 3 4 5 5 5 5 6 6 7]
      Test:  index=[ 0  1  2  3 15 16]
             group=[1 1 2 2 8 8]
    Fold 2:
      Train: index=[ 0  1  2  3  4  5  6 12 13 14 15 16]
             group=[1 1 2 2 3 3 3 6 6 7 8 8]
      Test:  index=[ 7  8  9 10 11]
             group=[4 5 5 5 5]

    Notes
    -----
    The implementation is designed to:

    * Mimic the behavior of StratifiedKFold as much as possible for trivial
      groups (e.g. when each group contains only one sample).
    * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
      ``y = [1, 0]`` should not change the indices generated.
    * Stratify based on samples as much as possible while keeping
      non-overlapping groups constraint. That means that in some cases when
      there is a small number of groups containing a large number of samples
      the stratification will not be possible and the behavior will be close
      to GroupKFold.

    See also
    --------
    StratifiedKFold: Takes class information into account to build folds which
        retain class distributions (for binary or multiclass classification
        tasks).

    GroupKFold: K-fold iterator variant with non-overlapping groups.
    """

    def __init__(self, n_splits=5, shuffle=False, random_state=None):
        super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

    def _iter_test_indices(self, X, y, groups):
        # Implementation is based on this kaggle kernel:
        # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
        # and is a subject to Apache 2.0 License. You may obtain a copy of the
        # License at http://www.apache.org/licenses/LICENSE-2.0
        # Changelist:
        # - Refactored function to a class following scikit-learn KFold
        #   interface.
        # - Added heuristic for assigning group to the least populated fold in
        #   cases when all other criteria are equal
        # - Swtch from using python ``Counter`` to ``np.unique`` to get class
        #   distribution
        # - Added scikit-learn checks for input: checking that target is binary
        #   or multiclass, checking passed random state, checking that number
        #   of splits is less than number of members in each class, checking
        #   that least populated class has more members than there are splits.
        rng = check_random_state(self.random_state)
        y = np.asarray(y)
        type_of_target_y = type_of_target(y)
        allowed_target_types = ("binary", "multiclass")
        if type_of_target_y not in allowed_target_types:
            raise ValueError(
                "Supported target types are: {}. Got {!r} instead.".format(
                    allowed_target_types, type_of_target_y
                )
            )

        y = column_or_1d(y)
        _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
        # if np.all(self.n_splits > y_cnt):
        #     raise ValueError(
        #         "n_splits=%d cannot be greater than the"
        #         " number of members in each class." % (self.n_splits)
        #     )
        # n_smallest_class = np.min(y_cnt)
        # if self.n_splits > n_smallest_class:
        #     warnings.warn(
        #         "The least populated class in y has only %d"
        #         " members, which is less than n_splits=%d."
        #         % (n_smallest_class, self.n_splits),
        #         UserWarning,
        #     )
        n_classes = len(y_cnt)

        _, groups_inv, groups_cnt = np.unique(
            groups, return_inverse=True, return_counts=True
        )
        y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
        for class_idx, group_idx in zip(y_inv, groups_inv):
            y_counts_per_group[group_idx, class_idx] += 1

        y_counts_per_fold = np.zeros((self.n_splits, n_classes))
        groups_per_fold = defaultdict(set)

        if self.shuffle:
            rng.shuffle(y_counts_per_group)

        # Stable sort to keep shuffled order for groups with the same
        # class distribution variance
        sorted_groups_idx = np.argsort(
            -np.std(y_counts_per_group, axis=1), kind="mergesort"
        )

        for group_idx in sorted_groups_idx:
            group_y_counts = y_counts_per_group[group_idx]
            best_fold = self._find_best_fold(
                y_counts_per_fold=y_counts_per_fold,
                y_cnt=y_cnt,
                group_y_counts=group_y_counts,
            )
            y_counts_per_fold[best_fold] += group_y_counts
            groups_per_fold[best_fold].add(group_idx)

        for i in range(self.n_splits):
            test_indices = [
                idx
                for idx, group_idx in enumerate(groups_inv)
                if group_idx in groups_per_fold[i]
            ]
            yield test_indices

    def _find_best_fold(self, y_counts_per_fold, y_cnt, group_y_counts):
        best_fold = None
        min_eval = np.inf
        min_samples_in_fold = np.inf
        for i in range(self.n_splits):
            y_counts_per_fold[i] += group_y_counts
            # Summarise the distribution over classes in each proposed fold
            std_per_class = np.std(y_counts_per_fold / y_cnt.reshape(1, -1), axis=0)
            y_counts_per_fold[i] -= group_y_counts
            fold_eval = np.mean(std_per_class)
            samples_in_fold = np.sum(y_counts_per_fold[i])
            is_current_fold_better = (
                fold_eval < min_eval
                or np.isclose(fold_eval, min_eval)
                and samples_in_fold < min_samples_in_fold
            )
            if is_current_fold_better:
                min_eval = fold_eval
                min_samples_in_fold = samples_in_fold
                best_fold = i
        return best_fold

# Loads and manipulates

In [8]:
orb = pd.read_csv("/neurospin/dico/data/bv_databases/human/partially_labeled/orbital_patterns/Troiani/HCP_OFC_for_JFM - Types_and_Subtypes.csv", index_col=0)
orb = orb.dropna()
orb["Right"] = orb["Right"].astype(int)
orb["Left"] = orb["Left"].astype(int)

In [9]:
participants_file = "/neurospin/dico/data/bv_databases/human/not_labeled/hcp/participants.csv"
participants_unrestricted = pd.read_csv(participants_file)
participants_unrestricted = participants_unrestricted[['Subject', "Gender"]]
print(len(participants_unrestricted))
participants_unrestricted.head()

1206


Unnamed: 0,Subject,Gender
0,100004,M
1,100206,M
2,100307,F
3,100408,M
4,100610,M


In [10]:
participants_file = "/neurospin/dico/jchavas/RESTRICTED_jchavas_1_18_2022_3_17_51.csv"
participants = pd.read_csv(participants_file)
participants = participants[['Subject', 'ZygosityGT', 'Family_ID']]
print(len(participants))
participants.head()

1206


Unnamed: 0,Subject,ZygosityGT,Family_ID
0,100004,,52259_82122
1,100206,,56037_85858
2,100307,MZ,51488_81352
3,100408,MZ,51730_81594
4,100610,DZ,52813_82634


In [11]:
participants = pd.merge(participants_unrestricted, participants)
participants.loc[(participants['ZygosityGT']== " "), 'ZygosityGT'] = 'NotTwin'  
participants['Subject'] = participants['Subject'].astype('string')
print(len(participants))
participants.head()

1206


Unnamed: 0,Subject,Gender,ZygosityGT,Family_ID
0,100004,M,NotTwin,52259_82122
1,100206,M,NotTwin,56037_85858
2,100307,F,MZ,51488_81352
3,100408,M,MZ,51730_81594
4,100610,M,DZ,52813_82634


In [12]:
participants.dtypes

Subject       string[python]
Gender                object
ZygosityGT            object
Family_ID             object
dtype: object

In [13]:
treated_subjects = glob.glob("/neurospin/dico/data/bv_databases/human/not_labeled/hcp/hcp/*[!.minf]")
treated_subjects = [x.split('/')[-1] for x in treated_subjects]
treated_subjects = [x for x in treated_subjects if 'database' not in x]
print(treated_subjects[:5])
len(treated_subjects)
participants = participants[participants['Subject'].isin(treated_subjects)]
print(len(participants))
set(treated_subjects) - set(participants['Subject'])

['210112', '579665', '922854', '517239', '329440']
1113


{'142626'}

In [14]:
print(len(orb))
print(orb.dtypes)
print(orb[["Right"]].value_counts(dropna=False))
print(orb[["Left"]].value_counts(dropna=False))
orb.index = orb.index.astype(str)

577
Right             int64
Right Subtype    object
Left              int64
Left Subtype     object
dtype: object
Right
1        339
2        120
3         99
4         19
Name: count, dtype: int64
Left
1       281
2       159
3       101
4        36
Name: count, dtype: int64


In [15]:
orb.head()

Unnamed: 0,Right,Right Subtype,Left,Left Subtype
100307,1,B,2,F
100408,1,N,2,E
100610,1,A,1,A
101006,1,B,1,A
101410,2,A,2,I


In [16]:
orb.index

Index(['100307', '100408', '100610', '101006', '101410', '102311', '102816',
       '103010', '103515', '103818',
       ...
       '966975', '969476', '971160', '972566', '973770', '984472', '987983',
       '990366', '995174', '996782'],
      dtype='object', length=577)

In [17]:
set(orb.index) - set(treated_subjects)

set()

In [18]:
orb.index.isin(set(treated_subjects) - set(participants['Subject'])).sum()

0

In [19]:
orb.head()

Unnamed: 0,Right,Right Subtype,Left,Left Subtype
100307,1,B,2,F
100408,1,N,2,E
100610,1,A,1,A
101006,1,B,1,A
101410,2,A,2,I


In [20]:
participants.head()

Unnamed: 0,Subject,Gender,ZygosityGT,Family_ID
1,100206,M,NotTwin,56037_85858
2,100307,F,MZ,51488_81352
3,100408,M,MZ,51730_81594
4,100610,M,DZ,52813_82634
5,101006,F,NotTwin,51283_52850_81149


In [21]:
orb = pd.merge(participants, orb, left_on='Subject', right_index=True).set_index('Subject')

# Makes split without considering twins

In [22]:
orb[["Gender"]].value_counts()

Gender
F         349
M         228
Name: count, dtype: int64

In [23]:
orb.query("Right==1 and Left==1 and Gender=='M'")

Unnamed: 0_level_0,Gender,ZygosityGT,Family_ID,Right,Right Subtype,Left,Left Subtype
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
100610,M,DZ,52813_82634,1,A,1,A
103010,M,MZ,55895_85715,1,A,1,A
106824,M,DZ,56183_86002,1,I,1,A
107422,M,MZ,54628_84450,1,B,1,A
108020,M,DZ,56200_86019,1,C,1,A
...,...,...,...,...,...,...,...
832651,M,DZ,55904_85724,1,N,1,E
871964,M,NotTwin,51852_81716,1,N,1,C
898176,M,DZ,51749_81613,1,F,1,N
926862,M,MZ,55763_85584,1,C,1,I


In [24]:
def print_frequencies(df):
    for right in df['Right'].unique():
        for left in df['Left'].unique():
            for gender in df['Gender'].unique():
                freq = df.query("Right==@right and Left==@left and Gender==@gender")
                print(f"{right}, {left}, {gender}: {len(freq)}")

In [25]:
print_frequencies(orb)

1, 2, F: 50
1, 2, M: 36
1, 1, F: 111
1, 1, M: 67
1, 3, F: 49
1, 3, M: 14
1, 4, F: 7
1, 4, M: 5
2, 2, F: 16
2, 2, M: 28
2, 1, F: 27
2, 1, M: 27
2, 3, F: 7
2, 3, M: 2
2, 4, F: 5
2, 4, M: 8
3, 2, F: 13
3, 2, M: 11
3, 1, F: 27
3, 1, M: 14
3, 3, F: 17
3, 3, M: 7
3, 4, F: 8
3, 4, M: 2
4, 2, F: 3
4, 2, M: 2
4, 1, F: 5
4, 1, M: 3
4, 3, F: 3
4, 3, M: 2
4, 4, F: 1
4, 4, M: 0


In [26]:
def print_results(parent, folds, col, verbose=True):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    if verbose:
        print("query   : #rows      : #rows per fold\n")

    for col0 in parent[col[0]].unique():
        for col1 in parent[col[1]].unique():
            #for col2 in parent[col[2]].unique():
            #df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
            df = parent.query(f"{col[0]}==@col0 and {col[1]}==@col1")
            len_query = len(df)
            if verbose:
                #print(f"{col0}, {col1}, {col2} : total = {len_query} : per fold =", end = ' ')
                print(f"{col0}, {col1} : total = {len_query} : per fold =", end = ' ')
            for fold in folds:
                #df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1 and {col[2]}==@col2")
                df0 = fold.query(f"{col[0]}==@col0 and {col[1]}==@col1")
                len_query_fold = len(df0)
                if abs(len_query_fold-len_query/n_splits) >= 2:
                    total_errors += 1
                if verbose:
                    print(f"{len_query_fold} -", end= ' ')
            if verbose:
                print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 2:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

In [27]:
def iterative_split(df, folds, n_splits, stratify_columns):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = [pd.get_dummies(df[col]) for col in stratify_columns]
    one_hot_cols = pd.concat(one_hot_cols, axis=1).to_numpy()
    stratifier = IterativeStratification(
        n_splits=n_splits, order=len(stratify_columns), sample_distribution_per_fold=folds)
    folds = []
    for indices in stratifier.split(df.to_numpy(), one_hot_cols):
        folds.append(df.iloc[indices[1]])
    return folds

In [28]:
folds = iterative_split(orb, [0.1,]*10, 10, ['Right', 'Left', 'Gender'])

In [29]:
print(len(folds))

10


In [30]:
print_results(orb, folds, ['Right', 'Left', 'Gender'])

query   : #rows      : #rows per fold

1, 2 : total = 86 : per fold = 9 - 8 - 9 - 8 - 9 - 9 - 9 - 8 - 9 - 8 - 
1, 1 : total = 178 : per fold = 19 - 18 - 19 - 17 - 16 - 18 - 18 - 18 - 18 - 17 - 
1, 3 : total = 63 : per fold = 6 - 6 - 6 - 7 - 7 - 6 - 6 - 6 - 6 - 7 - 
1, 4 : total = 12 : per fold = 0 - 1 - 0 - 2 - 2 - 1 - 1 - 2 - 1 - 2 - 
2, 2 : total = 44 : per fold = 5 - 4 - 4 - 5 - 5 - 4 - 4 - 4 - 4 - 5 - 
2, 1 : total = 54 : per fold = 6 - 6 - 5 - 5 - 6 - 6 - 5 - 5 - 4 - 6 - 
2, 3 : total = 9 : per fold = 1 - 0 - 1 - 0 - 1 - 0 - 2 - 2 - 1 - 1 - 
2, 4 : total = 13 : per fold = 0 - 2 - 2 - 1 - 1 - 1 - 2 - 1 - 1 - 2 - 
3, 2 : total = 24 : per fold = 3 - 2 - 2 - 2 - 2 - 3 - 2 - 3 - 3 - 2 - 
3, 1 : total = 41 : per fold = 4 - 3 - 4 - 4 - 4 - 5 - 4 - 5 - 4 - 4 - 
3, 3 : total = 24 : per fold = 2 - 3 - 3 - 2 - 3 - 2 - 2 - 2 - 2 - 3 - 
3, 4 : total = 10 : per fold = 1 - 0 - 1 - 1 - 1 - 1 - 1 - 1 - 2 - 1 - 
4, 2 : total = 5 : per fold = 0 - 1 - 1 - 0 - 0 - 0 - 1 - 1 - 1 - 0 - 
4, 1 : total = 8

In [31]:
def iterative_split_through_sorting_shuffle(df, n_splits, stratify_columns, random_state):
    """Custom iterative train test split which
    maintains balanced representation.
    """
    # Dataframe random row shuffle + sorting according to stratify_columns
    sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
    # for each fold, we take one row every n_splits rows
    folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
    # Further shuffling
    folds = [fold.sample(frac=1, random_state=random_state) for fold in folds]
    random.Random(random_state).shuffle(folds)
    return folds

In [32]:
folds = iterative_split_through_sorting_shuffle(orb, 10, ['Right', 'Left', 'Gender'], 1)

In [33]:
print_results(orb, folds, ['Right', 'Left', 'Gender'])

query   : #rows      : #rows per fold

1, 2 : total = 86 : per fold = 8 - 9 - 9 - 8 - 8 - 9 - 9 - 8 - 9 - 9 - 
1, 1 : total = 178 : per fold = 18 - 17 - 17 - 18 - 18 - 18 - 18 - 18 - 18 - 18 - 
1, 3 : total = 63 : per fold = 7 - 6 - 6 - 6 - 7 - 6 - 6 - 7 - 6 - 6 - 
1, 4 : total = 12 : per fold = 1 - 2 - 1 - 2 - 1 - 1 - 1 - 1 - 1 - 1 - 
2, 2 : total = 44 : per fold = 5 - 4 - 4 - 4 - 5 - 5 - 4 - 5 - 4 - 4 - 
2, 1 : total = 54 : per fold = 5 - 5 - 6 - 5 - 5 - 5 - 6 - 5 - 6 - 6 - 
2, 3 : total = 9 : per fold = 0 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 
2, 4 : total = 13 : per fold = 2 - 2 - 1 - 2 - 1 - 1 - 1 - 1 - 1 - 1 - 
3, 2 : total = 24 : per fold = 2 - 2 - 2 - 2 - 2 - 3 - 3 - 2 - 3 - 3 - 
3, 1 : total = 41 : per fold = 4 - 4 - 5 - 4 - 4 - 4 - 4 - 4 - 4 - 4 - 
3, 3 : total = 24 : per fold = 3 - 2 - 2 - 3 - 3 - 2 - 2 - 3 - 2 - 2 - 
3, 4 : total = 10 : per fold = 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 
4, 2 : total = 5 : per fold = 1 - 1 - 1 - 1 - 0 - 0 - 1 - 0 - 0 - 0 - 
4, 1 : total = 8

# Makes splits by considering zygosity

In [34]:
set(orb['ZygosityGT'].tolist())

{'DZ', 'MZ', 'NotTwin'}

In [35]:
orb.sort_values("Family_ID")

Unnamed: 0_level_0,Gender,ZygosityGT,Family_ID,Right,Right Subtype,Left,Left Subtype
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
473952,F,DZ,51090_80960_99983,1,A,1,A
289555,F,DZ,51090_80960_99983,3,Y,2,A
123723,F,MZ,51106_80975,1,N,1,A
677968,F,MZ,51279_81145,3,T,4,A
139637,F,MZ,51279_81145,2,A,2,A
...,...,...,...,...,...,...,...
108020,M,DZ,56200_86019,1,C,1,A
728454,M,DZ,56200_86019,1,A,2,A
618952,M,NotTwin,56202_86021,2,C,1,B
650746,M,NotTwin,99987_99988,2,I,1,S


In [36]:
l1 = [2, 3, 4]
l2 = [3, 4, 5]
print(set(l1)-set(l2))
print(set(l2)-set(l1))
print(set(l1).intersection(set(l2)))

{2}
{5}
{3, 4}


In [37]:
def how_many_common_families(total, folds):
    nb_common = 0
    for i, source in enumerate(folds):
        for j, target in enumerate(folds):
            if j > i:
                nb_common += len(set(target.Family_ID).intersection(set(source.Family_ID)))
    print(f"number of common families = {nb_common} over {len(set(total.Family_ID))} total families")


In [38]:
print_results(orb, folds, ['Right', 'Left', 'Gender'], False)
how_many_common_families(orb, folds)


lengths of folds :  58 57 57 57 58 58 58 58 58 58 
Expected total_length = 577
Effective total_length = 577
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 221 over 335 total families


In [39]:
def bring_together_families(folds):
    nb_common = 0
    results = [fold.copy() for fold in folds]
    for i in range(len(results)):
        for j in range(len(results)):
            if j > i:
                source = results[i]
                target = results[j]
                common = set(target.Family_ID).intersection(set(source.Family_ID))
                if np.random.randint(2):
                    results[i] = pd.concat([source, target[target.Family_ID.isin(common)]], axis=0)
                    results[j] = target[~target.Family_ID.isin(common)]
                else:
                    results[i] = source[~source.Family_ID.isin(common)]
                    results[j] = pd.concat([target, source[source.Family_ID.isin(common)]], axis=0)
    return results

In [40]:
results = bring_together_families(folds)

In [41]:
print_results(orb, results, ['Right', 'Left', 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

1, 2 : total = 86 : per fold = 9 - 11 - 11 - 13 - 9 - 4 - 8 - 4 - 9 - 8 - 
1, 1 : total = 178 : per fold = 23 - 19 - 19 - 22 - 21 - 6 - 17 - 18 - 16 - 17 - 
1, 3 : total = 63 : per fold = 9 - 7 - 6 - 9 - 8 - 3 - 7 - 2 - 10 - 2 - 
1, 4 : total = 12 : per fold = 2 - 2 - 4 - 2 - 1 - 0 - 1 - 0 - 0 - 0 - 
2, 2 : total = 44 : per fold = 7 - 4 - 3 - 5 - 9 - 4 - 1 - 2 - 4 - 5 - 
2, 1 : total = 54 : per fold = 10 - 6 - 4 - 7 - 5 - 4 - 6 - 4 - 3 - 5 - 
2, 3 : total = 9 : per fold = 1 - 1 - 1 - 2 - 1 - 0 - 1 - 0 - 0 - 2 - 
2, 4 : total = 13 : per fold = 5 - 3 - 1 - 1 - 2 - 0 - 0 - 0 - 0 - 1 - 
3, 2 : total = 24 : per fold = 5 - 1 - 3 - 3 - 2 - 0 - 2 - 0 - 5 - 3 - 
3, 1 : total = 41 : per fold = 6 - 5 - 7 - 5 - 3 - 1 - 4 - 2 - 4 - 4 - 
3, 3 : total = 24 : per fold = 2 - 2 - 4 - 3 - 2 - 1 - 1 - 3 - 5 - 1 - 
3, 4 : total = 10 : per fold = 1 - 2 - 0 - 1 - 2 - 1 - 1 - 0 - 1 - 1 - 
4, 2 : total = 5 : per fold = 1 - 1 - 1 - 2 - 0 - 0 - 0 - 0 - 0 - 0 - 
4, 1 : total

In [42]:
orb

Unnamed: 0_level_0,Gender,ZygosityGT,Family_ID,Right,Right Subtype,Left,Left Subtype
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
100307,F,MZ,51488_81352,1,B,2,F
100408,M,MZ,51730_81594,1,N,2,E
100610,M,DZ,52813_82634,1,A,1,A
101006,F,NotTwin,51283_52850_81149,1,B,1,A
101410,M,NotTwin,52198_82061,2,A,2,I
...,...,...,...,...,...,...,...
984472,F,DZ,51455_81320,3,T,1,B
987983,F,NotTwin,51493_81357,2,C,3,E
990366,M,NotTwin,56064_85885,4,G,3,A
995174,M,MZ,55923_85743,2,A,1,A


In [43]:
## why is group columns not used ?
def stratified_group_kfold_split(df, n_splits, stratify_columns, group_columns, random_state):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = df[group_columns].to_numpy()
    one_hot_grps = df[group_columns].to_numpy()
    print(one_hot_grps.shape)
    stratifier = StratifiedGroupKFold_WithoutCheck(
        n_splits=n_splits, shuffle=False)
    folds = []
    for indices in stratifier.split(df.to_numpy(), one_hot_cols, one_hot_grps):
        folds.append(df.iloc[indices[1]])
    return folds

In [44]:
results = stratified_group_kfold_split(orb, 10, ['Left', 'Gender'], ['Family_ID'], 4)
#results = stratified_group_kfold_split(orb, 3, ['Left', 'Gender'], ['Family_ID'], 2)

(577, 1)


In [45]:
#print_results(orb, results, ['Right', 'Left', 'Gender'], True)
print_results(orb, results, ['Left', 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

2, F : total = 82 : per fold = 9 - 4 - 13 - 11 - 9 - 5 - 6 - 10 - 9 - 6 - 
2, M : total = 77 : per fold = 12 - 7 - 7 - 13 - 8 - 6 - 4 - 3 - 6 - 11 - 
1, F : total = 170 : per fold = 15 - 16 - 15 - 14 - 18 - 17 - 16 - 23 - 16 - 20 - 
1, M : total = 111 : per fold = 7 - 16 - 11 - 7 - 13 - 15 - 10 - 11 - 12 - 9 - 
3, F : total = 76 : per fold = 9 - 12 - 7 - 7 - 4 - 10 - 9 - 6 - 7 - 5 - 
3, M : total = 25 : per fold = 3 - 2 - 4 - 5 - 2 - 2 - 6 - 0 - 0 - 1 - 
4, F : total = 21 : per fold = 1 - 1 - 1 - 0 - 2 - 2 - 4 - 2 - 5 - 3 - 
4, M : total = 15 : per fold = 2 - 0 - 0 - 1 - 2 - 1 - 3 - 2 - 2 - 2 - 

lengths of folds :  58 58 58 58 58 58 58 57 57 57 
Expected total_length = 577
Effective total_length = 577
total number of stratification errors: 31
total number of mismatched fold sizes : 0
number of common families = 0 over 335 total families


In [46]:
results[0].head()

Unnamed: 0_level_0,Gender,ZygosityGT,Family_ID,Right,Right Subtype,Left,Left Subtype
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
100610,M,DZ,52813_82634,1,A,1,A
105923,F,NotTwin,52925_82747,2,E,2,G
108121,F,NotTwin,51566_81430,1,F,1,E
115825,M,MZ,56016_85837_99974,1,B,2,A
118225,M,DZ,52813_82634,1,N,2,F


In [129]:
n_splits = 10
side = 'Right'
df = orb[[side, 'Gender', 'Family_ID']]

# group based on family
groups = df['Family_ID'].to_numpy()

# labels are combinations of Sex and actual label
labels = df[[side, 'Gender']]
labels['label'] = labels[labels.columns[:]].apply(
    lambda x: ''.join(x.dropna().astype(str)),
    axis=1
)
labels = labels['label'].to_numpy()

In [130]:
stratifier = StratifiedGroupKFold(
    n_splits=n_splits, shuffle=False)
results = []
for indices in stratifier.split(df.to_numpy(), labels, groups):
    results.append(df.iloc[indices[1]])



In [131]:
print_results(orb, results, [side, 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

1, F : total = 217 : per fold = 22 - 22 - 22 - 22 - 22 - 21 - 21 - 21 - 22 - 22 - 
1, M : total = 122 : per fold = 12 - 12 - 12 - 12 - 12 - 12 - 13 - 12 - 12 - 13 - 
2, F : total = 55 : per fold = 5 - 5 - 6 - 5 - 6 - 6 - 6 - 6 - 5 - 5 - 
2, M : total = 65 : per fold = 7 - 6 - 6 - 7 - 7 - 7 - 7 - 6 - 6 - 6 - 
3, F : total = 65 : per fold = 7 - 6 - 7 - 6 - 6 - 7 - 7 - 6 - 6 - 7 - 
3, M : total = 34 : per fold = 3 - 4 - 3 - 4 - 3 - 3 - 3 - 4 - 4 - 3 - 
4, F : total = 12 : per fold = 1 - 2 - 1 - 1 - 1 - 1 - 1 - 2 - 1 - 1 - 
4, M : total = 7 : per fold = 1 - 1 - 1 - 1 - 1 - 1 - 0 - 0 - 1 - 0 - 

lengths of folds :  58 58 58 58 58 58 58 57 57 57 
Expected total_length = 577
Effective total_length = 577
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 0 over 335 total families


# Save results

In [132]:
save_path = f"/neurospin/dico/data/deep_folding/current/datasets/orbital_patterns/Troiani/{side}"

In [133]:
for i in range(len(results)):
    results[i].reset_index()['Subject'].to_csv(
        f"{save_path}/split_{i}.csv",
        header=False,
        index=False,
        quoting=csv.QUOTE_ALL)

# Step 1 : split in 5 folds, keep last for test

In [None]:
side = 'Right'

In [60]:
save_path = f"/neurospin/dico/data/deep_folding/current/datasets/orbital_patterns/Troiani/{side}"

In [61]:
n_splits = 5
df = orb[[side, 'Gender', 'Family_ID']]

# group based on family
groups = df['Family_ID'].to_numpy()

# labels are combinations of Sex and actual label
labels = df[[side, 'Gender']]
labels['label'] = labels[labels.columns[:]].apply(
    lambda x: ''.join(x.dropna().astype(str)),
    axis=1
)
labels = labels['label'].to_numpy()

In [62]:
stratifier = StratifiedGroupKFold(
    n_splits=n_splits, shuffle=False)
results = []
for indices in stratifier.split(df.to_numpy(), labels, groups):
    results.append(df.iloc[indices[1]])

In [63]:
print_results(orb, results, [side, 'Gender'], True)
how_many_common_families(orb, results)

query   : #rows      : #rows per fold

2, F : total = 82 : per fold = 16 - 16 - 17 - 16 - 17 - 
2, M : total = 77 : per fold = 16 - 15 - 15 - 15 - 16 - 
1, F : total = 170 : per fold = 34 - 34 - 34 - 34 - 34 - 
1, M : total = 111 : per fold = 22 - 22 - 22 - 23 - 22 - 
3, F : total = 76 : per fold = 16 - 15 - 15 - 15 - 15 - 
3, M : total = 25 : per fold = 5 - 5 - 5 - 5 - 5 - 
4, F : total = 21 : per fold = 4 - 5 - 4 - 4 - 4 - 
4, M : total = 15 : per fold = 3 - 3 - 3 - 3 - 3 - 

lengths of folds :  116 115 115 115 116 
Expected total_length = 577
Effective total_length = 577
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 0 over 335 total families


In [64]:
results[-1].reset_index()['Subject'].to_csv(
    f"{save_path}/test_split.csv",
    header=False,
    index=False,
    quoting=csv.QUOTE_ALL)

# Step 2 : split remaining 80% in 5 for cross val and train eval

In [69]:
orb_train_val = orb.drop(index=results[-1].index.to_list())

In [72]:
# repeat the process
n_splits = 5
df = orb_train_val[[side, 'Gender', 'Family_ID']]

# group based on family
groups = df['Family_ID'].to_numpy()

# labels are combinations of Sex and actual label
labels = df[[side, 'Gender']]
labels['label'] = labels[labels.columns[:]].apply(
    lambda x: ''.join(x.dropna().astype(str)),
    axis=1
)
labels = labels['label'].to_numpy()

In [73]:
stratifier = StratifiedGroupKFold(
    n_splits=n_splits, shuffle=False)
results = []
for indices in stratifier.split(df.to_numpy(), labels, groups):
    results.append(df.iloc[indices[1]])

In [74]:
print_results(orb_train_val, results, [side, 'Gender'], True)
how_many_common_families(orb_train_val, results)

query   : #rows      : #rows per fold

2, F : total = 77 : per fold = 15 - 15 - 16 - 16 - 15 - 
2, M : total = 64 : per fold = 13 - 13 - 13 - 12 - 13 - 
1, F : total = 139 : per fold = 28 - 28 - 28 - 27 - 28 - 
1, M : total = 93 : per fold = 18 - 19 - 19 - 19 - 18 - 
3, F : total = 60 : per fold = 12 - 12 - 12 - 12 - 12 - 
3, M : total = 21 : per fold = 5 - 4 - 4 - 4 - 4 - 
4, F : total = 17 : per fold = 3 - 4 - 3 - 4 - 3 - 
4, M : total = 13 : per fold = 3 - 2 - 2 - 3 - 3 - 

lengths of folds :  97 97 97 97 96 
Expected total_length = 484
Effective total_length = 484
total number of stratification errors: 0
total number of mismatched fold sizes : 0
number of common families = 0 over 281 total families


In [75]:
# save folds
for i in range(len(results)):
    results[i].reset_index()['Subject'].to_csv(
        f"{save_path}/train_val_split_{i}.csv",
        header=False,
        index=False,
        quoting=csv.QUOTE_ALL)