In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os

def split_and_visualize_dataset(file_path, output_dir):
    """
    Splits the dataset into training, validation, and test sets,
    and saves a visualization of the split sizes.

    Parameters:
    - file_path: str, path to the CSV file.
    - output_dir: str, directory where the visualization will be saved.

    Returns:
    - None, but saves the visualization as a PNG file in the specified directory.
    """
    # Load the dataset
    data = pd.read_csv(file_path)

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Separate features and target variable
    X = data.drop(columns=["Market_Label", "Date"])
    y = data["Market_Label"]

    # First, split into training (70%) and temp (30% for validation and test)
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.3, stratify=y, random_state=42
    )

    # Split the temp set further into validation (15%) and test (15%)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42
    )

    # Visualize the split sizes
    split_sizes = [len(X_train), len(X_val), len(X_test)]
    labels = ['Training Set', 'Validation Set', 'Test Set']

    plt.figure(figsize=(8, 5))
    plt.bar(labels, split_sizes, color=['blue', 'orange', 'green'])
    plt.title('Dataset Split Visualization', fontsize=14)
    plt.ylabel('Number of Entries', fontsize=12)
    plt.xlabel('Dataset Splits', fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Save the plot
    output_file = os.path.join(output_dir, "dataset_split_visualization.png")
    plt.savefig(output_file)
    plt.close()

    print(f"Visualization saved to {output_file}")
    # Return the splits
    return X_train, X_val, X_test, y_train, y_val, y_test

# split the minmax nasdaq data
csv_file_path_minmax_nasdaq = "../data/min_max_scaling/cleaned_minmax_scaled_combined_data_nasdaq.csv"
output_directory_minmax_nasdaq = "../output/data_split/min_max_nasdaq"
split_and_visualize_dataset(csv_file_path_minmax_nasdaq, output_directory_minmax_nasdaq)

# split the minmax sp500 data
csv_file_path_minmax_sp500 = "../data/min_max_scaling/cleaned_minmax_scaled_combined_data_sp500.csv"
output_directory_minmax_sp500 = "../putout/data_split/min_max_sp500"
split_and_visualize_dataset(csv_file_path_minmax_sp500, output_directory_minmax_sp500)

Visualization saved to ../output/data_split/min_max_nasdaq\dataset_split_visualization.png
Visualization saved to ../putout/data_split/min_max_sp500\dataset_split_visualization.png


(      GDP Growth       CPI  Interest Rate  M2 Money Supply       PPI  \
 1539    0.787424  0.336773       0.640106         0.139852  0.371002   
 7494    0.852177  0.672548       0.471448         0.184823  0.625264   
 6565    0.504875  0.317634       0.281541         0.129007  0.413172   
 2128    0.800662  0.470632       0.653386         0.132127  0.447259   
 2653    0.441610  0.381315       0.512616         0.202239  0.542044   
 ...          ...       ...            ...              ...       ...   
 6091    0.410874  0.401877       0.241700         0.132772  0.547698   
 4283    0.083338  0.375403       0.419655         0.050818  0.592834   
 5937    0.410874  0.275296       0.140770         0.162658  0.343183   
 5835    0.430224  0.286010       0.168659         0.201407  0.310361   
 2599    0.362626  0.351912       0.374502         0.174354  0.530888   
 
       Unemployment Rate  VIX_Close  
 1539           0.061947   0.203127  
 7494           0.008850   0.226649  
 6565   