# Data Splits - Train, Validation, Test

In [1]:
import numpy as np
import pandas as pd

np.random.seed(42)

# Test data
df = pd.DataFrame(np.random.rand(1000,2))
df[2] = (df[0] + df[1] > 1).astype(int)

x = df[[0, 1]]
y = df[2]

print(x.head())
print()
print(y.head())

          0         1
0  0.374540  0.950714
1  0.731994  0.598658
2  0.156019  0.155995
3  0.058084  0.866176
4  0.601115  0.708073

0    1
1    1
2    0
3    0
4    1
Name: 2, dtype: int32


### Keep-out Test Set

In [2]:
from sklearn.model_selection import train_test_split

x_nontest, x_test, y_nontest, y_test = train_test_split(
    x, y, test_size=0.3, stratify=y, shuffle=True, random_state=42
)
x_nontest.shape, x_test.shape, y_nontest.shape, y_test.shape

((700, 2), (300, 2), (700,), (300,))

In [3]:
np.count_nonzero(y), np.count_nonzero(y_nontest), np.count_nonzero(y_test)

(504, 353, 151)

### Stratified k-fold Cross Validation

In [4]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

for train_index, val_index in skf.split(x_nontest, y_nontest):
    x_train, x_val = x_nontest.iloc[train_index], x_nontest.iloc[val_index]
    y_train, y_val = y_nontest.iloc[train_index], y_nontest.iloc[val_index]
    
    # Fit model ...
    
    print(x_train.shape, x_val.shape, y_train.shape, y_val.shape)
    print(np.count_nonzero(y_train), np.count_nonzero(y_val))
    print()

(630, 2) (70, 2) (630,) (70,)
317 36

(630, 2) (70, 2) (630,) (70,)
317 36

(630, 2) (70, 2) (630,) (70,)
317 36

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35

(630, 2) (70, 2) (630,) (70,)
318 35



### GroupShuffleSplit - Ensure that all elements of one group belong to the same partition

In [5]:
from sklearn.model_selection import GroupShuffleSplit

splitter = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
# splits = splitter.split(posts, groups=posts[author_id_column])