<a href="https://colab.research.google.com/github/is0280fp/google_colab/blob/main/StratifiedGroupKFold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive 
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pickle
import numpy as np
import pandas as pd

with open("/content/drive/MyDrive/pickle/label_five_class.pickle", mode="rb") as f:
   label = pickle.load(f)
f.close()

In [3]:
import random
from sklearn.model_selection import GroupKFold
from collections import Counter, defaultdict


def Count_y(y, groups):
    # y counts per group
    unique_num = np.max(y) + 1
    y_counts_per_group = defaultdict(lambda: np.zeros(unique_num))
    for label, g in zip(y, groups):
        y_counts_per_group[g][label] += 1

    return y_counts_per_group


def StratifiedGroupKFold(X, y, groups, features, k, seed = 42):
    # Preparation
    max_y = np.max(y)
    y_counts_per_group = Count_y(y, groups)
    kf = GroupKFold(n_splits=k)

    for train_idx, val_idx in kf.split(X, y, groups):
        # Training dataset and validation dataset
        # from IPython.core.debugger import Pdb; Pdb().set_trace()
        x_train = X.iloc[train_idx, :]
        id_train = x_train["SUBJECT_ID"].unique()
        x_train = x_train[features]

        x_val, y_val = X.iloc[val_idx, :], y.iloc[val_idx]
        id_val = x_val["SUBJECT_ID"].unique()
        x_val = x_val[features]

        # y counts of training dataset and validation dataset
        y_counts_train = np.zeros(max_y+1)
        y_counts_val = np.zeros(max_y+1)
        for id_ in id_train:
            y_counts_train += y_counts_per_group[id_]
        for id_ in id_val:
            y_counts_val += y_counts_per_group[id_]

        # Determination ratio of validation dataset
        numratio_train = y_counts_train / np.max(y_counts_train)
        stratified_count = np.ceil(y_counts_val[np.argmax(y_counts_train)] * numratio_train)
        stratified_count = stratified_count.astype(int)

        # Select validation dataset randomly
        val_idx = np.array([])
        np.random.seed(seed) 

        for num in range(max_y+1):
            val_idx = np.append(val_idx, np.random.choice(y_val[y_val==num].index, stratified_count[num]))
        val_idx = val_idx.astype(int)
        
        yield train_idx, val_idx

def Get_distribution(y_vals):
    # Get distribution
    y_distr = Counter(y_vals)
    y_vals_sum = sum(y_distr.values())

    return [f"{y_distr[i] / y_vals_sum:.2%}" for i in range(np.max(y_vals) + 1)]

In [4]:
def Read_data():
    # Read dataset
    df = pd.read_csv('/content/drive/MyDrive/pickle/GRF_csv/GRF_F_V_PRO_right.csv', header=0)
    df["target"] = label.astype(int)

    # Extract feature names
    features = df.columns[3]

    return df, features

df, features = Read_data()

In [5]:
df

Unnamed: 0,SUBJECT_ID,SESSION_ID,TRIAL_ID,F_V_PRO_1,F_V_PRO_2,F_V_PRO_3,F_V_PRO_4,F_V_PRO_5,F_V_PRO_6,F_V_PRO_7,F_V_PRO_8,F_V_PRO_9,F_V_PRO_10,F_V_PRO_11,F_V_PRO_12,F_V_PRO_13,F_V_PRO_14,F_V_PRO_15,F_V_PRO_16,F_V_PRO_17,F_V_PRO_18,F_V_PRO_19,F_V_PRO_20,F_V_PRO_21,F_V_PRO_22,F_V_PRO_23,F_V_PRO_24,F_V_PRO_25,F_V_PRO_26,F_V_PRO_27,F_V_PRO_28,F_V_PRO_29,F_V_PRO_30,F_V_PRO_31,F_V_PRO_32,F_V_PRO_33,F_V_PRO_34,F_V_PRO_35,F_V_PRO_36,F_V_PRO_37,...,F_V_PRO_63,F_V_PRO_64,F_V_PRO_65,F_V_PRO_66,F_V_PRO_67,F_V_PRO_68,F_V_PRO_69,F_V_PRO_70,F_V_PRO_71,F_V_PRO_72,F_V_PRO_73,F_V_PRO_74,F_V_PRO_75,F_V_PRO_76,F_V_PRO_77,F_V_PRO_78,F_V_PRO_79,F_V_PRO_80,F_V_PRO_81,F_V_PRO_82,F_V_PRO_83,F_V_PRO_84,F_V_PRO_85,F_V_PRO_86,F_V_PRO_87,F_V_PRO_88,F_V_PRO_89,F_V_PRO_90,F_V_PRO_91,F_V_PRO_92,F_V_PRO_93,F_V_PRO_94,F_V_PRO_95,F_V_PRO_96,F_V_PRO_97,F_V_PRO_98,F_V_PRO_99,F_V_PRO_100,F_V_PRO_101,target
0,510,413,1,0.022642,0.066304,0.123147,0.178585,0.229816,0.286079,0.353620,0.426446,0.491737,0.540526,0.572324,0.591803,0.604101,0.613784,0.625180,0.641615,0.664849,0.694741,0.729989,0.769572,0.812781,0.856937,0.897933,0.933673,0.963686,0.988120,1.007563,1.022383,1.031675,1.034030,1.028950,1.017599,1.002251,0.985146,0.967714,0.950556,0.933666,...,0.944331,0.957378,0.969673,0.979435,0.985461,0.987254,0.985060,0.979608,0.971624,0.961551,0.949647,0.936000,0.920341,0.902275,0.881277,0.856765,0.828360,0.796449,0.762198,0.727672,0.695103,0.664576,0.633634,0.599865,0.562673,0.522544,0.479455,0.433099,0.383930,0.333009,0.282119,0.233682,0.189883,0.151416,0.117502,0.087352,0.061159,0.039500,0.022633,4
1,510,413,2,0.022637,0.063175,0.114420,0.163122,0.207777,0.257460,0.317920,0.385202,0.449475,0.502569,0.542991,0.574322,0.600711,0.625123,0.649685,0.675617,0.703545,0.733478,0.765108,0.798315,0.833309,0.868610,0.900976,0.928785,0.951564,0.969677,0.984226,0.995895,1.004173,1.007840,1.006018,0.998984,0.988394,0.975896,0.962345,0.947913,0.932515,...,0.974916,0.982807,0.988364,0.991240,0.990379,0.984920,0.974894,0.961443,0.946102,0.930218,0.914373,0.898116,0.880641,0.861494,0.840307,0.816337,0.788608,0.757278,0.723950,0.691121,0.661090,0.633367,0.604795,0.573001,0.537845,0.499716,0.458520,0.414299,0.367469,0.319159,0.271192,0.225628,0.184277,0.147850,0.115619,0.086782,0.061318,0.039709,0.022630,4
2,510,413,3,0.022628,0.066277,0.123461,0.178248,0.225705,0.274257,0.330758,0.392945,0.452859,0.503764,0.544744,0.579051,0.609675,0.638582,0.667435,0.697532,0.729419,0.762559,0.796058,0.829714,0.863954,0.897789,0.928085,0.953208,0.973479,0.989105,1.000299,1.007460,1.010191,1.007481,0.999106,0.985895,0.969552,0.951921,0.934366,0.917419,0.901253,...,0.966896,0.974057,0.980367,0.985448,0.988716,0.989712,0.988231,0.984024,0.976889,0.967046,0.954969,0.941220,0.926005,0.909102,0.889945,0.867775,0.841827,0.811628,0.777423,0.740580,0.703603,0.668999,0.636491,0.603194,0.566852,0.526892,0.483504,0.437075,0.388001,0.336820,0.285054,0.235092,0.189484,0.149884,0.115968,0.086485,0.060872,0.039427,0.022631,4
3,510,413,4,0.022641,0.067567,0.130535,0.199452,0.267539,0.337289,0.411577,0.485777,0.551217,0.602309,0.639995,0.668927,0.694116,0.719583,0.747888,0.779536,0.813357,0.847533,0.880104,0.909800,0.936276,0.959890,0.981327,1.000406,1.015865,1.027714,1.037043,1.043614,1.045667,1.042110,1.033123,1.019829,1.003737,0.986166,0.968028,0.949724,0.931214,...,0.970274,0.982586,0.992746,1.000381,1.004948,1.005915,1.002867,0.995456,0.983848,0.968726,0.950822,0.930732,0.908520,0.884175,0.857689,0.828929,0.797767,0.764096,0.728299,0.691786,0.656788,0.625290,0.596772,0.567885,0.535631,0.498579,0.456490,0.409605,0.358515,0.305000,0.251956,0.202682,0.159656,0.124149,0.095263,0.071202,0.051078,0.034963,0.022631,4
4,510,413,6,0.022629,0.065415,0.122248,0.179967,0.232711,0.285586,0.346265,0.413813,0.480443,0.538298,0.583896,0.618488,0.645865,0.669319,0.691176,0.713584,0.738613,0.768327,0.803170,0.842064,0.881803,0.918061,0.948845,0.975186,0.997365,1.015405,1.029518,1.039255,1.043570,1.041564,1.033183,1.019272,1.001405,0.981793,0.962063,0.942874,0.924258,...,0.964664,0.982573,0.996633,1.005156,1.007113,1.002606,0.992580,0.978326,0.961664,0.944011,0.925839,0.906737,0.886066,0.863230,0.837778,0.809363,0.778125,0.745008,0.712142,0.681813,0.654262,0.627340,0.598441,0.566074,0.530219,0.490923,0.448102,0.401956,0.353288,0.304090,0.256634,0.212920,0.174113,0.139937,0.109247,0.081570,0.057560,0.037933,0.022633,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75727,93,999910375,6,0.037297,0.119416,0.240632,0.386556,0.520518,0.619456,0.688776,0.734850,0.761056,0.775709,0.788590,0.812037,0.848955,0.897069,0.951219,1.005436,1.056417,1.101251,1.137340,1.165444,1.186906,1.202519,1.213922,1.220860,1.222165,1.217774,1.207912,1.192228,1.170491,1.143634,1.112074,1.075773,1.036471,0.995271,0.953412,0.912188,0.872449,...,0.856724,0.881639,0.907329,0.933580,0.960102,0.986662,1.012977,1.038759,1.063662,1.087085,1.108978,1.129076,1.146736,1.161644,1.173571,1.181991,1.186343,1.186764,1.182946,1.173490,1.158589,1.137543,1.108489,1.070457,1.022510,0.962423,0.887904,0.799965,0.699453,0.588133,0.474230,0.365067,0.270073,0.194392,0.138276,0.099883,0.074006,0.054072,0.037274,0
75728,93,999910375,7,0.037268,0.114844,0.228392,0.364038,0.489199,0.578761,0.639066,0.680870,0.712451,0.741239,0.772606,0.812930,0.866817,0.930649,0.998994,1.066119,1.128673,1.184877,1.232686,1.270087,1.299495,1.322485,1.339628,1.351428,1.358262,1.360024,1.355417,1.345020,1.328737,1.306297,1.277881,1.244387,1.206272,1.163802,1.118614,1.071911,1.024853,...,0.790524,0.811361,0.833569,0.856937,0.880927,0.905119,0.929094,0.952233,0.974153,0.994494,1.012800,1.028910,1.042874,1.054638,1.063658,1.070126,1.073772,1.073983,1.070150,1.062080,1.049176,1.029788,1.003616,0.969731,0.926455,0.871551,0.805113,0.726574,0.635240,0.535408,0.433037,0.334974,0.250217,0.182052,0.131307,0.096683,0.072515,0.053597,0.037266,0
75729,93,999910375,9,0.037274,0.117238,0.234293,0.373251,0.497061,0.583862,0.640081,0.677136,0.705529,0.730081,0.752063,0.776437,0.808990,0.852039,0.903792,0.957940,1.010178,1.057885,1.098748,1.132324,1.158774,1.178496,1.194085,1.206195,1.214418,1.218637,1.218753,1.214150,1.204212,1.189546,1.170527,1.147167,1.120855,1.092098,1.061200,1.028935,0.995999,...,0.785953,0.802219,0.819903,0.838831,0.858851,0.879899,0.901948,0.924838,0.948383,0.972237,0.995862,1.018820,1.040539,1.060645,1.078864,1.094721,1.107689,1.117789,1.124779,1.127686,1.126819,1.121857,1.111646,1.095524,1.072707,1.041214,0.998374,0.943328,0.874354,0.788016,0.688269,0.578080,0.462636,0.350491,0.249202,0.166171,0.105906,0.064350,0.037271,0
75730,93,999910375,10,0.037270,0.119939,0.242145,0.389927,0.529120,0.634246,0.711091,0.771960,0.823174,0.865921,0.900948,0.932693,0.966857,1.004708,1.044050,1.081207,1.114647,1.143934,1.167985,1.186744,1.200485,1.209639,1.214840,1.216697,1.214700,1.207609,1.195956,1.180191,1.160667,1.138053,1.113167,1.086286,1.057174,1.026251,0.993858,0.960558,0.927287,...,0.833554,0.853633,0.875376,0.898549,0.922901,0.947983,0.973364,0.998746,1.023699,1.047882,1.070917,1.092162,1.111130,1.127537,1.140971,1.150821,1.157468,1.160789,1.160069,1.155117,1.145631,1.130718,1.108867,1.079885,1.042763,0.994885,0.935852,0.864897,0.780477,0.682359,0.574657,0.462263,0.353513,0.256518,0.176966,0.118950,0.080556,0.054963,0.037269,0


In [8]:
X = df.drop("target", axis=1)
y = df["target"]
groups = df["SUBJECT_ID"]

distrs = [Get_distribution(y)]
index = ["all dataset"]

train_idx_list = []
val_idx_list = []

for fold, (train_idx, val_idx) in enumerate(StratifiedGroupKFold(X, y, groups, features, k=5)):

    # print(f"TRAIN_ID - fold {fold}:", groups[train_idx].unique(), 
    #       f"TEST_ID - fold {fold}:", groups[val_idx].unique())
    train_idx_list.append(train_idx)
    val_idx_list.append(val_idx)

    distrs.append(Get_distribution(y[train_idx]))
    index.append(f"training set - fold {fold}")
    distrs.append(Get_distribution(y[val_idx]))
    index.append(f"validation set - fold {fold}")

print(pd.DataFrame(distrs, index=index, columns=[f"Label {l}" for l in range(np.max(y) + 1)]))

                        Label 0 Label 1 Label 2 Label 3 Label 4
all dataset              10.24%  16.83%  26.24%  28.24%  18.45%
training set - fold 0    10.26%  16.87%  26.12%  28.06%  18.69%
validation set - fold 0  10.26%  16.87%  26.12%  28.06%  18.69%
training set - fold 1    10.13%  16.74%  24.77%  28.97%  19.39%
validation set - fold 1  10.13%  16.74%  24.77%  28.97%  19.39%
training set - fold 2    10.44%  17.22%  26.39%  27.94%  18.00%
validation set - fold 2  10.45%  17.22%  26.39%  27.94%  18.00%
training set - fold 3    10.42%  16.91%  26.98%  27.85%  17.83%
validation set - fold 3  10.42%  16.91%  26.98%  27.85%  17.83%
training set - fold 4     9.95%  16.44%  26.94%  28.37%  18.31%
validation set - fold 4   9.95%  16.44%  26.94%  28.36%  18.31%


In [10]:
# StratifiedGroupKFoldの分割結果をpickle

train_idx_list = np.array(train_idx_list)
val_idx_list = np.array(val_idx_list)

with open("/content/drive/MyDrive/pickle/train_idx_list_StratifiedGroup5Fold.pickle", mode="wb") as f:
   pickle.dump(train_idx_list, f)
f.close()
with open("/content/drive/MyDrive/pickle/val_idx_list_StratifiedGroup5Fold.pickle", mode="wb") as f:
   pickle.dump(val_idx_list, f)
f.close()