In [1]:
# default_exp model.train_test_split

# Splitting data into train and test set

> API details.

In [2]:
#hide
from nbdev.showdoc import *

In [3]:
#export
import os
import pandas as pd
from sample_project import config
from sample_project.helper import write_to_csv, read_from_csv
from fastcore.utils import store_attr
import numpy as np
from sklearn.model_selection import train_test_split
from IPython.display import display


In [4]:
#hide
import warnings
warnings.filterwarnings("ignore")

In [17]:
#export
class Stratified_Split:
    '''
    This class is built to split model data into train and test set
    
    Args:
            model_data (Pandas DataFrame): The csv file name which has model dataset with "client_id", features and label which is "churn_or_not"
            test_size (float): Percentage of data to be used as test set
            seed (integer): number for randomization of initial point 
            do_downsample (boolean): If it's wanted to do downsampling 

        Return:
            train (pandas DataFrame): the train set which includes both features and label 
            test (pandas DataFrame): the test set which includes both features and label 
    
    '''
    def __init__(self, model_data = config.CSV_CHURN_MODEL, test_size=0.1, seed=42, do_downsample=False):
        
        store_attr()
        
    
    def __call__(self):
        
        df_model_data = read_from_csv(self.model_data)
        
        train, test = train_test_split(df_model_data, stratify = df_model_data["churn_or_not"], test_size=self.test_size, random_state=self.seed)
        
        if self.do_downsample:
            train = self._downsample(train)
        
        return train, test
        
    def _downsample(self,df):
        return pd.concat([df[df["churn_or_not"]==1],
                         df[df["churn_or_not"]==0].sample(len(df[df["churn_or_not"]==1]))],axis=0)
        

In [20]:
#hide
splitter = Stratified_Split(test_size=0.1, seed=42, do_downsample=False)
train,test = splitter()

In [21]:
#hide
train

Unnamed: 0,client_id,sum_bin_0,count_bin_0,sum_bin_1,count_bin_1,sum_bin_2,count_bin_2,sum_bin_3,count_bin_3,sum_bin_4,...,count_bin_5,sum_bin_6,count_bin_6,sum_bin_7,count_bin_7,sum_bin_8,count_bin_8,sum_bin_9,count_bin_9,churn_or_not
143,3114,0.0,0.0,0.0,0.0,0.0,0.0,35499.7,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,6.0,0.0,0.0,73.0,0
596,12102,511528.9,0.0,0.0,1253666.0,0.0,0.0,1290719.0,0.0,0.0,...,0.0,0.0,106.0,0.0,0.0,105.0,0.0,0.0,100.0,0
344,7299,364697.3,0.0,0.0,587844.1,0.0,0.0,564042.2,0.0,0.0,...,0.0,0.0,110.0,0.0,0.0,106.0,0.0,0.0,109.0,0
328,6829,0.0,0.0,0.0,0.0,0.0,0.0,599186.9,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,63.0,0.0,0.0,86.0,0
98,2470,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
428,9428,0.0,0.0,0.0,0.0,0.0,0.0,791594.4,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,51.0,0.0,0.0,82.0,0
163,3606,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,43.0,0
16,383,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,66.0,0
497,10351,14500.0,0.0,0.0,138434.9,0.0,0.0,1324942.7,0.0,0.0,...,0.0,0.0,23.0,0.0,0.0,122.0,0.0,0.0,88.0,0


In [22]:
#hide
test

Unnamed: 0,client_id,sum_bin_0,count_bin_0,sum_bin_1,count_bin_1,sum_bin_2,count_bin_2,sum_bin_3,count_bin_3,sum_bin_4,...,count_bin_5,sum_bin_6,count_bin_6,sum_bin_7,count_bin_7,sum_bin_8,count_bin_8,sum_bin_9,count_bin_9,churn_or_not
610,12397,732133.5,0.0,0.0,1415089.8,0.0,0.0,299002.2,0.0,0.0,...,0.0,0.0,103.0,0.0,0.0,59.0,0.0,0.0,71.0,0
101,2501,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,62.0,0
296,6301,179302.4,0.0,0.0,408785.9,0.0,0.0,408101.9,0.0,0.0,...,0.0,0.0,119.0,0.0,0.0,122.0,0.0,0.0,122.0,0
365,7765,0.0,0.0,0.0,0.0,0.0,0.0,65817.4,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,11.0,0.0,0.0,72.0,0
1,3,429529.9,0.0,0.0,524233.8,0.0,0.0,524215.4,0.0,0.0,...,0.0,0.0,85.0,0.0,0.0,85.0,0.0,0.0,85.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
550,11380,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,13.0,0
494,10287,383164.5,0.0,0.0,556851.0,0.0,0.0,560331.1,0.0,0.0,...,0.0,0.0,97.0,0.0,0.0,96.0,0.0,0.0,95.0,0
490,10269,0.0,0.0,0.0,476663.5,0.0,0.0,461565.8,0.0,0.0,...,0.0,0.0,42.0,0.0,0.0,76.0,0.0,0.0,95.0,0
495,10288,383164.5,0.0,0.0,556851.0,0.0,0.0,560331.1,0.0,0.0,...,0.0,0.0,97.0,0.0,0.0,96.0,0.0,0.0,95.0,0


In [23]:
#export
def check_class_balance(df):
    grouped = df.groupby("churn_or_not").size().rename("count").to_frame()
    grouped["percentage"] = grouped["count"] / len(df)
    return grouped

In [24]:
#hide
display(check_class_balance(train))
check_class_balance(test)

Unnamed: 0_level_0,count,percentage
churn_or_not,Unnamed: 1_level_1,Unnamed: 2_level_1
0,566,0.920325
1,49,0.079675


Unnamed: 0_level_0,count,percentage
churn_or_not,Unnamed: 1_level_1,Unnamed: 2_level_1
0,64,0.927536
1,5,0.072464
