In [7]:
import pandas as pd
import numpy as np
from sklearn.utils import shuffle

drop_feature=['quantity_group',        # identical as 'quantity'
              'source_type',           # 1-to-1 subset of 'source'
              'source_class',          # 1-to-1 subset of 'source_class'
              'waterpoint_type_group', # 1-to-1 subset of 'waterpoint_type'
              'quality_group',         # 1-to-1 subset of 'water_quality'
              'payment_type',          # 1-to-1 subset of 'payment'
              'management_group',      # 1-to-1 subset of 'management'
              'extraction_type_group', # 1-to-1 subset of 'extraction_type'
              'extraction_type_class'  # 1-to-1 subset of 'extraction_type'
             ]
def balance_class(traincsv, testcsv, label_name, drop_list=None, balance_rate=None, labelcsv=None):
    """
    traincsv: 
        type: str
        train csv file name
    testcsv:
        type: str
        test csv file name
    label_name: 
        type: str
        column name of label in label csv or train csv
    drop_list: 
        type: list
        a list of column names which will be dropped out
    balance_rate:
        type: float
        smaller classes will increase to balance_rate * number of biggest class
        ie, 250 data are class1, 100 data are class2, if balance_rate = 0.6
        class2 will increase to 150( = 250*0.6)
    labelcsv:
        type: str
        label csv file name
    """
    testout = pd.DataFrame.from_csv(testcsv, index_col = None)
    testout.drop(drop_list, axis=1, inplace = True)
    testout.to_csv(testcsv.split('.csv')[0] + '_out.csv', index=False, header = list(testout.columns))
    
    train = pd.DataFrame.from_csv(traincsv, index_col = None)
    if labelcsv is not None:
        label = pd.DataFrame.from_csv(labelcsv, index_col = None)
        train[label_name] = label[label_name]
        
    class_count = dict()
    class_count_max = 0
   
    for c in train[label_name].unique():
        class_count[c] = train[train[label_name] == c].shape[0]
        if class_count[c] > class_count_max:
            class_count_max = class_count[c]
    
    
    aug_list = list()
    for c in class_count:
        aug_temp = train[train[label_name] == c]
        aug_num = int(class_count_max*balance_rate) - aug_temp.shape[0]
        if aug_num > 0:
            aug_temp = aug_temp.sample(n=aug_num, replace = True).copy()
            aug_list.append(aug_temp)
    
    
    train_aug = pd.concat(aug_list)
    train_aug_shuf = shuffle(train_aug)
    trainout = pd.concat([train, train_aug_shuf])

    trainout.drop(drop_list, axis=1, inplace = True)
    
    labelout = pd.DataFrame(trainout[label_name])
    if labelcsv is not None:
        trainout.drop([label_name], axis=1, inplace = True)
        labelout.to_csv(labelcsv.split('.csv')[0] + '_out.csv', index=False, header = list(labelout.columns))
    
    
    trainout.to_csv(traincsv.split('.csv')[0] + '_out.csv', index=False, header = list(trainout.columns))

In [8]:
# example:
balance_class(traincsv = "train.csv", testcsv = "test.csv", label_name = "status_group",drop_list = drop_feature, 
                       balance_rate = 0.8, labelcsv = "label.csv")

train = pd.DataFrame.from_csv("train_out.csv", index_col=None)
test = pd.DataFrame.from_csv("label_out.csv", index_col=None)

In [11]:
print(train.shape)
train.head()

(83873, 31)


Unnamed: 0,id,amount_tsh,date_recorded,funder,gps_height,installer,longitude,latitude,wpt_name,num_private,...,scheme_name,permit,construction_year,extraction_type,management,payment,water_quality,quantity,source,waterpoint_type
0,69572,6000.0,2011-03-14,Roman,1390,Roman,34.938093,-9.856322,none,0,...,Roman,False,1999,gravity,vwc,pay annually,soft,enough,spring,communal standpipe
1,8776,0.0,2013-03-06,Grumeti,1399,GRUMETI,34.698766,-2.147466,Zahanati,0,...,,True,2010,gravity,wug,never pay,soft,insufficient,rainwater harvesting,communal standpipe
2,34310,25.0,2013-02-25,Lottery Club,686,World vision,37.460664,-3.821329,Kwa Mahundi,0,...,Nyumba ya mungu pipe scheme,True,2009,gravity,vwc,pay per bucket,soft,enough,dam,communal standpipe multiple
3,67743,0.0,2013-01-28,Unicef,263,UNICEF,38.486161,-11.155298,Zahanati Ya Nanyumbu,0,...,,True,1986,submersible,vwc,never pay,soft,dry,machine dbh,communal standpipe multiple
4,19728,0.0,2011-07-13,Action In A,0,Artisan,31.130847,-1.825359,Shuleni,0,...,,True,0,gravity,other,never pay,soft,seasonal,rainwater harvesting,communal standpipe


In [12]:
print(test.shape)
test.head()

(83873, 1)


Unnamed: 0,status_group
0,functional
1,functional
2,functional
3,non functional
4,functional
