# Train-Test Split

### Comparing random split and stratified shuffle split

In [1]:
# Imports

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit

In [2]:
data = pd.read_csv('../data/pre_encoded_price_categorized.csv')

# Split the data into training and testing sets (you can adjust the split ratio)
train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

In [3]:
print(len(train_set))
print(len(test_set))

74036
18510


In [4]:
# Using stratified sampling on the pre-encoded data to ensure the training and test sets contain representative data
# that represents all the identified categories of price (low, mid, high)

stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in stratified_split.split(data, data['price_category']):
    train_set_sss = data.iloc[train_index]
    test_set_sss = data.iloc[test_index]

In [5]:
print(len(train_set_sss))
print(len(test_set_sss))

74036
18510


### Comparison

In [6]:
sets_dict = {'original_data':data, 'test_set':test_set, 'stratified_test_set':test_set_sss, 'train_set':train_set, 'stratified_train_set':train_set_sss}

for df_name, df in sets_dict.items():

    print(f"Set name: {df_name}")

    # Count the number of rows where price category is 0, 1, and 2
    count_0 = len(df[df['price_category'] == 0])
    count_1 = len(df[df['price_category'] == 1])
    count_2 = len(df[df['price_category'] == 2])

    print(f"Number of rows where price_category = 0: {count_0}")
    print(f"Number of rows where price_category = 1: {count_1}")
    print(f"Number of rows where price_category = 2: {count_2}")

Set name: original_data
Number of rows where price_category = 0: 22908
Number of rows where price_category = 1: 69476
Number of rows where price_category = 2: 162
Set name: test_set
Number of rows where price_category = 0: 4643
Number of rows where price_category = 1: 13833
Number of rows where price_category = 2: 34
Set name: stratified_test_set
Number of rows where price_category = 0: 4582
Number of rows where price_category = 1: 13896
Number of rows where price_category = 2: 32
Set name: train_set
Number of rows where price_category = 0: 18265
Number of rows where price_category = 1: 55643
Number of rows where price_category = 2: 128
Set name: stratified_train_set
Number of rows where price_category = 0: 18326
Number of rows where price_category = 1: 55580
Number of rows where price_category = 2: 130


#### Stratified Shuffle Split gives more representative data in both training and test sets

In [7]:
train_set_sss.to_csv('../data/training_set.csv', index=False)
test_set_sss.to_csv('../data/test_set.csv', index=False)