In [50]:
import pandas as pd
from constants import *
from sklearn.model_selection import train_test_split

In [51]:
def read_data(file_path):
    return pd.read_csv(file_path)


def split_data(data):
    print(f"SIZE OF DATA = {data.shape[0]}")
    positive_data = data[data["class"] == 1]
    negative_data = data[data["class"] == 0]

    # Split positive data
    pos_test_val, pos_train = train_test_split(
        positive_data, test_size=0.5, random_state=42, stratify=positive_data["class"]
    )
    print(f"POS TRAIN ROWS = {pos_train.shape[0]}")
    print(f"POS TEST_VAL ROWS = {pos_test_val.shape[0]}")

    pos_test, pos_val = train_test_split(
        pos_test_val, test_size=0.25, random_state=42, stratify=pos_test_val["class"]
    )
    print(f"POS TEST ROWS = {pos_test.shape[0]}")
    print(f"POS VAL ROWS = {pos_val.shape[0]}")

    # Split negative data
    neg_test_val, neg_train = train_test_split(
        negative_data, test_size=0.5, random_state=42, stratify=negative_data["class"]
    )
    print(f"NEG TEST ROWS = {neg_train.shape[0]}")
    print(f"NEG TEST_VAL ROWS = {neg_test_val.shape[0]}")

    neg_test, neg_val = train_test_split(
        neg_test_val, test_size=0.25, random_state=42, stratify=neg_test_val["class"]
    )
    print(f"NEG TEST ROWS = {neg_test.shape[0]}")
    print(f"NEG VALIDATION ROWS = {neg_val.shape[0]}")

    # Training dataset
    train = pd.concat([pos_train, neg_train])
    # Validation dataset
    val = pd.concat([pos_val, neg_val])
    # Test dataset
    test = pd.concat([pos_test, neg_test])

    # Shuffle to ensure randomness
    train = train.sample(frac=1, random_state=42).reset_index(drop=True)
    val = val.sample(frac=1, random_state=42).reset_index(drop=True)
    test = test.sample(frac=1, random_state=42).reset_index(drop=True)

    assert train["domain"].nunique() == len(train)
    assert val["domain"].nunique() == len(val)
    assert test["domain"].nunique() == len(test)
    assert data.shape[0] == len(train) + len(val) + len(test)

    return train, val, test


def save_data(train, val, test, path="data/"):
    train.to_csv(path + "train.csv", index=False)
    val.to_csv(path + "val.csv", index=False)
    test.to_csv(path + "test.csv", index=False)


def split_and_save_data(file_path, data_dir="data/"):
    data = read_data(file_path)
    train, val, test = split_data(data)
    save_data(train, val, test, path=data_dir)

In [52]:
split_and_save_data(FileDef.ALL.value, data_dir="/home/chance/GitHub/gradient_boosted_dns/data/")

SIZE OF DATA = 2136074
POS TRAIN ROWS = 573445
POS TEST_VAL ROWS = 573444
POS TEST ROWS = 430083
POS VAL ROWS = 143361
NEG TEST ROWS = 494593
NEG TEST_VAL ROWS = 494592
NEG TEST ROWS = 370944
NEG VALIDATION ROWS = 123648
