## Using Sturges rule to calculate the optimal number of bins

Background :
Stratified k-fold cross validation is not used typically for regression (other CV methods are fine). In order to use the stratified k-fold CV, divide the data into smaller 'bins'. The optimal number of bins is given by the Sturge's rule:


    Num of bins = 1 + log2(N)

where N is the number of samples

In [1]:
import pandas as pd
import numpy as np

from sklearn import model_selection
from sklearn import datasets

In [2]:
# Create dummy input and output values
X, y = datasets.make_regression(n_samples=15000, n_targets=1, n_features=100)

# Create a dataframe from the input and target arrays
df = pd.DataFrame(X, columns=[f"f_{i}" for i in range(X.shape[1])])

df["target"] = y

In [3]:
# Create column with dummy value for folds
df["kfold"] = -1

# Randomize the data
df.sample(frac=1).reset_index(drop=True)

# Calculate the number of bins using sturges rule
bins = int(np.ceil(1 + np.log2(len(df)))) # Can use ceiling or floor. Get bins is an approximate method

# Bin the targets
df.loc[:,"bins"] = pd.cut(df['target'], bins=bins, labels=False)

# Initiate the k-folds now
kf = model_selection.StratifiedKFold(n_splits=5)

# Fill the new kfold column
for fold, (target_, value_) in enumerate(kf.split(X=df, y=df.bins.values)):
    df.loc[value_, "kfold"] = fold

# Drop the bins column
df.drop(['bins'], axis=1, inplace=True)

In [4]:
df.tail(50)

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
14950,0.372518,0.563671,-0.666906,0.912871,-0.744076,0.321511,-2.842873,-0.446016,0.370458,-0.364127,...,0.48398,-1.762496,0.913773,-0.754995,0.025514,0.29222,1.48664,-1.327782,-42.630091,4
14951,-1.659365,-1.14673,0.637508,0.611174,-0.702268,0.8669,0.850733,1.376983,-0.685474,0.645368,...,1.374035,-0.488398,0.549114,0.233011,-0.944768,-0.343452,0.289982,-0.625535,-243.745392,4
14952,-0.513404,1.913298,-1.058433,0.651829,-0.128734,1.031519,-0.607771,0.321517,0.109069,0.193519,...,-0.522092,-0.286793,1.772576,-0.851661,-1.196551,1.473177,-1.224553,0.538565,19.297297,4
14953,0.359106,0.447694,-0.56916,-0.127194,-1.113238,-0.675038,-1.250532,-0.481556,-2.038413,1.053992,...,1.921838,0.99758,-0.395666,-1.611333,-0.579812,-0.046189,1.228548,-0.206577,-94.498826,4
14954,0.087013,-1.359483,-0.564225,0.25041,0.830192,-0.728663,-0.654639,1.279279,-0.135438,-0.449493,...,-0.218978,-0.204899,-0.830342,0.858213,1.348992,-0.17913,-0.331093,1.484092,-61.930703,4
14955,0.122791,-0.748581,0.092903,-0.202781,1.240351,-0.200436,-0.533484,0.646984,0.275455,-0.288577,...,0.018714,0.241467,-0.138591,-0.70822,0.141917,-1.150785,0.548922,0.257285,-191.699173,4
14956,0.592128,0.186438,-2.05719,0.002763,-0.647374,0.406347,0.235986,0.342507,-2.184248,0.469424,...,-1.179084,1.176186,0.265047,0.998504,-0.103359,-0.176702,1.983139,0.4351,-203.760735,4
14957,-0.83579,0.048754,0.332741,-1.106676,0.791721,0.008645,1.832099,-0.414537,1.759333,-0.441312,...,-0.556317,0.080479,0.445804,-0.170944,0.893135,0.399043,0.088229,0.74547,-111.517766,4
14958,-0.748564,0.796371,0.015674,0.439978,-0.742984,-1.380572,0.220275,-1.237797,-0.23237,0.430666,...,0.420896,0.976894,0.658158,0.043314,-0.048503,0.001127,-1.220138,-0.253831,112.761925,4
14959,0.44104,-0.363766,0.136802,0.780112,0.334575,0.131689,-0.525071,-0.495669,-0.463855,0.59393,...,0.680188,-1.978721,-1.324441,-0.294587,-1.20658,0.520369,-0.734388,-1.689716,-14.575158,4
