## Train, Validation and Test Sets Split

In [1]:
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
train_size = 0.8
valid_size = 0.1
test_size = 0.1
random_state = 888
total_labels = 4

In [3]:
def train_valid_test_split(input_matrix, label, train_size, valid_size, test_size, 
                           random_state, total_labels):
    n = input_matrix.shape[0]
    X_train_idx, X_test_idx = train_test_split(list(range(n)), 
                                               test_size = valid_size + test_size, 
                                               random_state = random_state)
    X_valid_idx, X_test_idx = train_test_split(X_test_idx, 
                                               test_size = test_size /(valid_size + test_size), 
                                               random_state = random_state)
    # training set
    X_train = input_matrix[X_train_idx]
    Y_train = np.zeros((len(X_train_idx), total_labels) , dtype=int)
    Y_train[:, label] = 1
    # validation set
    X_valid = input_matrix[X_valid_idx]
    Y_valid = np.zeros((len(X_valid_idx), total_labels) , dtype=int)
    Y_valid[:, label] = 1
    # test set
    X_test = input_matrix[X_test_idx]
    Y_test = np.zeros((len(X_test_idx), total_labels) , dtype=int)
    Y_test[:, label] = 1
    return {"X_train": X_train, 
            "Y_train": Y_train, 
            "X_valid": X_valid, 
            "Y_valid": Y_valid, 
            "X_test": X_test, 
            "Y_test": Y_test}

    

In [4]:
splitted_data_0 = train_valid_test_split(np.load("data/X_positive.npy"), 
                                         0, train_size, valid_size, test_size, 
                                         random_state, total_labels)
splitted_data_1 = train_valid_test_split(np.load("data/X_slightly_positive.npy"), 
                                         1, train_size, valid_size, test_size, 
                                         random_state, total_labels)
splitted_data_2 = train_valid_test_split(np.load("data/X_slightly_negative.npy"), 
                                         2, train_size, valid_size, test_size, 
                                         random_state, total_labels)
splitted_data_3 = train_valid_test_split(np.load("data/X_negative.npy"), 
                                         3, train_size, valid_size, test_size, 
                                         random_state, total_labels)

In [5]:
train_data = np.concatenate((splitted_data_0.get("X_train"),
                             splitted_data_1.get("X_train"),
                             splitted_data_2.get("X_train"),
                             splitted_data_3.get("X_train")), axis=0)
train_label = np.concatenate((splitted_data_0.get("Y_train"),
                              splitted_data_1.get("Y_train"),
                              splitted_data_2.get("Y_train"),
                              splitted_data_3.get("Y_train")), axis=0)
np.save("data/train_data.npy", train_data)
np.save("data/train_label.npy", train_label)

valid_data = np.concatenate((splitted_data_0.get("X_valid"),
                             splitted_data_1.get("X_valid"),
                             splitted_data_2.get("X_valid"),
                             splitted_data_3.get("X_valid")), axis=0)
valid_label = np.concatenate((splitted_data_0.get("Y_valid"),
                              splitted_data_1.get("Y_valid"),
                              splitted_data_2.get("Y_valid"),
                              splitted_data_3.get("Y_valid")), axis=0)
np.save("data/valid_data.npy", valid_data)
np.save("data/valid_label.npy", valid_label)

test_data = np.concatenate((splitted_data_0.get("X_test"),
                             splitted_data_1.get("X_test"),
                             splitted_data_2.get("X_test"),
                             splitted_data_3.get("X_test")), axis=0)
test_label = np.concatenate((splitted_data_0.get("Y_test"),
                              splitted_data_1.get("Y_test"),
                              splitted_data_2.get("Y_test"),
                              splitted_data_3.get("Y_test")), axis=0)
np.save("data/test_data.npy", test_data)
np.save("data/test_label.npy", test_label)
