In [1]:
# default_exp model.train_test_split

# Splitting data into train and test set

> API details.

In [2]:
#hide
from nbdev.showdoc import *

In [27]:
#export
import os
import pandas as pd
from sample_project import config
from sample_project.helper import write_to_csv, read_from_csv
from fastcore.utils import store_attr
import numpy as np
from sklearn.model_selection import train_test_split
from IPython.display import display


In [2]:
#hide
import warnings
warnings.filterwarnings("ignore")

In [20]:
#export
class Stratified_Split:
    '''
    This class is built to split model data into train and test set
    
    Args:
            model_data (Pandas DataFrame): The csv file name which has model dataset with "client_id", features and label which is "churn_or_not"
            test_size (float): Percentage of data to be used as test set
            seed (integer): number for randomization of initial point 
            do_downsample (boolean): If it's wanted to do downsampling 

        Return:
            train (pandas DataFrame): the train set which includes both features and label 
            test (pandas DataFrame): the test set which includes both features and label 
    
    '''
    def __init__(self, model_data = None, test_size=0.1, seed=42, do_downsample=False):
        
        store_attr()
        
        if model_data == None: model_data = config.CSV_CHURN_MODEL
        self.model_data = read_from_csv(model_data)
        
    
    def __call__(self):
        
        train, test = train_test_split(self.model_data, stratify = self.model_data["churn_or_not"], test_size=self.test_size, random_state=self.seed)
        
        if self.do_downsample:
            train = self._downsample(train)
        
        return train, test
        
    def _downsample(self,df):
        return pd.concat(model_data[model_data["churn_or_not"]==1],
                         model_data[model_data["churn_or_not"]==0].sample(len(model_data[model_data["churn_or_not"]==1])))
        

In [22]:
#hide
splitter = Stratified_Split()
train,test = splitter()

In [23]:
#hide
train

Unnamed: 0,client_id,sum_bin_0,count_bin_0,sum_bin_1,count_bin_1,sum_bin_2,count_bin_2,sum_bin_3,count_bin_3,sum_bin_4,...,count_bin_5,sum_bin_6,count_bin_6,sum_bin_7,count_bin_7,sum_bin_8,count_bin_8,sum_bin_9,count_bin_9,churn_or_not
265,4672,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,134957.3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,66.0,0
583,10209,0.0,0.0,0.0,0.0,0.0,0.0,65309.4,257224.1,77933.4,...,0.0,0.0,0.0,0.0,0.0,13.0,71.0,24.0,102.0,0
120,2342,0.0,0.0,0.0,0.0,0.0,0.0,0.0,181311.4,111537.9,...,0.0,0.0,0.0,0.0,0.0,0.0,29.0,18.0,72.0,0
443,7683,107328.5,73797.8,233675.1,120752.4,195477.0,171614.5,128632.4,207424.5,75244.8,...,17.0,50.0,26.0,40.0,37.0,29.0,50.0,18.0,68.0,0
651,11286,0.0,0.0,0.0,0.0,0.0,0.0,0.0,240785.9,199364.3,...,0.0,0.0,0.0,0.0,0.0,0.0,26.0,32.0,132.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,9973,426316.2,118728.2,235872.8,499191.2,557721.6,788374.7,292461.5,844229.0,422139.4,...,22.0,54.0,44.0,63.0,62.0,38.0,79.0,30.0,105.0,0
659,11362,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,19612.3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,92.0,0
393,6691,0.0,0.0,0.0,27690.9,114428.0,154649.2,76945.6,189520.7,54629.9,...,0.0,0.0,7.0,42.0,60.0,39.0,74.0,22.0,99.0,0
576,10095,411446.3,245865.4,842593.4,425699.1,643893.7,614890.2,446013.9,826855.9,254600.7,...,23.0,57.0,38.0,46.0,49.0,31.0,63.0,20.0,75.0,0


In [24]:
#hide
test

Unnamed: 0,client_id,sum_bin_0,count_bin_0,sum_bin_1,count_bin_1,sum_bin_2,count_bin_2,sum_bin_3,count_bin_3,sum_bin_4,...,count_bin_5,sum_bin_6,count_bin_6,sum_bin_7,count_bin_7,sum_bin_8,count_bin_8,sum_bin_9,count_bin_9,churn_or_not
552,9844,10500.0,113540.3,638437.0,151882.7,192088.0,507273.2,332163.0,443075.0,165240.0,...,9.0,60.0,29.0,35.0,55.0,40.0,63.0,18.0,95.0,0
518,9419,87122.8,105615.2,310880.8,163079.3,241292.1,250419.4,155712.6,292434.3,82249.7,...,20.0,62.0,39.0,45.0,51.0,32.0,68.0,18.0,86.0,0
272,4786,0.0,0.0,0.0,0.0,0.0,0.0,243327.7,412655.5,204117.6,...,0.0,0.0,0.0,0.0,0.0,14.0,55.0,20.0,100.0,0
81,1594,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,200.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,56.0,0
419,7195,214184.1,261979.0,720611.9,362464.3,550375.1,533405.3,399946.5,672040.8,251319.7,...,31.0,93.0,52.0,67.0,71.0,55.0,94.0,30.0,123.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
407,6895,84602.6,69952.9,191724.4,109476.3,179550.4,174593.5,119437.4,197800.8,65300.8,...,22.0,48.0,32.0,40.0,48.0,28.0,56.0,18.0,44.0,1
176,3072,0.0,0.0,0.0,0.0,0.0,0.0,15297.2,259340.2,81282.1,...,0.0,0.0,0.0,0.0,0.0,3.0,63.0,25.0,101.0,0
697,11866,398854.9,165662.3,416581.4,210791.2,325099.5,310503.5,216778.3,567034.4,164096.6,...,27.0,69.0,40.0,54.0,53.0,38.0,77.0,23.0,96.0,0
206,3608,0.0,0.0,0.0,0.0,0.0,0.0,17827.1,1015742.8,271477.4,...,0.0,0.0,0.0,0.0,0.0,7.0,70.0,22.0,69.0,0


In [25]:
#export
def check_class_balance(df):
    grouped = df.groupby("churn_or_not").size().rename("count").to_frame()
    grouped["percentage"] = grouped["count"] / len(df)
    return grouped

In [29]:
#hide
display(check_class_balance(train))
check_class_balance(test)

Unnamed: 0_level_0,count,count_total_pct
churn_or_not,Unnamed: 1_level_1,Unnamed: 2_level_1
0,689,0.926075
1,55,0.073925


Unnamed: 0_level_0,count,count_total_pct
churn_or_not,Unnamed: 1_level_1,Unnamed: 2_level_1
0,77,0.927711
1,6,0.072289
