In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold

np.random.seed(2112)

***
5 different splits for repeated cross validation

In [2]:
train_labels = pd.read_csv("../data/raw/train_labels.csv")
train_labels.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 458913 entries, 0 to 458912
Data columns (total 2 columns):
 #   Column       Non-Null Count   Dtype 
---  ------       --------------   ----- 
 0   customer_ID  458913 non-null  object
 1   target       458913 non-null  int64 
dtypes: int64(1), object(1)
memory usage: 7.0+ MB


In [3]:
super_seeds = [2, 7, 11, 23, 2112]

In [4]:
all_splits = list()

for it,seed in enumerate(super_seeds):
    folds = train_labels[["customer_ID"]].copy()
    folds["fold"] = -1
    
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    skf_split = skf.split(train_labels, train_labels["target"].values)
    
    for fold,(_,valid_idx) in enumerate(skf_split):
        folds.loc[valid_idx,"fold"] = fold
    
    # assert that all samples are assigned to one fold
    assert len(folds.query("fold < 0")) == 0
    
    print("-"*90)
    print(pd.merge(train_labels, folds).groupby("fold")["target"].value_counts())
    
    all_splits.append(folds)
    folds.to_parquet(f"../data/processed/cv{it}.parquet", index=False)

------------------------------------------------------------------------------------------
fold  target
0     0         68017
      1         23766
1     0         68017
      1         23766
2     0         68017
      1         23766
3     0         68017
      1         23765
4     0         68017
      1         23765
Name: target, dtype: int64
------------------------------------------------------------------------------------------
fold  target
0     0         68017
      1         23766
1     0         68017
      1         23766
2     0         68017
      1         23766
3     0         68017
      1         23765
4     0         68017
      1         23765
Name: target, dtype: int64
------------------------------------------------------------------------------------------
fold  target
0     0         68017
      1         23766
1     0         68017
      1         23766
2     0         68017
      1         23766
3     0         68017
      1         23765
4     0         68

In [5]:
(all_splits[0].fold == all_splits[1].fold).all()

False

In [6]:
(all_splits[1].fold == all_splits[2].fold).all()

False

In [7]:
(all_splits[2].fold == all_splits[3].fold).all()

False

In [8]:
(all_splits[3].fold == all_splits[4].fold).all()

False

***