In [1]:
import numpy as np
from collections import Counter

In [2]:
def stratified_kfold(x, y, n_splits=3):
    cv_samples = [] # creating empty list to store split indices
    for i in np.unique(y): # iterating through unique labels
        samples = np.where(y==i)[0] # selecting idxs from y which are equal to a particular label
        cut = np.linspace(0, len(samples), n_splits+1).astype("int") # creating 'n_splits' cuts to split the labels
        if len(cv_samples) == 0:
            cv_samples = [samples[cut[i]:cut[i+1]] for i in range(len(cut)-1)] # creating a list of 'n_splits' list of samples of first label
        else:
            # appending 'n_splits' list of samples of other labels to first label
            cv_samples = [np.append(cv_samples[n], samples[cut[i]:cut[i+1]]) for n, i in enumerate(range(len(cut)-1))]
    
    split = 0 # int variable to track split #
    split_idx = [*range(n_splits)] # list of split # [0, n_splits)
    while split < n_splits:
        test_cv_idx = split_idx[split] # index of test samples from cv_samples
        train_cv_idx = split_idx[: split] + split_idx[split+1:] # index of train samples from cv_samples
        train_idx = np.concatenate([cv_samples[idx] for idx in train_cv_idx]) # train samples idxs
        test_idx = cv_samples[test_cv_idx] # test samples idxs
        split += 1
        yield train_idx, test_idx # using generator to prevent out of memory error

### TESTING ON BREAST CANCER DATA

In [3]:
import pandas as pd
from sklearn.datasets import load_breast_cancer
x = load_breast_cancer()['data']
y = load_breast_cancer()['target']

In [4]:
# TESTING Y DISTRIBUTION IN TRAIN TEST SPLITS
print(f'Y DISTRIBUTION\n{pd.Series(y).value_counts(normalize=True)}\n')
cv_samples = stratified_kfold(x, y)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    print(f'TRAIN SPLIT {n}\n{pd.Series(y[train_idx]).value_counts(normalize=True)}\n')
    print(f'TEST SPLIT {n}\n{pd.Series(y[test_idx]).value_counts(normalize=True)}\n')

Y DISTRIBUTION
1    0.627417
0    0.372583
dtype: float64

TRAIN SPLIT 0
1    0.626316
0    0.373684
dtype: float64

TEST SPLIT 0
1    0.62963
0    0.37037
dtype: float64

TRAIN SPLIT 1
1    0.627968
0    0.372032
dtype: float64

TEST SPLIT 1
1    0.626316
0    0.373684
dtype: float64

TRAIN SPLIT 2
1    0.627968
0    0.372032
dtype: float64

TEST SPLIT 2
1    0.626316
0    0.373684
dtype: float64



In [5]:
# TESTING FOR DUPLICATE IDXS IN TRAIN AND TEST IDXS
cv_samples = stratified_kfold(x, y)
for train_idx, test_idx in cv_samples:
    assert len(set(test_idx).intersection(set(train_idx))) == 0

In [6]:
# TESTING FOR EQUAL TRAIN + TEST LENGTHS ACROSS ALL SPLITS
len_ = 0
cv_samples = stratified_kfold(x, y)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    if n == 0:
        len_ = len(train_idx) + len(test_idx)
        assert len_ == len(y)
    else:
        assert len_ == len(train_idx) + len(test_idx)

In [7]:
# TESTING FOR DUPLICATE IDXS IN TRAIN AND TEST SPLITS
cv_samples = stratified_kfold(x, y)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    assert len(set(train_idx)) == len(train_idx)
    assert len(set(test_idx)) == len(test_idx)

### TESTING ON IRIS DATA

In [8]:
import pandas as pd
from sklearn.datasets import load_iris
x = load_iris()['data']
y = load_iris()['target']

In [9]:
# TESTING Y DISTRIBUTION IN TRAIN TEST SPLITS
print(f'Y DISTRIBUTION\n{pd.Series(y).value_counts(normalize=True)}\n')
cv_samples = stratified_kfold(x, y, n_splits=5)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    print(f'TRAIN SPLIT {n}\n{pd.Series(y[train_idx]).value_counts(normalize=True)}\n')
    print(f'TEST SPLIT {n}\n{pd.Series(y[test_idx]).value_counts(normalize=True)}\n')

Y DISTRIBUTION
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TRAIN SPLIT 0
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TEST SPLIT 0
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TRAIN SPLIT 1
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TEST SPLIT 1
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TRAIN SPLIT 2
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TEST SPLIT 2
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TRAIN SPLIT 3
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TEST SPLIT 3
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TRAIN SPLIT 4
0    0.333333
1    0.333333
2    0.333333
dtype: float64

TEST SPLIT 4
0    0.333333
1    0.333333
2    0.333333
dtype: float64



In [10]:
# TESTING FOR DUPLICATE IDXS IN TRAIN AND TEST IDXS
cv_samples = stratified_kfold(x, y, n_splits=5)
for train_idx, test_idx in cv_samples:
    assert len(set(test_idx).intersection(set(train_idx))) == 0

In [11]:
# TESTING FOR EQUAL TRAIN + TEST LENGTHS ACROSS ALL SPLITS
len_ = 0
cv_samples = stratified_kfold(x, y, n_splits=5)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    if n == 0:
        len_ = len(train_idx) + len(test_idx)
        assert len_ == len(y)
    else:
        assert len_ == len(train_idx) + len(test_idx)

In [12]:
# TESTING FOR DUPLICATE IDXS IN TRAIN AND TEST SPLITS
cv_samples = stratified_kfold(x, y, n_splits=5)
for n, (train_idx, test_idx) in enumerate(cv_samples):
    assert len(set(train_idx)) == len(train_idx)
    assert len(set(test_idx)) == len(test_idx)