# `fit_spline` mode comparison

Compares `smooth` = `fast`, `slow`, `0.5` with and without OD weighting. Includes fit-trace overlays, μ_max comparison plots, and runtime swarms.


In [None]:
import importlib.util
import inspect
import sys
import time
import types
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.interpolate import BSpline

src_dir = None
for base in (Path.cwd(), Path.cwd().parent):
    candidate = base / "src"
    if (candidate / "growthcurves" / "non_parametric.py").exists():
        src_dir = candidate
        break
if src_dir is None:
    raise RuntimeError("Could not locate src/growthcurves/non_parametric.py")

# Force fresh load from disk to avoid stale kernel definitions.
for name in ("growthcurves.non_parametric", "growthcurves.inference", "growthcurves"):
    sys.modules.pop(name, None)

pkg = types.ModuleType("growthcurves")
pkg.__path__ = [str(src_dir / "growthcurves")]
sys.modules["growthcurves"] = pkg

inf_path = src_dir / "growthcurves" / "inference.py"
inf_spec = importlib.util.spec_from_file_location("growthcurves.inference", inf_path)
inf_mod = importlib.util.module_from_spec(inf_spec)
sys.modules["growthcurves.inference"] = inf_mod
inf_spec.loader.exec_module(inf_mod)

np_path = src_dir / "growthcurves" / "non_parametric.py"
np_spec = importlib.util.spec_from_file_location("growthcurves.non_parametric", np_path)
np_mod = importlib.util.module_from_spec(np_spec)
sys.modules["growthcurves.non_parametric"] = np_mod
np_spec.loader.exec_module(np_mod)

fit_spline = np_mod.fit_spline
print("Loaded fit_spline signature:", inspect.signature(fit_spline))


In [None]:
candidate_paths = [
    Path("exploratory notebooks/generated_data/spline_eval_curves.csv"),
    Path("exploratory notebooks/exploratory notebooks/generated_data/spline_eval_curves.csv"),
]
csv_path = next((p for p in candidate_paths if p.exists()), None)
if csv_path is None:
    raise RuntimeError("Missing generated CSV. Run spline_test_data_generation.ipynb first.")

df = pd.read_csv(csv_path)
curves = []
for curve_id, g in df.groupby("curve_id", sort=False):
    g = g.sort_values("t")
    t = g["t"].to_numpy(dtype=float)
    N_obs = g["N_obs"].to_numpy(dtype=float)
    N_true = g["N_true"].to_numpy(dtype=float) if "N_true" in g.columns else np.full_like(t, np.nan, dtype=float)
    mu_true = float(g["mu_true"].iloc[0])

    mask = np.isfinite(t) & np.isfinite(N_obs) & (N_obs > 0)
    t, N_obs, N_true = t[mask], N_obs[mask], N_true[mask]
    if len(t) >= 10 and np.ptp(t) > 0 and np.isfinite(mu_true) and mu_true > 0:
        curves.append(
            {
                "curve_id": str(curve_id),
                "t": t,
                "N_obs": N_obs,
                "N_true": N_true,
                "mu_true": mu_true,
            }
        )

print(f"Loaded {len(curves)} valid curves from {csv_path}")


In [None]:
method_specs = [
    {"smooth": "fast", "use_weights": True, "label": "fast + weights", "color": "#1f77b4"},
    {"smooth": "fast", "use_weights": False, "label": "fast", "color": "#17becf"},
    {"smooth": "slow", "use_weights": True, "label": "slow + weights", "color": "#2ca02c"},
    {"smooth": "slow", "use_weights": False, "label": "slow", "color": "#98df8a"},
    {"smooth": 0.5, "use_weights": True, "label": "0.5 + weights", "color": "#d62728"},
    {"smooth": 0.5, "use_weights": False, "label": "0.5", "color": "#ff9896"},
]
for spec in method_specs:
    spec["key"] = f"{spec['smooth']}|w={int(spec['use_weights'])}"

runtime_records = []
fit_records = []
for rec in curves:
    t = rec["t"]
    N_obs = rec["N_obs"]
    mu_true = rec["mu_true"]
    cid = rec["curve_id"]

    for spec in method_specs:
        t0 = time.perf_counter()
        fit = fit_spline(t, N_obs, smooth=spec["smooth"], use_weights=spec["use_weights"])
        dt_ms = (time.perf_counter() - t0) * 1000.0
        runtime_records.append({"curve_id": cid, "method_key": spec["key"], "method": spec["label"], "time_ms": dt_ms})

        if fit is None:
            continue
        params = fit.get("params", {})
        mu_hat = float(params.get("mu_max", np.nan))
        if not np.isfinite(mu_hat):
            continue

        fit_records.append(
            {
                "curve_id": cid,
                "method_key": spec["key"],
                "method": spec["label"],
                "smooth": str(spec["smooth"]),
                "use_weights": bool(spec["use_weights"]),
                "mu_true": float(mu_true),
                "mu_hat": float(mu_hat),
                "rel_err": float(abs(mu_hat - mu_true) / mu_true),
                "spline_s": float(params.get("spline_s", np.nan)),
            }
        )

runtime_df = pd.DataFrame(runtime_records)
eval_df = pd.DataFrame(fit_records)

rows = []
n_total = len(curves)
for spec in method_specs:
    key = spec["key"]
    sub_eval = eval_df[eval_df["method_key"] == key]
    sub_time = runtime_df[runtime_df["method_key"] == key]["time_ms"].to_numpy(dtype=float)
    n_ok = len(sub_eval)
    n_failed = max(n_total - n_ok, 0)
    rel = sub_eval["rel_err"].to_numpy(dtype=float)
    rows.append(
        {
            "method": spec["label"],
            "smooth": str(spec["smooth"]),
            "use_weights": bool(spec["use_weights"]),
            "n_curves": int(n_total),
            "n_failed": int(n_failed),
            "failure_rate_%": 100.0 * n_failed / n_total if n_total else np.nan,
            "median_rel_err_%": 100.0 * float(np.median(rel)) if len(rel) else np.nan,
            "mean_rel_err_%": 100.0 * float(np.mean(rel)) if len(rel) else np.nan,
            "median_time_ms": float(np.median(sub_time)) if len(sub_time) else np.nan,
            "mean_time_ms": float(np.mean(sub_time)) if len(sub_time) else np.nan,
        }
    )

results = pd.DataFrame(rows)
method_order = [spec["label"] for spec in method_specs]
results["method"] = pd.Categorical(results["method"], categories=method_order, ordered=True)
results = results.sort_values("method").reset_index(drop=True)

mu_sorted_idx = np.argsort([rec["mu_true"] for rec in curves]) if curves else np.array([], dtype=int)
example_indices = sorted(set([0, len(mu_sorted_idx) // 2, max(len(mu_sorted_idx) - 1, 0)])) if len(mu_sorted_idx) else []
example_curves = [curves[int(mu_sorted_idx[i])] for i in example_indices if len(mu_sorted_idx)]

results.round(3)


In [None]:
if not example_curves:
    raise RuntimeError("No example curves available to plot.")

label_to_color = {spec["label"]: spec["color"] for spec in method_specs}
fig = make_subplots(
    rows=1,
    cols=len(example_curves),
    subplot_titles=[f"{rec['curve_id']}<br>μ_true={rec['mu_true']:.3f}" for rec in example_curves],
    horizontal_spacing=0.04,
)

for cidx, rec in enumerate(example_curves, start=1):
    t = rec["t"]
    N_obs = rec["N_obs"]
    N_true = rec["N_true"]
    t_eval = np.linspace(float(np.min(t)), float(np.max(t)), 300)

    fig.add_trace(
        go.Scatter(x=t, y=N_obs, mode="markers", name="Observed", marker=dict(size=5, color="#7f7f7f"), showlegend=(cidx == 1)),
        row=1, col=cidx
    )

    if np.any(np.isfinite(N_true)):
        fig.add_trace(
            go.Scatter(x=t, y=N_true, mode="lines", name="True", line=dict(color="#111111", width=2), showlegend=(cidx == 1)),
            row=1, col=cidx
        )

    for spec in method_specs:
        fit = fit_spline(t, N_obs, smooth=spec["smooth"], use_weights=spec["use_weights"])
        if fit is None:
            continue
        p = fit.get("params", {})
        if not {"tck_t", "tck_c", "tck_k"}.issubset(p.keys()):
            continue
        spline = BSpline(np.asarray(p["tck_t"], dtype=float), np.asarray(p["tck_c"], dtype=float), int(p["tck_k"]))
        N_fit = np.exp(np.asarray(spline(t_eval), dtype=float))
        fig.add_trace(
            go.Scatter(
                x=t_eval,
                y=N_fit,
                mode="lines",
                name=spec["label"],
                line=dict(width=2, color=spec["color"]),
                showlegend=(cidx == 1),
            ),
            row=1, col=cidx
        )

fig.update_layout(template="plotly_white", height=420, width=1700, title="Example curve fits by method")
for cidx in range(1, len(example_curves) + 1):
    fig.update_xaxes(title_text="Time", row=1, col=cidx)
    fig.update_yaxes(title_text="OD", row=1, col=cidx)
fig.show()


In [None]:
if eval_df.empty:
    raise RuntimeError("No successful fits available for comparison plots.")

method_order = [spec["label"] for spec in method_specs]
label_to_color = {spec["label"]: spec["color"] for spec in method_specs}
baseline_method = method_specs[0]["label"]

fig_cmp = make_subplots(
    rows=2, cols=2,
    subplot_titles=[
        "Parity: μ_true vs μ_hat",
        "Mean relative error (95% CI)",
        f"Paired errors vs baseline ({baseline_method})",
        "Smoothing value (spline_s) distribution",
    ],
    horizontal_spacing=0.12,
    vertical_spacing=0.18,
)

# 1) Parity plot
for method in method_order:
    sub = eval_df[eval_df["method"] == method]
    if sub.empty:
        continue
    fig_cmp.add_trace(
        go.Scatter(
            x=sub["mu_true"], y=sub["mu_hat"], mode="markers",
            name=method, marker=dict(size=5, color=label_to_color[method], opacity=0.5),
            showlegend=True,
        ),
        row=1, col=1,
    )
if len(eval_df):
    lo = float(min(eval_df["mu_true"].min(), eval_df["mu_hat"].min()))
    hi = float(max(eval_df["mu_true"].max(), eval_df["mu_hat"].max()))
    fig_cmp.add_trace(
        go.Scatter(x=[lo, hi], y=[lo, hi], mode="lines", line=dict(color="black", dash="dash"), name="Parity", showlegend=False),
        row=1, col=1,
    )

# 2) Mean relative error bars
bar_x, bar_y, bar_err, bar_color = [], [], [], []
for method in method_order:
    vals = eval_df.loc[eval_df["method"] == method, "rel_err"].to_numpy(dtype=float)
    vals = vals[np.isfinite(vals)]
    if len(vals) == 0:
        bar_x.append(method); bar_y.append(np.nan); bar_err.append(0.0); bar_color.append(label_to_color[method])
        continue
    mean = float(np.mean(vals) * 100.0)
    ci95 = float(1.96 * np.std(vals, ddof=1) / np.sqrt(len(vals)) * 100.0) if len(vals) > 1 else 0.0
    bar_x.append(method); bar_y.append(mean); bar_err.append(ci95); bar_color.append(label_to_color[method])
fig_cmp.add_trace(
    go.Bar(x=bar_x, y=bar_y, error_y=dict(type="data", array=bar_err, visible=True), marker=dict(color=bar_color), showlegend=False),
    row=1, col=2,
)

# 3) Paired error scatter vs baseline
base = eval_df[eval_df["method"] == baseline_method][["curve_id", "rel_err"]].rename(columns={"rel_err": "base_rel"})
for method in method_order[1:]:
    sub = eval_df[eval_df["method"] == method][["curve_id", "rel_err"]].rename(columns={"rel_err": "method_rel"})
    merged = base.merge(sub, on="curve_id", how="inner")
    if merged.empty:
        continue
    fig_cmp.add_trace(
        go.Scatter(
            x=100.0 * merged["base_rel"], y=100.0 * merged["method_rel"], mode="markers",
            marker=dict(size=5, color=label_to_color[method], opacity=0.5), name=method, showlegend=False,
        ),
        row=2, col=1,
    )
pair_max = float(np.nanmax(100.0 * eval_df["rel_err"])) if len(eval_df) else 1.0
pair_max = max(pair_max, 1.0)
fig_cmp.add_trace(
    go.Scatter(x=[0, pair_max], y=[0, pair_max], mode="lines", line=dict(color="black", dash="dash"), showlegend=False),
    row=2, col=1,
)

# 4) spline_s swarm
rng = np.random.default_rng(20260224)
for i, method in enumerate(method_order):
    svals = eval_df.loc[eval_df["method"] == method, "spline_s"].to_numpy(dtype=float)
    svals = svals[np.isfinite(svals)]
    if len(svals) == 0:
        continue
    xj = rng.normal(loc=float(i), scale=0.07, size=len(svals))
    fig_cmp.add_trace(
        go.Scatter(x=xj, y=svals, mode="markers", marker=dict(size=4, color=label_to_color[method], opacity=0.45), showlegend=False),
        row=2, col=2,
    )
    fig_cmp.add_trace(
        go.Scatter(x=[float(i)], y=[float(np.mean(svals))], mode="markers", marker=dict(size=9, symbol="x", color=label_to_color[method]), showlegend=False),
        row=2, col=2,
    )

fig_cmp.update_layout(template="plotly_white", width=1400, height=950, title="μ_max comparison across spline modes")
fig_cmp.update_xaxes(title_text="μ_true", row=1, col=1)
fig_cmp.update_yaxes(title_text="μ_hat", row=1, col=1)
fig_cmp.update_xaxes(title_text="Method", row=1, col=2)
fig_cmp.update_yaxes(title_text="Relative error (%)", row=1, col=2)
fig_cmp.update_xaxes(title_text=f"{baseline_method} relative error (%)", row=2, col=1)
fig_cmp.update_yaxes(title_text="Method relative error (%)", row=2, col=1)
fig_cmp.update_xaxes(
    title_text="Method", row=2, col=2,
    tickmode="array",
    tickvals=[float(i) for i in range(len(method_order))],
    ticktext=method_order,
    range=[-0.6, len(method_order) - 0.4],
)
fig_cmp.update_yaxes(title_text="spline_s", row=2, col=2)
fig_cmp.show()


In [None]:
if runtime_df.empty:
    raise RuntimeError("No runtime data available.")

method_order = [spec["label"] for spec in method_specs]
label_to_color = {spec["label"]: spec["color"] for spec in method_specs}
fig_rt = make_subplots(
    rows=1, cols=2,
    subplot_titles=["Mean runtime per fit (95% CI)", "Runtime swarm per fit"],
    horizontal_spacing=0.14,
)

bar_x, bar_y, bar_err, bar_color = [], [], [], []
for method in method_order:
    vals = runtime_df.loc[runtime_df["method"] == method, "time_ms"].to_numpy(dtype=float)
    vals = vals[np.isfinite(vals)]
    mean = float(np.mean(vals)) if len(vals) else np.nan
    ci95 = float(1.96 * np.std(vals, ddof=1) / np.sqrt(len(vals))) if len(vals) > 1 else 0.0
    bar_x.append(method); bar_y.append(mean); bar_err.append(ci95); bar_color.append(label_to_color[method])
fig_rt.add_trace(
    go.Bar(x=bar_x, y=bar_y, error_y=dict(type="data", array=bar_err, visible=True), marker=dict(color=bar_color), showlegend=False),
    row=1, col=1,
)

rng = np.random.default_rng(20260224)
for i, method in enumerate(method_order):
    vals = runtime_df.loc[runtime_df["method"] == method, "time_ms"].to_numpy(dtype=float)
    vals = vals[np.isfinite(vals)]
    if len(vals) == 0:
        continue
    xj = rng.normal(loc=float(i), scale=0.07, size=len(vals))
    fig_rt.add_trace(
        go.Scatter(x=xj, y=vals, mode="markers", marker=dict(size=4, color=label_to_color[method], opacity=0.35), showlegend=False),
        row=1, col=2,
    )
    fig_rt.add_trace(
        go.Scatter(x=[float(i)], y=[float(np.median(vals))], mode="markers", marker=dict(size=10, symbol="x", color=label_to_color[method]), showlegend=False),
        row=1, col=2,
    )

fig_rt.update_layout(template="plotly_white", width=1450, height=520, title="Runtime comparison across spline modes")
fig_rt.update_xaxes(title_text="Method", row=1, col=1)
fig_rt.update_yaxes(title_text="Runtime (ms)", row=1, col=1)
fig_rt.update_xaxes(
    title_text="Method", row=1, col=2,
    tickmode="array",
    tickvals=[float(i) for i in range(len(method_order))],
    ticktext=method_order,
    range=[-0.6, len(method_order) - 0.4],
)
fig_rt.update_yaxes(title_text="Runtime (ms)", row=1, col=2, type="log")
fig_rt.show()
