

## Define Split Data Function

    - INPUT: Processed Data path
    - Target column name
    - Number of folds
    - random seed number
    - output path folder

In [45]:
import os
import pickle

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

In [53]:
def split_data(df, target_column_name, number_folds, index_col=None, test_size=0.2, random_seed=0, save_in=None):
    np.random.seed(random_seed)

    data_ready_to_model = {}
    
    # Read DataFrame
    if isinstance(df, str):
        df = pd.read_csv(df)
        
    elif isinstance(df, pd.DataFrame):
        df = df.copy()
        
    else:
        ValueError("df should be a path file or a pandas DataFrame")
        
    # Set Index if not None
    if index_col is not None:
        df.set_index(index_col, inplace=True)
    
    # Split Features and Target
    y = df[target_column_name].copy()
    X = df.drop([target_column_name], axis=1).copy()
    
    # Splait Train and Test
    if isinstance(number_folds, int):
        # 2 folds
        if number_folds <= 2:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_seed)
            data_ready_to_model[0] = {
                                      "X_train": X_train,
                                      "X_test": X_test,
                                      "y_train": y_train,
                                      "y_test": y_test,
                                     }
            
            
        # K-fold (more than 2)
        else:
            kf = KFold(n_splits=number_folds, shuffle=True, random_state=random_seed)
        
            for i, (train_index, test_index) in enumerate(kf.split(X)):
                X_train, X_test = X.iloc[train_index], X.iloc[test_index]
                y_train, y_test = y.iloc[train_index], y.iloc[test_index]
                data_ready_to_model[i] = {
                                      "X_train": X_train,
                                      "X_test": X_test,
                                      "y_train": y_train,
                                      "y_test": y_test,
                                     }
    
    if save_in is not None:
        directory = os.path.dirname(save_in)
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        with open(save_in, 'wb') as f:
            pickle.dump(data_ready_to_model, f)

    return data_ready_to_model

## Call in Main

In [54]:
input_path_file = "/Users/lalachaimaenaciri/PycharmProjects/SCORE_LOW_HIGH_CAPSADSTR_INTENSITY/data/processed/clean_transcelerator_mar_2022.csv"
output_path_file = "/Users/lalachaimaenaciri/PycharmProjects/SCORE_LOW_HIGH_CAPSADSTR_INTENSITY/data/ready/splitted_data.pkl"               

target_column_name = 'score_lowhigh_capsadstr_int1'

splitted_data = split_data(df=input_path_file,
           index_col="No",
           target_column_name=target_column_name, 
           number_folds=3, random_seed=0, save_in=output_path_file)

In [56]:
splitted_data[2]['X_train']

Unnamed: 0_level_0,sour_low_taste_correct,astring_low_int,score_lowhigh_capsadstr1,astring_high_taste_correct,salty_low_int,sweet_high_taste_correct,bitter_high_taste_correct,sweet_low_int,Age,BMI,...,bitter_low_taste_correct,bitter_high_int,salty_high_int,bitter_low_int,sour_low_int,depression,astring_low_taste_correct,salty_low_taste_correct,sex_f1_m2_1.0,sex_f1_m2_2.0
No,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2,0.0,2.0,7.0,1.0,3.0,1.0,0.0,3.0,24.0,24.897060,...,0.0,1.0,4.0,3.0,3.0,16.0,1.0,1.0,0,1
3,1.0,3.0,12.0,1.0,5.0,1.0,1.0,3.0,25.0,24.618104,...,1.0,5.0,5.0,4.0,3.0,17.0,1.0,1.0,0,1
4,0.0,2.0,9.0,1.0,4.0,1.0,1.0,0.0,24.0,20.756387,...,1.0,4.0,3.0,3.0,0.0,18.0,1.0,1.0,0,1
5,1.0,1.0,11.0,1.0,4.0,1.0,1.0,1.0,23.0,24.212293,...,1.0,4.0,3.0,4.0,4.0,16.0,1.0,1.0,1,0
7,1.0,4.0,11.0,1.0,3.0,1.0,1.0,1.0,28.0,22.647377,...,0.0,4.0,3.0,0.0,1.0,23.0,1.0,1.0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
C50,1.0,2.0,10.0,1.0,1.0,1.0,1.0,2.0,37.0,24.100000,...,0.0,2.0,2.0,0.0,1.0,12.0,1.0,1.0,1,0
C6,0.0,0.0,7.0,0.0,4.0,1.0,1.0,2.0,38.0,35.294118,...,0.0,1.0,5.0,0.0,0.0,5.0,0.0,1.0,1,0
C7,0.0,0.0,6.0,0.0,3.0,1.0,0.0,2.0,52.0,22.498174,...,0.0,0.0,4.0,0.0,0.0,10.0,0.0,1.0,0,1
C8,0.0,0.0,7.0,0.0,1.0,1.0,1.0,0.0,33.0,19.267171,...,1.0,2.0,2.0,1.0,0.0,14.0,0.0,1.0,1,0
