EDA

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go


# 1) Helper: compute bin statistics
def _bin_stats(df: pd.DataFrame,x_col: str, y_col: str, lq: float, uq: float, n_bins: int, exclude_outliers: bool = False
    
    if exclude_outliers:
        lo, hi = df[y_col].quantile([lq, uq])
        df = df[(df[y_col] >= lo) & (df[y_col] <= hi)]

    df = df.copy()
    df["_bin"] = pd.qcut(df[x_col], q=n_bins, duplicates="drop")

    stats = (
        df.groupby("_bin", observed=True).agg(x_ctr=(x_col, "mean"),y_mean=(y_col, "mean"),y_std=(y_col, "std"),n=(y_col, "size"),
        ).reset_index(drop=True)
    )

    stats["y_se"] = stats["y_std"] / np.sqrt(stats["n"])
    return stats



# Animated single-target plot
def animate_quant_bins_v3(df: pd.DataFrame,x_col: str, y_col: str, lq: float, uq: float, period: str, n_bins: int,
                          show_outliers: bool = True, template: str = "simple_white"):
    
    df2 = df.copy()
    df2["period"] = df2.index.to_series().dt.to_period(period).astype(str)

    frames = []
    ymins, ymaxs = [], []

    for p, df_p in df2.groupby("period"):
        stats = _bin_stats(df_p,x_col,y_col,lq,uq,n_bins,exclude_outliers=show_outliers,)

        # Track axis limits (exclude outliers when computing limits)
        ymins.append((stats["y_mean"] - stats["y_se"]).min())
        ymaxs.append((stats["y_mean"] + stats["y_se"]).max())

        trace_bins = go.Scatter(
            x=stats["x_ctr"],
            y=stats["y_mean"],
            mode="lines+markers",
            name="Bins",
            line=dict(shape="linear", width=2, color="#1f77b4"),
            marker=dict(size=6, color="#1f77b4"),
            error_y=dict(type="data", array=stats["y_se"], visible=True, thickness=1.2),
        )
        traces = [trace_bins]

        # Optional outlier trace 
        if show_outliers:
            lo, hi = df_p[y_col].quantile([lq, uq])
            mask_out = (df_p[y_col] < lo) | (df_p[y_col] > hi)

            if mask_out.any():
                trace_out = go.Scatter(
                    x=df_p.loc[mask_out, x_col],
                    y=df_p.loc[mask_out, y_col],
                    mode="markers",
                    name="Outliers",
                    marker=dict(
                        size=5,
                        symbol="circle-open",
                        opacity=0.4,
                        color="#d62728",
                    ),
                    hoverinfo="skip",
                )
                traces.append(trace_out)

        # Assemble frame 
        frames.append(go.Frame(data=traces, name=str(p)))

    # Axis rang
    pad = 0.05
    ymin, ymax = min(ymins), max(ymaxs)
    yrange = [ymin - abs(ymin) * pad, ymax + abs(ymax) * pad]

    # Figure layout with play/pause and slider
    fig = go.Figure(
        data=frames[0].data,
        layout=dict(
            title=f"{title_suffix}mean {y_col} in {n_bins} quantile bins of {x_col}",
            xaxis=dict(title=x_col, autorange=True),
            yaxis=dict(title=f"Mean {y_col}", range=yrange),
            template=template,
            # Zero baseline
            shapes=[
                dict(
                    type="line",
                    xref="paper",
                    yref="y",
                    x0=0,
                    x1=1,
                    y0=0,
                    y1=0,
                    line=dict(color="black", width=1),
                )
            ],
            # Play / Pause buttons
            updatemenus=[
                dict(
                    type="buttons",
                    showactive=False,
                    x=0.02,
                    y=1.15,
                    buttons=[
                        dict(
                            label="Play",
                            method="animate",
                            args=[None, dict(frame=dict(duration=500), fromcurrent=True)],
                        ),
                        dict(
                            label="Pause",
                            method="animate",
                            args=[[None], dict(frame=dict(duration=0), mode="immediate")],
                        ),
                    ],
                )
            ],
            # Period slider
            sliders=[
                dict(
                    active=0,
                    currentvalue=dict(prefix="Period: "),
                    steps=[
                        dict(
                            method="animate",
                            label=f.name,
                            args=[[f.name], dict(frame=dict(duration=500), mode="immediate")],
                        )
                        for f in frames
                    ],
                )
            ],
            frames=frames,
        ),
    )

    return fig


# 3) Animated dual-target (true vs pred) plot
def animate_quant_bins_dual(
    df: pd.DataFrame,
    x_col: str,
    y_true_col: str,
    y_pred_col: str,
    lq: float,
    uq: float,
    period: str,
    n_bins: int,
    *,
    title_suffix: str = "",
    show_outliers: bool = True,
    template: str = "simple_white",
):
    """
    Same as *animate_quant_bins_v3* but shows **two** bin curves:
    one for the “true” target and one for the “predicted” target.
    """
    df2 = df.copy()
    df2["period"] = df2.index.to_series().dt.to_period(period).astype(str)

    frames = []
    ymins, ymaxs = [], []

    colors = {"True": "#1f77b4", "Pred": "#ff7f0e"}

    for p, df_p in df2.groupby("period"):
        # --- stats for each target ---------------------------------------------
        stats_true = _bin_stats(
            df_p,
            x_col,
            y_true_col,
            lq,
            uq,
            n_bins,
            exclude_outliers=show_outliers,
        )
        stats_pred = _bin_stats(
            df_p,
            x_col,
            y_pred_col,
            lq,
            uq,
            n_bins,
            exclude_outliers=show_outliers,
        )

        ymins.append(
            min(
                (stats_true["y_mean"] - stats_true["y_se"]).min(),
                (stats_pred["y_mean"] - stats_pred["y_se"]).min(),
            )
        )
        ymaxs.append(
            max(
                (stats_true["y_mean"] + stats_true["y_se"]).max(),
                (stats_pred["y_mean"] + stats_pred["y_se"]).max(),
            )
        )

        # --- traces ------------------------------------------------------------
        trace_true = go.Scatter(
            x=stats_true["x_ctr"],
            y=stats_true["y_mean"],
            mode="lines+markers",
            name="True",
            line=dict(shape="linear", width=2, color=colors["True"]),
            marker=dict(size=6, color=colors["True"]),
            error_y=dict(
                type="data", array=stats_true["y_se"], visible=True, thickness=1.2
            ),
        )
        trace_pred = go.Scatter(
            x=stats_pred["x_ctr"],
            y=stats_pred["y_mean"],
            mode="lines+markers",
            name="Pred",
            line=dict(shape="linear", width=2, color=colors["Pred"]),
            marker=dict(size=6, color=colors["Pred"]),
            error_y=dict(
                type="data", array=stats_pred["y_se"], visible=True, thickness=1.2
            ),
        )

        traces = [trace_true, trace_pred]

        # --- Optional outlier trace (only once, coloured grey) -----------------
        if show_outliers:
            lo, hi = df_p[y_true_col].quantile([lq, uq])
            mask_out = (df_p[y_true_col] < lo) | (df_p[y_true_col] > hi)
            if mask_out.any():
                traces.append(
                    go.Scatter(
                        x=df_p.loc[mask_out, x_col],
                        y=df_p.loc[mask_out, y_true_col],
                        mode="markers",
                        name="Outliers",
                        marker=dict(
                            size=5,
                            symbol="circle-open",
                            opacity=0.4,
                            color="grey",
                        ),
                        hoverinfo="skip",
                    )
                )

        frames.append(go.Frame(data=traces, name=str(p)))

    pad = 0.05
    ymin, ymax = min(ymins), max(ymaxs)
    yrange = [ymin - abs(ymin) * pad, ymax + abs(ymax) * pad]

    fig = go.Figure(
        data=frames[0].data,
        layout=dict(
            title=(
                f"{title_suffix}mean {y_true_col}/{y_pred_col} "
                f"in {n_bins} quantile bins of {x_col}"
            ),
            xaxis=dict(title=x_col, autorange=True),
            yaxis=dict(title="Mean target", range=yrange),
            template=template,
            shapes=[
                dict(
                    type="line",
                    xref="paper",
                    yref="y",
                    x0=0,
                    x1=1,
                    y0=0,
                    y1=0,
                    line=dict(color="black", width=1),
                )
            ],
            updatemenus=[
                dict(
                    type="buttons",
                    showactive=False,
                    x=0.02,
                    y=1.15,
                    buttons=[
                        dict(
                            label="Play",
                            method="animate",
                            args=[None, dict(frame=dict(duration=500), fromcurrent=True)],
                        ),
                        dict(
                            label="Pause",
                            method="animate",
                            args=[[None], dict(frame=dict(duration=0), mode="immediate")],
                        ),
                    ],
                )
            ],
            sliders=[
                dict(
                    active=0,
                    currentvalue=dict(prefix="Period: "),
                    steps=[
                        dict(
                            method="animate",
                            label=f.name,
                            args=[[f.name], dict(frame=dict(duration=500), mode="immediate")],
                        )
                        for f in frames
                    ],
                )
            ],
            frames=frames,
        ),
    )

    return fig
