# Methods for Customized Train, Validation, and Test sets
#### based on the train-test and minority-majority class proportions


In [18]:
import numpy as np

def odd_split(X, y, minority_class = 0, minority_test_size = 0.1):
    minority_indices = np.where(y == minority_class)[0]
    majority_indices = np.where(y != minority_class)[0]
    
    print("minority_indicies: {}", minority_indices)
    print("minority_indicies: {}", len(minority_indices))

    print("majority_indicies: {}", majority_indices)
    print("majority_indicies: {}", len(majority_indices))

    n = max(1, int(minority_test_size * len(minority_indices)))
    print("n:{}", n)
    selected = np.random.choice(range(len(minority_indices)), n, replace=False)
    print("selected: {}", selected)
    test_minority_indices = minority_indices[selected]
    print("test_minority_indicies: {}", test_minority_indices)
    assert (y[test_minority_indices] == minority_class).all()

    selected = np.random.choice(range(len(majority_indices)), n, replace=False)
    print("selected: {}", selected)
    test_majority_indices = majority_indices[selected]
    print("test_majority_indicies: {}", test_majority_indices)
    assert (y[test_majority_indices ] != minority_class).all()

    test_indices = np.concatenate((test_minority_indices, test_majority_indices))
    train_indices = np.array([i for i in range(len(y)) if i not in set(test_indices)])

    return X[train_indices], y[train_indices], X[test_indices], y[test_indices]


In [42]:
def train_test_split(X, y, minority_class = 0, test_size = 0.2, minority_test_size = 0.1):
    minority_indices = np.where(y==minority_class)[0]
    majority_indices = np.where(y!=minority_class)[0]

    n_samples = X.shape[0]
    n_test = n_samples * test_size

    n = max(1, int(minority_test_size * n_test))
    selected = np.random.choice(range(len(minority_indices)), n, replace=False)
    test_minority_indices = minority_indices[selected]
    assert (y[test_minority_indices] == minority_class).all()

    n = max(1, int((1-minority_test_size) * n_test))
    selected = np.random.choice(range(len(majority_indices)), n, replace=False)
    test_majority_indices = majority_indices[selected]
    assert (y[test_majority_indices ] != minority_class).all()

    test_indices = np.concatenate((test_minority_indices, test_majority_indices))
    train_indices = np.array([i for i in range(len(y)) if i not in set(test_indices)])

    return X[train_indices], y[train_indices], X[test_indices], y[test_indices]


def train_val_split(X, y, minority_class = 0, val_size = 0.1, minority_val_size = 0.1, minority_train_size = 0.5):
    minority_indices = np.where(y == minority_class)[0]
    majority_indices = np.where(y != minority_class)[0]

    n_samples = X.shape[0]
    n_train = n_samples * (1 - val_size)

    min_size = max(1, int(minority_train_size * n_train))
    selected = np.random.choice(range(len(minority_indices)), min_size, replace=False)
    train_minority_indices = minority_indices[selected]
    val_minority_indices =  np.setdiff1d(minority_indices, train_minority_indices)
    assert (y[train_minority_indices] == minority_class).all()

    maj_size = max(1, int((1 - minority_train_size) * n_train))
    selected = np.random.choice(range(len(majority_indices)), maj_size, replace=False)
    train_majority_indices = majority_indices[selected]
    val_majority_indices =  np.setdiff1d(majority_indices, train_majority_indices)
    assert (y[train_majority_indices ] != minority_class).all()

    train_indices = np.concatenate((train_minority_indices, train_majority_indices))
    val_indices = np.concatenate((val_minority_indices, val_majority_indices))

    return X[train_indices], y[train_indices], X[val_indices], y[val_indices]




In [45]:
from collections import Counter 
 
X = np.random.normal(size=(1000, 2))  
y = np.random.choice([0, 1], p=[0.4, 0.6], size=1000)
print('Whole', Counter(y))
print(type(X))
print(type(y))

X_train, y_train, X_test, y_test = train_test_split(X, y)
print('Train', Counter(y_train))
print('Test', Counter(y_test))

X_train, y_train, X_val, y_val = train_val_split(X_train, y_train)
print('Train', Counter(y_train))
print('Validation', Counter(y_val))

Whole Counter({1: 612, 0: 388})
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Train Counter({1: 432, 0: 368})
Test Counter({1: 180, 0: 20})
Train Counter({0: 360, 1: 360})
Validation Counter({1: 72, 0: 8})


1000