In [None]:
# default_exp training.balance

# Balance training data

> Various methods to balance training data by target class (label) for model training

In [None]:
# hide
from nbdev.showdoc import *
from fastcore.test import *
import dev_tools as dt

In [None]:
# export 
import pandas as pd
import numpy as np

## Downsample 

In [None]:
# export
def downsample(df:pd.DataFrame, y_column:str, random_state:int, min_size:int=None, **kwargs) -> pd.DataFrame:
    '''Balance classes of the target variable by downsampling all classes to be equal to or smaller than "min_size".
    
    Classes smaller than "min_size" are not affected and will remain at their current size. If "min_size" is ommitted, 
    the size of the smallest current class is taken as "min_size".
    
    Paramters
    
    df : pandas dataframe
         Dataframe containing column "y_column".
            
    y_column : str
               Name of df column containing the target variable (label).
    
    ranomd_state : int
                   Random state for reproducibility. 
    
    min_size : int, default=None
               If no value is supplied, min_size will be set to the size of the smallest current class. 
    
    Returns
    
    new_df : pandas dataframe
             Has the same structure as the input dataframe but classes were balanced by downsampling.
    
    '''
    df_new = df.copy()
    
    # get smallest current class if not supplied
    if min_size == None:
        min_size = df_new[y_column].value_counts().min()
    
    # downsample all classes larger than min_size
    for class_index, group in df_new.groupby(y_column):
        if group.shape[0] > min_size:
            drop_idx = group.sample(len(group)-min_size, random_state=random_state, **kwargs).index
            df_new = df_new.drop(drop_idx)
            
    return df_new.reset_index(drop=True)

In [None]:
df = pd.DataFrame({"x": ["A", "B", "B", "C", "C", "C"], "y": [0, 1, 1, 2, 2, 2]})
df.groupby("x").count()

Unnamed: 0_level_0,y
x,Unnamed: 1_level_1
A,1
B,2
C,3


In [None]:
new_df = downsample(df=df, y_column="y", random_state=dt.random_state)
new_df.groupby("x").count()

Unnamed: 0_level_0,y
x,Unnamed: 1_level_1
A,1
B,1
C,1


In [None]:
# hide

# downsample unit tests
# test 0 
new_df = downsample(df=df, y_column="y", random_state=dt.random_state, min_size=0).groupby("x").count()
test_eq(list(new_df.y), [])

# test complete downsample 
new_df = downsample(df=df, y_column="y", random_state=dt.random_state).groupby("x").count()
test_eq(list(new_df.y), [1, 1, 1])

# test partial downsample
new_df = downsample(df=df, y_column="y", random_state=dt.random_state, min_size=2).groupby("x").count()
test_eq(list(new_df.y), [1, 2, 2])

# test none downsample
new_df = downsample(df=df, y_column="y", random_state=dt.random_state, min_size=3).groupby("x").count()
test_eq(list(new_df.y), [1, 2, 3])

# test too large
new_df = downsample(df=df, y_column="y", random_state=dt.random_state, min_size=4).groupby("x").count()
test_eq(list(new_df.y), [1, 2, 3])