In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from src.data_io.track_io import _load_track_data
from src.tracking.track_processing import preprocess_tracks
from pathlib import Path

### Process tracks

In [None]:
# --- parameters ---
root = Path(r"Y:\killi_dynamics")
project_name = "20251019_BC1-NLS_52-80hpf"
tracking_config = "tracking_20251102"

# load cell tracks
tracks_df_raw, sphere_df = _load_track_data(root=root,
                                            project_name=project_name,
                                            tracking_config=tracking_config,
                                            prefer_smoothed=False)

tracks_df_raw.head()

### Clean tracks
To-do's:
1. Add in confirmed nucleus masks that are dropped by tracking
2. Re-integrate logic to fuse tracks that appear to be duplicates

In [None]:
max_frame = 930 # filters for through to early aggregation stages
track_class = np.asarray([0]) # deep cells only
tracks_df = tracks_df_raw.rename(columns={"mean_fluo":"BC1"})
# tracks_df = tracks_df[tracks_df['track_class'].isin(track_class)] 
tracks_df = tracks_df[tracks_df["t"]<max_frame]

tracks_df = preprocess_tracks(tracks_df)

# go ahead and remove stationary tracks
tracks_df = tracks_df[~tracks_df["track_mostly_stationary"]]

# calculate smoothed BC1 values
tracks_df["BC1_sm"] = (
    tracks_df.groupby("track_id")["BC1"]
             .rolling(window=3, center=True, min_periods=1)
             .mean()
             .reset_index(level=0, drop=True)
)


### Add smoothed tracks and calculate spherical coordinates

In [None]:
tracks_df_sm, _ = _load_track_data(root=root,
                                            project_name=project_name,
                                            tracking_config=tracking_config,
                                            prefer_smoothed=True)

tracks_df_sm = tracks_df_sm.rename(columns={"x":"xs", "y":"ys", "z":"zs"})
tracks_df = tracks_df.merge(tracks_df_sm.loc[:, ["t", "track_id", "xs", "ys", "zs"]], how="left", on=["track_id", "t"])
tracks_df.head()

### Cell number over time

In [None]:
cell_count_table = tracks_df.groupby(["t", "track_class"]).count().reset_index()
fig = px.line(cell_count_table, x="t", y="track_id", color="track_class")
fig.show()

### Is there a correlation between BC1 intensity and local cell density?

In [None]:
import numpy as np
import pandas as pd
from scipy.spatial import KDTree
from tqdm import tqdm

def compute_surface_density(df, xcol="x", ycol="y", zcol="z", 
                            groupcol="t", k=10, min_points=5):
    """
    Computes approximate local *surface* density for each cell at each timepoint.
    Uses kNN search radius but divides by disk area (π r^2).
    Units: 1 / length^2
    """
    df = df.sort_values([groupcol]).reset_index(drop=True)
    coords = df[[xcol, ycol, zcol]].to_numpy()
    times = df[groupcol].to_numpy()
    
    density = np.full(len(df), np.nan)
    
    grouped = df.groupby(groupcol).groups  # dict: t -> row indices
    
    for t, idxs in tqdm(grouped.items(), desc="Calculating NN densities..."):
        idxs = np.array(idxs)

        if len(idxs) < min_points:
            continue

        pts = coords[idxs]
        tree = KDTree(pts)

        # Query k+1 nearest since first entry is itself
        dists, _ = tree.query(pts, k=k+1)
        r_k = dists[:, -1]  # radius to k-th neighbor

        # Avoid divide-by-zero
        r_k[r_k == 0] = np.nan

        # Surface area of local patch ~ π r^2
        areas = np.pi * (r_k ** 2)

        densities = k / areas  # units = 1 / length^2
        density[idxs] = densities

    return density


In [None]:
tracks_df["nn_sa_density"] = compute_surface_density(tracks_df)

In [None]:
nbins = 15
start_t = 850
stop_t = 930

time_filter = (tracks_df["t"] > start_t) & (tracks_df["t"] <= stop_t)
plot_df = tracks_df.loc[time_filter].copy()

# label bins nicely
bin_labels = [f"bin {i+1}" for i in range(nbins)]

# quantize BC1_sm
plot_df["BC1_bin"] = pd.qcut(
    plot_df["BC1_sm"],
    q=nbins,
    labels=False
)

bin_medians = (
    plot_df.groupby("BC1_bin")["BC1_sm"]
           .median()
           .to_numpy()
)

bin_median_table = (
    plot_df.groupby("BC1_bin")["BC1_sm"]
           .median()
           .reset_index(name="BC1_median")
)

plot_df = plot_df.merge(bin_median_table, on="BC1_bin", how="left")
plot_df = plot_df.loc[:, ["BC1_bin", "BC1_median", "nn_sa_density"]].groupby(["BC1_bin", "BC1_median"]).mean().reset_index() 

#fig = px.box(
#    plot_df,
#    x="BC1_bin",
#    y="nn_sa_density",
#    title="Surface density by BC1 intensity bin",
#    width=800,
#    height=600,
#    points=False
#)
fig = px.scatter(plot_df, x="BC1_median", y="nn_sa_density")

fig.update_layout(
    xaxis_title="BC1 intensity bin",
    yaxis_title="NN surface density",
    template="plotly_white",
    width=800, 
    height=600
)

fig.show()


### BC1 intensity distribution

In [None]:
# fig = px.histogram(tracks_df, x='BC1', nbins=100, title='BC1 mean intensity distribution')
# fig.show()

In [None]:
# fig = px.density_heatmap(
#     tracks_df,
#     x="t",
#     y="BC1",
#     nbinsx=15,
#     nbinsy=15,
#     color_continuous_scale="Viridis"
# )
# fig.show()

In [None]:
# fig = px.scatter(tracks_df, x="t", y="BC1", title='BC1 mean intensity over time', opacity=0.5)
# fig.show()

## Is there a link between bc1 and proliferation? 

In [None]:
# get IDs of tracks that produced offpsring
parent_table = tracks_df.loc[tracks_df["parent_track_id"]>-1, ["t", "parent_track_id", "track_id"]].drop_duplicates()
parent_table = parent_table.groupby(["track_id", "parent_track_id"]).min().reset_index().sort_values(["t", "parent_track_id"])

# label tracks that produce offspring
parent_ids = parent_table["parent_track_id"].to_numpy()
child_ids = parent_table["track_id"].to_numpy()
tracks_df.loc[:, "parent_flag"] = 0
tracks_df.loc[:, "child_flag"] = 0
tracks_df.loc[tracks_df["track_id"].isin(parent_ids), "parent_flag"] = 1
tracks_df.loc[tracks_df["track_id"].isin(parent_ids), "child_flag"] = 1
tracks_df["proliferative_flag"] = tracks_df["child_flag"] | tracks_df["parent_flag"]

# check consistency
# print(len(tracks_df.loc[tracks_df["parent_flag"]==1, "track_id"].unique()))
# print(len(parent_table["parent_track_id"].unique()))

#### Plot locations

In [None]:
parent_track_df = tracks_df.loc[tracks_df["parent_flag"]==1]
div_df = (
    parent_track_df.sort_values(["track_id", "t", "x", "y", "z"])
             .groupby("track_id")
             .tail(1)
)

fig = px.scatter_3d(div_df.loc[div_df["t"]>500], x="x", y="y", z="z", color="t", opacity=0.75)
fig.show()

In [None]:
prolif_df = tracks_df.loc[tracks_df["parent_track_id"]>-1, ["t", "parent_track_id"]].drop_duplicates()
prolif_df = prolif_df.groupby(["parent_track_id"]).min().reset_index().sort_values(["t", "parent_track_id"])
prolif_df = prolif_df.loc[:, ["t","parent_track_id"]].drop_duplicates().groupby("t").count().reset_index().rename(columns={"parent_track_id":"dN"})

prolif_df["N"] = 888 + prolif_df["dN"].cumsum()

fig = px.line(prolif_df, x="t", y="N")
fig.show()

In [None]:
p_tids = np.unique(tracks_df.loc[tracks_df["parent_track_id"]>-1, ["parent_track_id"]])
tids = np.unique(tracks_df.loc[:, ["track_id"]])
tids_raw = np.unique(tracks_df_raw.loc[:, ["track_id"]])
print(len(p_tids))
print(np.sum(np.isin(p_tids, tids)))
print(np.sum(np.isin(p_tids, tids_raw)))
# fig = px.histogram(parent_table, x='t', nbins=25, title='Division times')
# fig.show()

In [None]:


# time bins
nbins = 5
time_bins = np.linspace(100, 900, nbins+1)

# get last obs prior to division for parents
parent_track_df = tracks_df.loc[tracks_df["parent_flag"]==1]
pre_div_df = (
    parent_track_df.sort_values(["track_id", "t"])
             .groupby("track_id")
             .tail(10)
)

child_track_df = tracks_df.loc[tracks_df["child_flag"]==1]
post_div_df = (
    parent_track_df.sort_values(["track_id", "t"])
             .groupby("track_id")
             .head(10)
)


# generate our null
null_df = tracks_df.loc[tracks_df["proliferative_flag"]==0]
pre_null_df = (
     null_df.sort_values(["track_id", "t"])
             .groupby("track_id")
             .head(10)
)
post_null_df = (
     null_df.sort_values(["track_id", "t"])
             .groupby("track_id")
             .tail(10)
)

null_df = pd.concat([pre_null_df, post_null_df], ignore_index=True)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline


def bootstrap_rolling_mean(df, xcol, ycol, window=51, B=200):
    """
    Computes:
      - rolling mean (centered)
      - bootstrap CI for the rolling mean

    df must contain xcol and ycol.
    """

    # Sort once
    df = df[[xcol, ycol]].dropna().sort_values(xcol).reset_index(drop=True)

    x = df[xcol].to_numpy()
    y = df[ycol].to_numpy()

    # Base rolling mean
    mean = pd.Series(y).rolling(window, center=True, min_periods=1).mean().to_numpy()

    # Collect bootstrap smooths
    boot = np.zeros((B, len(y)))

    for b in range(B):
        resample = df.sample(len(df), replace=True)

        # Sort the bootstrap sample so smoothing makes sense
        resample = resample.sort_values(xcol)
        yb = resample[ycol].to_numpy()

        boot[b] = (
            pd.Series(yb)
              .rolling(window, center=True, min_periods=1)
              .mean()
              .to_numpy()
        )

    # Compute percentile CI
    lower = np.nanpercentile(boot, 2.5, axis=0)
    upper = np.nanpercentile(boot, 97.5, axis=0)

    # Output dataframe
    out = pd.DataFrame({
        xcol: x,
        "mean": mean,
        "low":  lower,
        "high": upper
    })

    return out.dropna()





In [None]:
import plotly.graph_objects as go
import pandas as pd

# pos_smooth  = bootstrap_rolling_mean(pre_div_df,  "t", "BC1_sm", window=11, B=200)
# null_smooth = bootstrap_rolling_mean(pre_null_df, "t", "BC1_sm", window=11, B=200)

# fig = go.Figure()

# # Parent
# fig.add_trace(go.Scatter(
#     x=pos_smooth["t"],
#     y=pos_smooth["mean"],
#     mode="lines",
#     line=dict(color="red"),
#     name="Parent (mean)"
# ))
# fig.add_trace(go.Scatter(
#     x=pd.concat([pos_smooth["t"], pos_smooth["t"][::-1]]),
#     y=pd.concat([pos_smooth["high"], pos_smooth["low"][::-1]]),
#     fill='toself',
#     fillcolor='rgba(255,0,0,0.25)',
#     line=dict(color="rgba(255,0,0,0)"),
#     showlegend=False
# ))

# # Null
# fig.add_trace(go.Scatter(
#     x=null_smooth["t"],
#     y=null_smooth["mean"],
#     mode="lines",
#     line=dict(color="blue"),
#     name="Null (mean)"
# ))
# fig.add_trace(go.Scatter(
#     x=pd.concat([null_smooth["t"], null_smooth["t"][::-1]]),
#     y=pd.concat([null_smooth["high"], null_smooth["low"][::-1]]),
#     fill='toself',
#     fillcolor='rgba(0,0,255,0.25)',
#     line=dict(color="rgba(0,0,255,0)"),
#     showlegend=False
# ))

# fig.update_layout(
#     title="Bootstrap Rolling Mean with CI",
#     xaxis_title="Time (t)",
#     yaxis_title="BC1_sm",
#     template="plotly_white"
# )

# fig.show()




In [None]:
# time bins
nbins = 5
time_bins = np.linspace(100, 900, nbins + 1)

# label bins nicely
bin_labels = [f"bin {i+1}" for i in range(nbins)]

pre_div_df["time_bin"] = pd.cut(
    pre_div_df["t"], bins=time_bins, labels=bin_labels, include_lowest=True
)
post_div_df["time_bin"] = pd.cut(
    post_div_df["t"], bins=time_bins, labels=bin_labels, include_lowest=True
)
null_df["time_bin"] = pd.cut(
    null_df["t"], bins=time_bins, labels=bin_labels, include_lowest=True
)

import plotly.express as px

fig = px.box(
    pd.concat([
        pre_div_df.assign(group="Parent"),
        post_div_df.assign(group="Child"),
        null_df.assign(group="Null")
    ]),
    x="time_bin",
    y="BC1_sm",
    color="group",
    points="outliers",   # or "all" or False
    title="BC1_sm by time bin"
)

fig.update_layout(
    xaxis_title="Time bin",
    yaxis_title="BC1_sm",
    boxmode="group",   # group parent & null side-by-side
    template="plotly_white"
)

fig.show()


## Basic track plots

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm

def make_tracklets(
    tracks_df: pd.DataFrame,
    deltaT: int,
    stride: int,
    fluo_col: str = "gene_level",
):
    long_rows = []     # long-format rows (one per timepoint)
    summary_rows = []  # one per tracklet

    for tid, g in tqdm(tracks_df.groupby("track_id"), desc="Generating tracklets..."):
        g = g.sort_values("t").reset_index(drop=True)

        t = g["t"].to_numpy()
        fl = g[fluo_col].to_numpy()

        x_arr = g["x"].to_numpy()
        y_arr = g["y"].to_numpy()
        z_arr = g["z"].to_numpy()

        # slide window over actual time values
        tmin, tmax = t.min(), t.max()

        for t0 in range(tmin, tmax - deltaT + 1, stride):
            t1 = t0 + deltaT

            mask = (t >= t0) & (t < t1)
            idx = np.where(mask)[0]

            # require exact deltaT frames in window
            if len(idx) != deltaT:
                continue

            # create new tracklet id
            tracklet_id = len(summary_rows)

            # ---- (1) Append long-format rows ----
            for ii in idx:
                long_rows.append({
                    "tracklet_id": tracklet_id,
                    "parent_track_id": tid,
                    "t": int(t[ii]),
                    "x": float(x_arr[ii]),
                    "y": float(y_arr[ii]),
                    "z": float(z_arr[ii]),
                    fluo_col: float(fl[ii]),
                })

            # ---- (2) Append summary row ----
            summary_rows.append({
                "tracklet_id": tracklet_id,
                "parent_track_id": tid,
                "t_start": t0,
                "t_end": t1,
                fluo_col: float(fl[idx].mean()),
                # (optional: mean positions)
                # "mean_x": float(x_arr[idx].mean()),
                # "mean_y": float(y_arr[idx].mean()),
                # "mean_z": float(z_arr[idx].mean()),
            })

    long_df = pd.DataFrame(long_rows)
    summary_df = pd.DataFrame(summary_rows)

    return long_df, summary_df


In [None]:
dT = int(60*60 / 90)
stride = dT // 2 + 1

tracklets_long, tracklet_summary = make_tracklets(tracks_df, deltaT=dT, stride=stride, fluo_col="BC1")

In [None]:
fig = px.scatter(tracklet_summary, x="t_start", y="mean_fluo")
fig.show()

In [None]:
tracklets_long.head()

In [None]:
n_samples = 25
time_filter = (tracklet_summary["t_end"] < 930) & (tracklet_summary["t_end"] > 800)
high_filter = tracklet_summary["mean_fluo"] > 400
low_filter = (tracklet_summary["mean_fluo"] > 150) & (tracklet_summary["mean_fluo"] < 250)
bright_tracks = tracklet_summary.loc[time_filter & high_filter, "tracklet_id"].to_numpy()
dim_tracks = tracklet_summary.loc[time_filter & low_filter, "tracklet_id"].to_numpy()

# choose candidages
dim_ids = np.random.choice(dim_tracks, n_samples, replace=False)
bright_ids = np.random.choice(bright_tracks, n_samples, replace=False)

In [None]:
dim_tracklets = tracklets_long[tracklets_long["tracklet_id"].isin(dim_ids)].copy()
dim_tracklets[["xc","yc","zc"]] = (
    dim_tracklets[["x","y","z"]] -
    dim_tracklets.groupby("tracklet_id")[["x","y","z"]].transform("first")
)

bright_tracklets = tracklets_long[tracklets_long["tracklet_id"].isin(bright_ids)].copy()
bright_tracklets[["xc","yc","zc"]] = (
    bright_tracklets[["x","y","z"]] -
    bright_tracklets.groupby("tracklet_id")[["x","y","z"]].transform("first")
)
bright_tracklets.head()

In [None]:
import plotly.express as px
x="xc"
y="yc"
z="zc"
df = bright_tracklets

palette = px.colors.qualitative.Plotly # or Set3, Plotly, Alphabet, etc.

fig = go.Figure()

tracklets = df["tracklet_id"].unique()
n_colors = len(palette)

for i, tid in enumerate(tracklets):
    g = df[df["tracklet_id"] == tid]

    fig.add_trace(
        go.Scatter3d(
            x=g[x],
            y=g[y],
            z=g[z],
            mode="lines",
            name=str(tid),
            line=dict(
                width=3,
                color=palette[i % n_colors]   # categorical color
            ),
        )
    )

fig.update_layout(
    scene=dict(
        xaxis_title=x,
        yaxis_title=y,
        zaxis_title=z,
    ),
    legend_title_text="tracklet_id",
)

fig.show()

## Look at basic speed and velocity metrics

In [None]:
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

def add_short_scale_metrics(df, scale_short, dt, track_col="track_id"):
    """
    Fastest implementation without numba.
    Uses stride tricks to compute per-timepoint rolling contour-length speed
    and rolling net velocity in vectorized numpy form.

    Assumes tracks have contiguous rows in time (no duplicates, sorted).
    Does NOT assume contiguous t values — just contiguity in the dataframe.
    """

    df = df.copy()
    df["v_path_short"] = np.nan
    df["v_net_short_x"] = np.nan
    df["v_net_short_y"] = np.nan
    df["v_net_short_z"] = np.nan
    df["v_net_short_mag"] = np.nan

    half = scale_short // 2

    for tid, g in tqdm(df.groupby(track_col), desc="Calculating velocity metrics..."):
        g = g.sort_values("t")
        idx = g.index.to_numpy()

        x = g["x"].to_numpy()
        y = g["y"].to_numpy()
        z = g["z"].to_numpy()

        # ---------- STEP-LENGTHS ----------
        dx = np.diff(x)
        dy = np.diff(y)
        dz = np.diff(z)
        steps = np.sqrt(dx*dx + dy*dy + dz*dz)  # shape (n-1,)

        n = len(g)

        # ---------- ROLLING STEP WINDOWS ----------
        # For each center position i, step window = steps[(i-half):(i+half)]
        # Get indices of step windows
        lefts = np.arange(n) - half
        rights = np.arange(n) + half

        # Clip to valid range, but mark invalid windows
        valid = (lefts >= 0) & (rights < n)
        lefts = np.clip(lefts, 0, n-2)
        rights = np.clip(rights, 1, n-1)

        # Build step windows using sliding_window_view
        # steps_w shape: (n-1 - (window_size-1), window_size) but we select rows we need
        window_size = scale_short - 1  # number of steps in window
        if window_size <= 0:
            continue

        if len(steps) >= window_size:
            steps_w = sliding_window_view(steps, window_shape=window_size)
        else:
            steps_w = None

        v_path = np.full(n, np.nan)

        # Only fill valid positions
        valid_idx = np.where(valid)[0]
        if steps_w is not None:
            # Map each center index to the corresponding steps window row
            # row index = lefts[i]
            rows = lefts[valid_idx]
            step_sums = steps_w[rows].sum(axis=1)  # vectorized sum
            durations = window_size * dt
            v_path[valid_idx] = step_sums / durations

        # ---------- NET VELOCITY ----------
        v_net_x = np.full(n, np.nan)
        v_net_y = np.full(n, np.nan)
        v_net_z = np.full(n, np.nan)

        # dx, dy, dz across window
        dx_net = x[rights] - x[lefts]
        dy_net = y[rights] - y[lefts]
        dz_net = z[rights] - z[lefts]

        v_net_x[valid] = dx_net[valid] / ((scale_short - 1) * dt)
        v_net_y[valid] = dy_net[valid] / ((scale_short - 1) * dt)
        v_net_z[valid] = dz_net[valid] / ((scale_short - 1) * dt)

        # store outputs
        df.loc[idx, "v_path_short"] = v_path
        df.loc[idx, "v_net_short_x"] = v_net_x
        df.loc[idx, "v_net_short_y"] = v_net_y
        df.loc[idx, "v_net_short_z"] = v_net_z
        df.loc[idx, "v_net_short_mag"] = np.sqrt(v_net_x**2 + v_net_y**2 + v_net_z**2)

    return df

In [None]:
tracks_df_vel = add_short_scale_metrics(tracks_df, 
                                        scale_short=7,
                                        dt=1.5)
                                    

In [None]:
time_filter = (tracks_df_vel["t"] < 930) & (tracks_df_vel["t"] > 800)
bc1_filter = (tracks_df_vel["BC1"] > 150) 

tracks_df_vel["vi"] = np.divide(tracks_df_vel["v_path_short"], tracks_df_vel["v_net_short_mag"]).copy()
tracks_plot = tracks_df_vel.loc[time_filter & bc1_filter]
tracks_plot["fluo_bin"] = pd.qcut(tracks_plot["BC1"], q=5, labels=False)

In [None]:
fig = px.box(
    tracks_plot,
    x="fluo_bin",
    y="vi",
    points=False,      # no individual dots
)

fig.update_layout(
    xaxis_title="BC1 bin (quantiles)",
    yaxis_title="speed",
)
fig.show()