In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit

In [2]:
# Generate the Raw Data
np.random.seed(42)
names = np.random.permutation([f"Student {i}" for i in range(200)])
genders = np.random.permutation(["boy"] * 150 + ["girl"] * 40 + ["other"] * 10)
places = np.random.permutation(["hbh"] * 180 + ["mbr"] * 15 + ["blr"] * 5)

In [3]:
# Create DataFrame
df_students = pd.DataFrame({
    "name": names,
    "gender": genders,
    "place": places
})

df_students

Unnamed: 0,name,gender,place
0,Student 95,girl,hbh
1,Student 15,boy,hbh
2,Student 30,boy,hbh
3,Student 158,girl,hbh
4,Student 128,boy,hbh
...,...,...,...
195,Student 106,girl,hbh
196,Student 14,boy,hbh
197,Student 92,boy,hbh
198,Student 179,boy,hbh


In [4]:
# Standard Random Split
train_set, test_set = train_test_split(df_students, test_size=0.2)

train_set.shape
test_set.shape

(160, 3)

(40, 3)

In [5]:
def compare_proportions(df, train, test, col):
  print(f"\n--- Distribution for Column: '{col}' ---")

  # Calculate percentages
  overall = df[col].value_counts(normalize=True).sort_index()
  train_dist = train[col].value_counts(normalize=True).sort_index()
  test_dist = test[col].value_counts(normalize=True).sort_index()

  # Combine into a single dataframe for easy comparison
  comparison = pd.DataFrame({
      "Overall %": overall,
      "Train %": train_dist,
      "Test %": test_dist
  })

  # Fill NaN with 0 (in case a category is missing entirely from a set)
  comparison = comparison.fillna(0)

  # Display as percentage strings
  print(comparison.map(lambda x: f"{x:.1%}"))

In [6]:
compare_proportions(df_students, train_set, test_set, "gender")
compare_proportions(df_students, train_set, test_set, "place")


--- Distribution for Column: 'gender' ---
       Overall % Train % Test %
gender                         
boy        75.0%   75.0%  75.0%
girl       20.0%   20.0%  20.0%
other       5.0%    5.0%   5.0%

--- Distribution for Column: 'place' ---
      Overall % Train % Test %
place                         
blr        2.5%    3.1%   0.0%
hbh       90.0%   88.8%  95.0%
mbr        7.5%    8.1%   5.0%


In [7]:
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2)

for train_index, test_index in split.split(df_students, df_students['gender']):
  strat_train_set = df_students.loc[train_index]
  strat_test_set = df_students.loc[test_index]

strat_train_set.shape
strat_test_set.shape

(160, 3)

(40, 3)

In [8]:
compare_proportions(df_students, strat_train_set, strat_test_set, "gender")
compare_proportions(df_students, strat_train_set, strat_test_set, "place")


--- Distribution for Column: 'gender' ---
       Overall % Train % Test %
gender                         
boy        75.0%   75.0%  75.0%
girl       20.0%   20.0%  20.0%
other       5.0%    5.0%   5.0%

--- Distribution for Column: 'place' ---
      Overall % Train % Test %
place                         
blr        2.5%    3.1%   0.0%
hbh       90.0%   88.1%  97.5%
mbr        7.5%    8.8%   2.5%


Stratified sampling works only on a single column

If you want to use multiple columns for stratified sampling, you have to create another column by combining the both columns

Then you will do the stratified sampling on that new column

⚠️ And also ensure that no category has less than two member

In [9]:
df_students['combo'] = df_students['gender'] + '_' + df_students['place']
df_students['combo'].value_counts()

combo
boy_hbh      135
girl_hbh      38
boy_mbr       11
other_hbh      7
boy_blr        4
girl_mbr       2
other_mbr      2
other_blr      1
Name: count, dtype: int64

⚠️ Below code is expected to raise error if any category with just one member is detected

In [10]:
for train_index, test_index in split.split(df_students, df_students['combo']):
  strat_train_set = df_students.loc[train_index]
  strat_test_set = df_students.loc[test_index]

strat_train_set.shape
strat_test_set.shape

compare_proportions(df_students, strat_train_set, strat_test_set, "gender")
compare_proportions(df_students, strat_train_set, strat_test_set, "place")

ValueError: The least populated classes in y have only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2. Classes with too few members are: ['other_blr']