---
date:
  created: 2025-04-27
  updated: 2025-04-27

categories:
- Data preparation

tags:
- Polars

slug: polars-stratified-train-test-split
---

# Stratified train-test split with Polars

<a href="https://colab.research.google.com/github/dd-n-kk/notebooks/blob/main/blog/polars-stratified-train-test-split.ipynb" target="_parent">
    :simple-googlecolab: Colab notebook
</a>

<!-- more -->

## Preparations

In [2]:
!uv pip install -Uq polars

In [3]:
import numpy as np
import polars as pl

In [4]:
_ = pl.Config(
    float_precision=3,
    fmt_str_lengths=200,
    fmt_table_cell_list_len=-1,
    tbl_cols=-1,
    tbl_rows=100,
    tbl_width_chars=100,
)

rng = np.random.default_rng(seed=777)

## Dummy data set

In [6]:
labels = rng.choice(4, size=1000, p=[0.1, 0.2, 0.3, 0.4])
features = rng.standard_normal((1000, 2)) * 0.1 + labels[:, None]

data = (
    pl.concat(
        (
            pl.from_numpy(features, schema=["feat_1", "feat_2"]),
            pl.from_numpy(labels, schema=["label"]),
        ),
        how="horizontal",
    )
    .with_row_index(name="id")
)

In [15]:
data.sample(5)

id,feat_1,feat_2,label
u32,f64,f64,i64
697,1.121,0.946,1
174,2.938,3.047,3
941,2.881,3.01,3
867,1.995,2.098,2
765,2.82,3.101,3


In [16]:
data.get_column("label").value_counts().sort("label")

label,count
i64,u32
0,98
1,211
2,297
3,394


## Stratified train-test split

### Train split

In [19]:
train_split = data.select(
    pl.all()
    .sample(fraction=0.9, shuffle=True, seed=777)
    .over("label", mapping_strategy="explode")
)

In [20]:
train_split.sample(5)

id,feat_1,feat_2,label
u32,f64,f64,i64
860,2.946,2.852,3
876,1.893,1.875,2
327,3.035,3.023,3
657,-0.17,0.106,0
607,-0.188,0.241,0


In [18]:
train_split.shape

(898, 4)

In [23]:
train_split.get_column("label").value_counts(normalize=True).sort("label")

label,proportion
i64,f64
0,0.098
1,0.21
2,0.297
3,0.394


### Test (or validation) split

In [24]:
test_split = data.join(train_split, on="id", how="anti")

In [25]:
test_split.shape

(102, 4)

In [26]:
test_split.get_column("label").value_counts(normalize=True).sort("label")

label,proportion
i64,f64
0,0.098
1,0.216
2,0.294
3,0.392
