In [1]:
from typing import Any

import polars as pl
from great_tables import GT, loc, style


@pl.api.register_expr_namespace("spt")
class DiscreteSplitter:
    def __init__(self, expr: pl.Expr) -> None:
        self._expr = expr

    def _mod_expr(self, n: int) -> pl.Expr:
        return pl.int_range(pl.len(), dtype=pl.UInt32).mod(n)

    def binarize(
        self, lit1: str, lit2: str, name: str = "binarized"
    ) -> pl.Expr:
        mod_expr = self._mod_expr(2)
        return (
            pl.when(mod_expr.eq(0))
            .then(pl.lit(lit1))
            .otherwise(pl.lit(lit2))
            .alias(name)
        )

    def trinarize(
        self, lit1: str, lit2: str, lit3: str, name: str = "trinarized"
    ) -> pl.Expr:
        mod_expr = self._mod_expr(3)
        return (
            pl.when(mod_expr.eq(0))
            .then(pl.lit(lit1))
            .when(mod_expr.eq(1))
            .then(pl.lit(lit2))
            .otherwise(pl.lit(lit3))
            .alias(name)
        )

    def bucketize(
        self, lits: list[Any], name: str = "bucketized"
    ) -> pl.Expr:
        mod_expr = self._mod_expr(len(lits))

        # first
        expr = pl.when(mod_expr.eq(0)).then(pl.lit(lits[0]))

        # middles
        for i, one_lit in enumerate(lits[1:-1], start=1):
            expr = expr.when(mod_expr.eq(i)).then(pl.lit(one_lit))

        # last
        expr = expr.otherwise(pl.lit(lits[-1]))
        return expr.alias(name)


df = (
    pl.DataFrame({"n": [100, 50, 72, 83, 97, 42, 20, 51, 77]})
    .with_row_index(offset=1)
    .with_columns(
        pl.col("").spt.binarize("lightblue", "papayawhip"),
        pl.col("").spt.trinarize("one", "two", "three"),
        pl.col("").spt.bucketize([1, 2, 3, 4]),
    )
)

print(df)

shape: (9, 5)
┌───────┬─────┬────────────┬────────────┬────────────┐
│ index ┆ n   ┆ binarized  ┆ trinarized ┆ bucketized │
│ ---   ┆ --- ┆ ---        ┆ ---        ┆ ---        │
│ u32   ┆ i64 ┆ str        ┆ str        ┆ i32        │
╞═══════╪═════╪════════════╪════════════╪════════════╡
│ 1     ┆ 100 ┆ lightblue  ┆ one        ┆ 1          │
│ 2     ┆ 50  ┆ papayawhip ┆ two        ┆ 2          │
│ 3     ┆ 72  ┆ lightblue  ┆ three      ┆ 3          │
│ 4     ┆ 83  ┆ papayawhip ┆ one        ┆ 4          │
│ 5     ┆ 97  ┆ lightblue  ┆ two        ┆ 1          │
│ 6     ┆ 42  ┆ papayawhip ┆ three      ┆ 2          │
│ 7     ┆ 20  ┆ lightblue  ┆ one        ┆ 3          │
│ 8     ┆ 51  ┆ papayawhip ┆ two        ┆ 4          │
│ 9     ┆ 77  ┆ lightblue  ┆ three      ┆ 1          │
└───────┴─────┴────────────┴────────────┴────────────┘


In [2]:
(
    GT(df)
    .tab_style(style=style.fill(pl.col("binarized")), locations=loc.body())
    .opt_stylize(style=6)
)

index,n,binarized,trinarized,bucketized
1,100,lightblue,one,1
2,50,papayawhip,two,2
3,72,lightblue,three,3
4,83,papayawhip,one,4
5,97,lightblue,two,1
6,42,papayawhip,three,2
7,20,lightblue,one,3
8,51,papayawhip,two,4
9,77,lightblue,three,1
