In [None]:
import pandas as pd
import numpy as np
import os
from src.utilities.plot_functions import format_2d_plotly, format_3d_plotly, rotate_figure
from ultrack.tracks.graph import get_paths_to_roots, tracks_df_forest, inv_tracks_df_forest
from glob2 import glob
from tqdm import tqdm
import zarr

## Load tracking data

In [None]:
root = "E:\\Nick\\Cole Trapnell's Lab Dropbox\\Nick Lammers\\Nick\\killi_tracker\\"

os.environ["QT_API"] = "pyqt5"
os.environ["PYQTGRAPH_QT_LIB"] = "PyQt5"
os.environ["QT_API"] = "pyqt5"

project_name = "20250311_LCP1-NLSMSC"
stitch_suffix = ""
fig_path = os.path.join(root, "figures", project_name, "pipeline_figs", "")
os.makedirs(fig_path, exist_ok=True)
# load image dataset
zpath = os.path.join(root, "built_data", "zarr_image_files", project_name + "_fused.zarr")
fused_image_zarr = zarr.open(zpath, mode="r")

# load full tracking dataset
print("Loading tracking data for project:", project_name)
nls_track_path = os.path.join(root, "tracking", project_name, "tracking_20250328_redux", "well0000", "track_0000_2339_cb", "")
nls_tracks_df = pd.read_csv(os.path.join(nls_track_path, "tracks" + stitch_suffix + "_fluo.csv"))
nucleus_class_df = pd.read_csv(os.path.join(nls_track_path, "track_class_df_full.csv"))

# add class info to tracks
nls_tracks_df = nls_tracks_df.merge(nucleus_class_df.loc[:, ["track_id", "t", "track_class", "frame_class"]], on=["track_id", "t"], how="left")
nls_tracks_df.loc[:, "z_scaled"] = nls_tracks_df["z"].copy() * 3

print("Loading lcp tracking data for project:", project_name)
lcp_track_path = os.path.join(root, "built_data", "tracking", project_name, "")
lcp_tracks_df = pd.read_csv(os.path.join(lcp_track_path, "lcp_tracks_df.csv"))

lcp_curation_df = pd.read_csv(os.path.join(lcp_track_path, "20250311_lcp_track_curation.csv"))

## Load mask data

In [None]:
# full mask dataset
full_mask_fluo_dir = os.path.join(root, "built_data", "fluorescence_data", project_name, "")
fluo_frames = sorted(glob(full_mask_fluo_dir + "*.csv"))
fluo_df_list = []
for df_path in tqdm(fluo_frames):
    df = pd.read_csv(df_path)
    fluo_df_list.append(df)

fluo_df_full = pd.concat(fluo_df_list, axis=0, ignore_index=True)


## Combine

In [None]:
# make figure directory
fig_path = os.path.join(root, "figures", project_name)
os.makedirs(fig_path, exist_ok=True)

In [None]:
fluo_df_full.shape

In [None]:
import plotly.express as px

weights = fluo_df_full["mean_fluo"].to_numpy()**6/np.sum(fluo_df_full["mean_fluo"].to_numpy()**6)
n_plot = 100000
np.random.seed(134)
plot_indices = np.random.choice(fluo_df_full.shape[0], n_plot, replace=False, p=weights)

fluo_df_full["stage"] = fluo_df_full["frame"]*1.5/60 + 26

fig = px.scatter(fluo_df_full.loc[plot_indices], x="stage", y="mean_fluo")
axis_labels = ["stage (hpf)", "nuclear lcp-gfp intensity (au)"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=5))

fig.write_image(os.path.join(fig_path, "lcp_intensity_by_stage.png"), scale=2)

fig.show()

## Plot numbers of lcp+ cells over time according to our various datasets

In [None]:
fluo_thresh = 115
window_size = 50

fluo_df_full["fluo_flag"] = (fluo_df_full["mean_fluo"] > fluo_thresh).astype(float)
master_df_g = fluo_df_full.loc[:, ["stage", "fluo_flag"]].groupby(["stage"]).sum().reset_index()

master_df_g['fluo_trend'] = master_df_g['fluo_flag']\
    .transform(lambda s: s.rolling(window=window_size, min_periods=1).mean())

fig = px.line(master_df_g, x="stage", y="fluo_trend")
axis_labels = ["stage (hpf)", "number of lcp+ cells"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=5))

fig.write_image(os.path.join(fig_path, "lcp_cell_count_by_stage.png"), scale=2)

fig.show()


## li threshold

In [None]:
li_df = pd.read_csv("E:\\Nick\\Cole Trapnell's Lab Dropbox\\Nick Lammers\\Nick\\killi_tracker\\built_data\\mask_stacks\\20250311_LCP1-NLSMSCside1_li_thresh_trend.csv")
li_df["stage"] = li_df["frame"]*1.5/60 + 26

fig = px.line(li_df, x="stage", y="li_thresh")
axis_labels = ["stage (hpf)", "inferred segmentation threshold"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=6, color="coral"))

fig.write_image(os.path.join(fig_path, "li_thresh_by_stage.png"), scale=2)

fig.show()

## number of cells (total) and by type

In [None]:
total_counts_df = nls_tracks_df.groupby(["t"]).size().reset_index(name="number of cells")
window_size = 25
total_counts_df["number of cells"] = total_counts_df["number of cells"].rolling(window=window_size, center=True, min_periods=1).mean()
total_counts_df["stage"] = total_counts_df["t"]*1.5/60 + 26

fig = px.line(total_counts_df, x="stage", y="number of cells")
axis_labels = ["stage (hpf)", "number of cells"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)
fig.update_traces(line=dict(width=5, color="white"))
fig.write_image(os.path.join(fig_path, "total_cell_count_by_stage.png"), scale=2)
fig.show()

In [None]:
# ---------------------------------------------------------------
# 0.  per‑frame raw counts
cell_type_counts = (nls_tracks_df
        .groupby(["t", "track_class"])
        .size()
        .rename("n_cells")               # Series → name it
        .reset_index())

# ---------------------------------------------------------------
# 1.  add missing (t, class) rows and fill with 0
all_t       = np.arange(nls_tracks_df["t"].min(),
                        nls_tracks_df["t"].max() + 1)
all_classes = cell_type_counts["track_class"].unique()

full_index = pd.MultiIndex.from_product(
                    [all_t, all_classes],
                    names=["t", "track_class"])

cell_type_counts = (cell_type_counts
        .set_index(["t", "track_class"])
        .reindex(full_index, fill_value=0)      # missing → 0
        .reset_index())

# ---------------------------------------------------------------
# 2.  centred rolling mean (window = 25)
window_size = 50
cell_type_counts = (cell_type_counts
        .sort_values(["track_class", "t"])
        .assign(**{
            "number of cells": (
                cell_type_counts
                .sort_values(["track_class", "t"])         # ensure order
                .groupby("track_class")["n_cells"]
                .transform(lambda s:
                           s.rolling(window_size,
                                     center=True,
                                     min_periods=1).mean())
            )
        }))

# ---------------------------------------------------------------
# 3.  stage (hpf)
cell_type_counts["stage"] = cell_type_counts["t"] * 1.5 / 60 + 26

name_map = {
    0: "deep",
    1: "EVL",
    2: "YSN",
}

cell_type_counts = cell_type_counts.replace({"track_class": name_map})

fig = px.line(cell_type_counts, x="stage", y="number of cells", color="track_class", labels={"track_class":"cell type"},
              color_discrete_map={"deep": "blueviolet"})
axis_labels = ["stage (hpf)", "number of cells"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)
fig.update_traces(line=dict(width=5))
fig.write_image(os.path.join(fig_path, "cell_type_count_by_stage.png"), scale=2)
fig.show()

## Mask type segmentation figure

In [None]:
frame = 1500

# get image and seg
seg_zarr = zarr.open(nls_track_path + "segments.zarr", mode="r")
seg_frame = seg_zarr[frame]
im_frame = fused_image_zarr[frame, 1]

# segment classes
seg_df = nls_tracks_df.loc[nls_tracks_df["t"]==frame, ["track_id", "track_class"]]
deep_ids = seg_df.loc[seg_df["track_class"]==0, "track_id"].to_numpy()
evl_ids = seg_df.loc[seg_df["track_class"]==1, "track_id"].to_numpy()
ysn_ids = seg_df.loc[seg_df["track_class"]==2, "track_id"].to_numpy()

# get mask
mask = np.zeros_like(seg_frame)
mask[np.isin(seg_frame, deep_ids)] = 1
mask[np.isin(seg_frame, evl_ids)] = 2
mask[np.isin(seg_frame, ysn_ids)] = 3

# max project
im_max = np.max(im_frame, axis=0)
seg_max = np.max(seg_frame, axis=0)
mask_max = np.max(mask, axis=0)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.colors import ListedColormap, BoundaryNorm
from napari.utils.colormaps import label_colormap
from pathlib import Path

def save_fig(img, path, *, dpi=600):
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(img)
    ax.axis("off")
    plt.tight_layout()
    fig.savefig(fig_path / path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

clim = [30, 1200]
# ------------------------------------------------------------------
# 1.  panel A – im_max only (grayscale)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(im_max, cmap="gray", vmin=clim[0], vmax=clim[1])
ax.axis("off")
fig.savefig(os.path.join(fig_path, "im_max.png"), dpi=600, bbox_inches="tight", pad_inches=0)
plt.show()
# ------------------------------------------------------------------


In [None]:
# 2.  panel B – seg_max overlay (cyan, 40 % α)
# ONE helper that returns a ListedColormap
def discrete_cmap(n, *, seed=0):
    """Return a reproducible n‑colour colormap (label 0 = transparent)."""
    rng = np.random.default_rng(seed)
    colors = rng.random((n + 1, 4))            # RGBA 0‑1
    colors[0] = (0, 0, 0, 0)                   # label 0 fully transparent
    return ListedColormap(colors)

# -------------------------------------------------------------------
n_labels = int(seg_max.max())                  # highest integer in seg_max
seg_cmap = discrete_cmap(n_labels, seed=42)    # 1 → random colour 1, …

fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(im_max, cmap="gray")                 # background
ax.imshow(seg_max, cmap=seg_cmap, interpolation="nearest")              # one line: overlay
ax.axis("off")
fig.savefig(os.path.join(fig_path, "im_seg_max_overlay.png"), dpi=600, bbox_inches="tight", pad_inches=0)
plt.show()
plt.close(fig)
#
# # ---------------------------------------------------------------

In [None]:
# 3.  panel C – mask_max overlay (1 → magenta, 2 → red, 3 → green)
mask_cmap = ListedColormap([
    (0, 0, 0, 0),        # 0 → fully transparent
    (1, 0, 1, 0.5),      # 1 → magenta, 50 % α
    (1, 0, 0, 0.5),      # 2 → red,     50 % α
    (0, 1, 0, 0.5),      # 3 → green,   50 % α
])

fig, ax = plt.subplots(figsize=(10, 10))
ax.axis("off")
ax.imshow(im_max, cmap="gray")
ax.imshow(mask_max, cmap=mask_cmap, vmin=0, vmax=3)#, interpolate="nearest")  # one line: overlay
ax.axis("off")
plt.tight_layout()
fig.savefig(os.path.join(fig_path, "im_mask_max_overlay.png"), dpi=600, bbox_inches="tight", pad_inches=0)

plt.show()

## Sphere fitting stuff...just density plot here

In [None]:
# get spherical coordinates oriented relative to high dome position
from scipy.spatial.transform import Rotation as R

sphere_df = pd.read_csv(os.path.join(nls_track_path, "sphere_fit.csv"))
start_frame_max = 25
phi_shift_manual = 60 / 180 * np.pi

deep_tracks_df = nls_tracks_df.loc[nls_tracks_df["track_class"]==0, ["t", "x", "y", "z_scaled"]].copy()

start_filter = deep_tracks_df["t"] <= start_frame_max
start_indices = np.where(start_filter)[0]

# 1) pull out your sphere center in (x,y,z) order
sphere_center = sphere_df .loc[sphere_df["t"]==0, ["xs","ys","zs"]] \
                          .iloc[0] \
                          .to_numpy()    # [x0,y0,z0]
sphere_radius = sphere_df .loc[sphere_df["t"]==0, ["r"]] \
                          .iloc[0] \
                          .to_numpy()

# 2) center‐of‐mass also in (x,y,z)
start_filter = deep_tracks_df["t"] <= start_frame_max
deep_cm = ( deep_tracks_df
            .loc[start_filter, ["x","y","z_scaled"]]
            .to_numpy()
            .mean(axis=0) )                  # [x̄,ȳ,z̄]

# 3) up‐vector = direction from sphere_center → deep_cm
orientation_vec = deep_cm - sphere_center
v = orientation_vec / np.linalg.norm(orientation_vec)

rot, _ = R.align_vectors([[0,0,1]], [v])
# Note: align_vectors(A,B) finds R so that R @ A[i] ≈ B[i],
# so here we align the *z‑axis* to your v.
# If you prefer the opposite convention, swap the lists.

Rmat = rot.as_matrix()

# 3) apply R to all points (shift first)
pts = deep_tracks_df.loc[:, ["x","y","z_scaled"]].to_numpy() - sphere_center
pts_rot = (Rmat @ pts.T).T   # now “pole” is z

# 4) compute standard spherical coords
x, y, z = pts_rot[:,0], pts_rot[:,1], pts_rot[:,2]
r   = np.linalg.norm(pts_rot, axis=1)
theta = np.arccos(np.clip(z/r, -1, 1))   # 0…π
phi = np.arctan2(y, x)                 # –π…π
phi += phi_shift_manual
phi_wrapped = (phi + np.pi) % (2*np.pi) - np.pi
deep_tracks_df[["r","theta","phi"]] = np.column_stack([r, theta, phi_wrapped])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

r, theta, phi = deep_tracks_df["r"].to_numpy(), deep_tracks_df["theta"].to_numpy(), deep_tracks_df["phi"].to_numpy()
# deep_tracks_df[["r","theta","phi"]] = np.column_stack([r, theta, phi])

# ------------------------------------------------------------
# 1) Example: N random points on (or near) the unit sphere
t_vec = deep_tracks_df["t"].to_numpy()

lon = np.degrees(phi)                        # 0–360
lat = 90 - np.degrees(theta)                       # 90 at pole … -90

# ------------------------------------------------------------
# 2) Robinson projection with Cartopy
proj = ccrs.Robinson()
xy = proj.transform_points(ccrs.PlateCarree(), lon, lat)
x_proj, y_proj = xy[:,0], xy[:,1]


In [None]:
from scipy.spatial import cKDTree
import matplotlib as mpl
from cycler import cycler

mpl.rcParams.update({
    # --------  backgrounds  --------
    "figure.facecolor":  "black",
    "axes.facecolor":    "black",
    "savefig.facecolor": "black",
    "savefig.edgecolor": "black",

    # --------  text / lines  -------
    "text.color":        "white",
    "axes.edgecolor":    "white",
    "axes.labelcolor":   "white",
    "xtick.color":       "white",
    "ytick.color":       "white",
    "grid.color":        "0.5",

    # brighter default colour‑cycle so traces stay visible
    "axes.prop_cycle":   cycler(color=plt.cm.tab10.colors)
})

r, theta, phi = deep_tracks_df["r"].to_numpy(), deep_tracks_df["theta"].to_numpy(), deep_tracks_df["phi"].to_numpy()
# deep_tracks_df[["r","theta","phi"]] = np.column_stack([r, theta, phi])

# ------------------------------------------------------------
# 1) Example: N random points on (or near) the unit sphere
t_vec = deep_tracks_df["t"].to_numpy()

lon = np.degrees(phi)                        # 0–360
lat = 90 - np.degrees(theta)                       # 90 at pole … -90

# ------------------------------------------------------------
# 2) Robinson projection with Cartopy
proj = ccrs.Robinson()
xy = proj.transform_points(ccrs.PlateCarree(), lon, lat)
x_proj, y_proj = xy[:,0], xy[:,1]



# proj = ccrs.Robinson()
nbins = 75
t_window = 25
hm_path = os.path.join(fig_path, "density_hexbin", "")
os.makedirs(hm_path, exist_ok=True)


# ------------------------------------------------------------------
# 2.  CONSTANT MAP EXTENT (Robinson globe bounds)
proj   = ccrs.Robinson()
xlim   = proj.x_limits               # (-1.68e7 , +1.68e7)
ylim   = proj.y_limits               # (-8.63e6 , +8.63e6)
extent = (*xlim, *ylim)              # (xmin, xmax, ymin, ymax)

# ------------------------------------------------------------------
# 3.  BUILD OUTPUT FOLDER
# hm_path = Path(fig_path) / "density_hexbin"
# hm_path.mkdir(parents=True, exist_ok=True)

# ------------------------------------------------------------------
# 4.  SELECT FRAME WINDOW & NORMALISE
for frame in tqdm(range(0, 2339, 25)):

    t_filter   = (t_vec >= frame - t_window) & (t_vec <= frame + t_window)
    n_frames   = len(np.unique(t_vec[t_filter]))   # number of frames in window
    xp_frame   = x_proj[t_filter]
    yp_frame   = y_proj[t_filter]

    # ------------------------------------------------------------------
    # 5.  INITIAL HEXBIN  (in axes coords, no projection needed here)
    #     We need the bin centres to build the KD‑tree.
    fig_tmp, ax_tmp = plt.subplots()
    hb_tmp = ax_tmp.hexbin(
            xp_frame, yp_frame,
            gridsize=nbins,
            extent=extent,
            bins=None,                 # raw counts
            mincnt=1,
            cmap="hot_r"
    )
    plt.close(fig_tmp)

    counts  = hb_tmp.get_array()       # (M,)
    centres = hb_tmp.get_offsets()     # (M, 2)

    # ------------------------------------------------------------------
    # 6.  SMOOTH COUNTS WITH K‑NEAREST AVERAGE
    tree             = cKDTree(centres)
    _, idxs           = tree.query(centres, k=6)       # includes self
    counts_smooth     = counts[idxs].mean(axis=1) / n_frames
    counts_raw        = counts / n_frames

    # ------------------------------------------------------------------
    # 7.  PLOT ON GLOBE
    stage = frame * 1.5 / 60 + 26                     # hpf

    fig, ax = plt.subplots(figsize=(10, 5),
                           subplot_kw=dict(projection=proj))

    hb = ax.hexbin(
            centres[:, 0], centres[:, 1],
            C=counts_smooth,
            gridsize=nbins,
            extent=extent,
            reduce_C_function=np.mean,
            cmap="inferno",
            vmin=0, vmax=1.5,
            alpha=0.8,
            linewidths=0,       # no edge lines
        edgecolors='none'
    )

    ax.set_global()                 # same as set_xlim/ylim(xlim, ylim)

    # cosmetics
    fig.colorbar(hb, ax=ax, label="cell count")
    ax.set_title(f"Deep cell density on embryonic surface ({stage:.2f} hpf)")
    ax.gridlines(draw_labels=True,
                 xformatter=LongitudeFormatter(),
                 yformatter=LatitudeFormatter())

    plt.tight_layout()

    fig.savefig(
        os.path.join(hm_path, f"deep_cell_density_f{frame:04}.png"),
        dpi=600, bbox_inches="tight"
    )
    # plt.show()

    fig.savefig(os.path.join(hm_path, f"deep_cell_density_f{frame:04}.png"),
                 dpi=600, bbox_inches="tight")


## tracking qc figs

In [None]:
from ultrack.tracks.graph import inv_tracks_df_forest
from ultrack.tracks.gap_closing import tracks_starts, tracks_ends
from tqdm import tqdm

# calculate root node for each track
child_to_parent_dict = inv_tracks_df_forest(nls_tracks_df)

nls_tracks_df = pd.read_csv(os.path.join(nls_track_path, "tracks" + stitch_suffix + "_fluo.csv"))
# initialize
# deep_tracks_df["root_id"] = -1
# iterate over all tracks
track_index = nls_tracks_df["track_id"].unique()
track_parent_dict = nls_tracks_df.loc[:, ["track_id", "parent_track_id"]].drop_duplicates().set_index('track_id')['parent_track_id'].to_dict()
child_keys = np.asarray(child_to_parent_dict.keys())
track_root_dict = track_parent_dict .copy()

map_ids = []
for track_id in tqdm(track_index):
    parent_id = track_parent_dict[track_id]
    if parent_id != -1:
        map_ids.append(track_id)
        # find root
        while parent_id != -1:
            curr_id = parent_id
            if curr_id not in child_keys: # this signals we are at the root
                break
            parent_id = track_parent_dict[curr_id]
        track_root_dict[track_id] = curr_id
    else:
        track_root_dict[track_id] = track_id

# track_parent_dict[10]
df = pd.DataFrame.from_dict(
    track_root_dict,
    orient='index',
    columns=['root_id']
)
df.index.name = 'track_id'
df = df.reset_index()

nls_tracks_df = nls_tracks_df.merge(df, on="track_id", how="left")
nls_tracks_df = nls_tracks_df.loc[nls_tracks_df["root_id"] > 0, :].copy()

In [None]:
track_counts = nls_tracks_df.loc[:, ["root_id"]].groupby(["root_id"]).size().reset_index(name="n_time_points")
nls_tracks_df = nls_tracks_df.merge(track_counts, on="root_id", how="left")
track_length = nls_tracks_df.loc[:, ["t", "n_time_points"]].groupby(["t"])["n_time_points"].mean().reset_index()

In [None]:
track_counts.head()

In [None]:
track_length["stage"] = track_length["t"] * 1.5 / 60 + 26
track_length["track length"] = track_length["n_time_points"] * 1.5 / 60
fig = px.line(track_length, x="stage", y="track length")
axis_labels = ["stage (hpf)", "average track length (hours)"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)
fig.update_traces(line=dict(width=5, color="white"))
fig.write_image(os.path.join(fig_path, "average_track_length_by_stage.png"), scale=2)
fig.show()

In [None]:
# 1) for each track, find its first frame and its total span
track_stats = (
    nls_tracks_df
    .groupby("root_id")["t"]
    .agg(start_frame="min", end_frame="max", n_time_points="nunique")
    .reset_index()
)
track_stats = track_stats.loc[track_stats["n_time_points"]>1, :].copy()
#    columns = ['root_id','start_frame','end_frame','n_time_points']

# 2) now average the lifespans by their start_frame
avg_by_start = (
    track_stats
    .groupby("start_frame")["n_time_points"]
    .mean()
    .reset_index(name="avg_length")
)

# 3) if you want to convert start_frame → hours post‑fertilisation:
avg_by_start["start_hpf"] = avg_by_start["start_frame"] * 1.5/60 + 26
avg_by_start["avg_length"] = avg_by_start["avg_length"].rolling(window=35, center=True, min_periods=1).mean()
avg_by_start["avg_length_hr"] = avg_by_start["avg_length"] * 1.5/60

# 4) plot it
import plotly.express as px
fig = px.line(
    avg_by_start,
    x="start_hpf",
    y="avg_length_hr",
    labels={"start_hpf":"track start (hpf)", "avg_length":"mean track length (hours)"},
)

fig = format_2d_plotly(fig, font_size=18, axis_labels=["track start (hpf)", "mean track length (hours)"])

fig.update_traces(line=dict(width=5, color="white"))

fig.show()

fig.write_image(os.path.join(fig_path, "avg_track_length_by_start.png"), scale=2)

In [None]:
import plotly.graph_objects as go

# ── 1) per‐track stats ─────────────────────────────────────────────────────
track_stats = (
    nls_tracks_df
    .groupby("root_id")["t"]
    .agg(
        start_frame   = "min",
        n_time_points = "nunique"   # total duration in frames
    )
    .reset_index()
)
# drop single‐frame tracks if desired
track_stats = track_stats[track_stats["n_time_points"] > 1].copy()

# ── 2) summary by start_frame ────────────────────────────────────────────
def p10(x): return x.quantile(0.10)
def p90(x): return x.quantile(0.90)

summary = (
    track_stats
    .groupby("start_frame")["n_time_points"]
    .agg(
        mean_length = "mean",
        p10_length  = p10,
        p90_length  = p90,
        count       = "size"
    )
    .reset_index()
)

# ── 3) convert to hpf / hours ─────────────────────────────────────────────
summary["start_hpf"]   = summary["start_frame"] * 1.5/60 + 26
summary["mean_hr"]     = summary["mean_length"] * 1.5/60
summary["p10_hr"]      = summary["p10_length"]  * 1.5/60
summary["p90_hr"]      = summary["p90_length"]  * 1.5/60

# optional smoothing over starts
for col in ["mean_hr", "p10_hr", "p90_hr"]:
    summary[col] = (
        summary[col]
        .rolling(window=35, center=True, min_periods=1)
        .mean()
    )

# ── 4) plotly: mean + 10th/90th percentile ───────────────────────────────
fig = go.Figure([
    go.Scatter(
        x=summary["start_hpf"],
        y=summary["mean_hr"],
        mode="lines",
        name="Mean",
        line=dict(color="white", width=4)
    ),
    go.Scatter(
        x=summary["start_hpf"],
        y=summary["p10_hr"],
        mode="lines",
        name="10th percentile",
        line=dict(color="lightblue", width=2, dash="dash")
    ),
    go.Scatter(
        x=summary["start_hpf"],
        y=summary["p90_hr"],
        mode="lines",
        name="90th percentile",
        line=dict(color="orange", width=2, dash="dash")
    ),
    go.Scatter(
        x=pd.concat([summary["start_hpf"], summary["start_hpf"][::-1]]),
        y=pd.concat([summary["p90_hr"], summary["p10_hr"][::-1]]),
        fill="toself",
        fillcolor="rgba(255,165,0,0.2)",
        line=dict(width=0),
        hoverinfo="skip",
        showlegend=True,
        name="10–90 percentile band"
    )
])

fig.update_layout(
    template="plotly_dark",
    xaxis_title="Track start (hpf)",
    yaxis_title="Track duration (hours)",
    font=dict(size=16, color="white"),
    legend=dict(bgcolor="rgba(0,0,0,0.5)")
)

fig.show()
fig.write_image(
    os.path.join(fig_path, "track_length_summary_by_start.png"),
    scale=2
)