In [30]:
import polars as pl 
import holoviews as hv
import hvplot.polars

In [2]:
df = pl.read_parquet('./plotting_dev_parquet.parquet')
plot_df = df.group_by('symbol', 'date').agg(
    pl.col('max_tick_to_query_lag').mean(), 
    pl.col('return').sum(), 
    (pl.col('return').null_count() / pl.len()).alias('null_frac')
).with_columns(pl.col('symbol').cast(pl.String).str.len_chars().alias('weight'))
plot_df

symbol,date,max_tick_to_query_lag,return,null_frac,weight
enum,date,duration[μs],f64,f64,u32
"""ACA""",2025-02-18,25s 275121µs,-0.054059,0.319444,3
"""GALA""",2023-08-27,27s 176972µs,-0.002665,0.333333,4
"""OAX""",2024-04-04,13s 88659µs,0.212224,0.125,3
"""LUMIA""",2025-05-05,39s 890015µs,-0.006051,0.465278,5
"""TLM""",2025-04-29,42s 875152µs,-0.051738,0.5,3
…,…,…,…,…,…
"""MAGIC""",2023-03-30,5s 254902µs,-0.014546,0.013889,5
"""XTZ""",2025-10-01,25s 773774µs,0.048898,0.229167,3
"""PHA""",2023-03-02,39s 899277µs,0.004831,0.444444,3
"""GTC""",2022-06-18,26s 996451µs,-0.07094,0.333333,3


In [58]:
from typing import List, Union, Dict, Callable, Any
from datetime import (
    datetime as Datetime, 
    date as Date, 
    timedelta as Timedelta
)

def quick_format(content: Union[float, Timedelta, str]) -> str: 
    """
    Formats float to :.2f (two decimal places)
    Formats timedelta to seconds.(3-decimal places)
    Directly returns string otherwise
    """
    if isinstance(content, float):
        return f"{content:.2f}"
    elif isinstance(content, Timedelta):
        return f"{content.total_seconds():.3f}"
    else:
        return str(content)
    
def slice_plots(
    plot_df: pl.DataFrame, 
    accum_cols: List[pl.Expr],
    cont_feature_cols: List[pl.Expr],
    weight_col: pl.Expr, 
    n_ticks=5, # Number of ticks along x-axis 
    plot_subsample: int = 10_000, # Subsampling after computing full sort
    format_fns: Dict[str, Callable[Any, str]] = {}
):
    plots = []
    for cont_feature in cont_feature_cols:
        cont_feature_name = cont_feature.meta.output_name()
        sorted_df = plot_df.select(
            cont_feature, weight_col, *accum_cols
        ).sort(cont_feature, descending=False).with_columns(
            weight_col.cum_sum() / weight_col.sum(),
            *[c.cum_sum() for c in accum_cols]
        ).sample(min(plot_subsample, len(plot_df)), seed=0).sort(cont_feature)

        quantiles = np.linspace(0, 1, n_ticks)
        xticks = []
        format_fn = format_fns[cont_feature_name] if cont_feature_name in format_fns else quick_format
        for q in quantiles:
            q_value = sorted_df[cont_feature_name].quantile(q)
            closest_idx = (sorted_df[cont_feature_name] - q_value).abs().arg_min()
            xticks.append((
                sorted_df['weight'][closest_idx],
                f"{format_fn(sorted_df[cont_feature_name][closest_idx])}"
            ))

        plot = sorted_df.hvplot.line(
            x='weight',
            y=[ac.meta.output_name() for ac in accum_cols],
            hover_cols=[cont_feature_name]
        ).opts(
            xticks=xticks,
            xlabel=cont_feature_name
        )
        plots.append(plot)
    return plots

plot_subsample = 10_000
accum_cols = [pl.col('null_frac')]
weight_col = pl.col('weight')
n_ticks = 5
cont_feature_cols = [pl.col('return'), pl.col('max_tick_to_query_lag')]

plots = slice_plots(plot_df, accum_cols, cont_feature_cols, weight_col)
[display(p) for p in plots]

[None, None]