In [None]:
import pandas as pd
import numpy as np
import os
from ultrack.tracks.graph import get_paths_to_roots, tracks_df_forest

## Load tracking data

In [2]:
# load tracks dataset
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/killi_tracker/"
project_name = "20250311_LCP1-NLSMSC_local"
tracking_config = "tracking_20250328_redux"
tracking_name = "track_0000_2339_cb"

track_path = os.path.join(root, "tracking", project_name, tracking_config, "well0000", tracking_name, "tracks_fluo.csv")
tracks_df_raw = pd.read_csv(track_path)
track_path_s = os.path.join(root, "tracking", project_name, tracking_config, "well0000", tracking_name, "tracks_fluo_stitched.csv")
tracks_df = pd.read_csv(track_path_s)

FileNotFoundError: [Errno 2] No such file or directory: "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/killi_tracker/tracking\\20250311_LCP1-NLSMSC_local\\tracking_20250328_redux\\well0000\\track_0000_2339_cb\\tracks_fluo.csv"

In [None]:
fig_path = os.path.join(root, "figures", "tracking", project_name, tracking_config)
os.makedirs(fig_path, exist_ok=True)

## Load raw nucleus mask data

In [None]:
from glob2 import glob
from tqdm import tqdm

fluo_path = os.path.join(root, "built_data", "fluorescence_data", project_name, "")
fluo_df_path_list = sorted(glob(fluo_path + "*.csv"))
fluo_df_list = []
for fluo_p in tqdm(fluo_df_path_list):
    df = pd.read_csv(fluo_p)
    fluo_df_list.append(df)

fluo_df = pd.concat(fluo_df_list, axis=0, ignore_index=True)

### Plot numbers of cells over time

In [None]:
import plotly.express as px
from src.utilities.plot_functions import format_2d_plotly

# get tracking-based counts
tid, tc = np.unique(tracks_df["t"], return_counts=True)
counts_df = pd.DataFrame(tid, columns=["frame"])
counts_df["n_nuclei_track"] = tc

# get segmentation-based counts
tidf, tcf = np.unique(fluo_df["frame"], return_counts=True)
counts_df["n_nuclei_seg"] = tcf

counts_df["stage"] = 26 + counts_df["frame"] * 1.5 / 60

fig = px.line(counts_df, x="stage", y="n_nuclei_seg")

axis_labels = ["stage (hpf)", "number of nuclei"]

fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=4))

fig.write_image(fig_path + "n_cells_seg.png")
fig.show()

In [None]:
import plotly.graph_objects as go

fig = px.line(counts_df, x="frame", y="n_nuclei_seg")

fig.add_traces(go.Scatter(x=counts_df["frame"], y=counts_df["n_nuclei_track"], mode="lines"))
axis_labels = ["stage (hpf)", "number of nuclei"]

fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=4))

# fig.write_image(fig_path + "n_cells_seg.png")
fig.show()

### Look at emergence of lcp+ cells

In [None]:
fluo_thresh = 115

# # initialize
# full_frame_vec = np.arange(0, np.max(tracks_df["t"])+1)
# track_counts = np.zeros_like(full_frame_vec)
# seg_counts = np.zeros_like(full_frame_vec)

# # tracking data
# track_df_ft = tracks_df["mean_fluo"] > fluo_thresh
# tr_frames, tr_counts_ = np.unique(tracks_df.loc[track_df_ft, "t"], return_counts=True)
# track_counts[tr_frames.astype(int)] = tr_counts

# # mask data
# mask_df_ft = fluo_df["mean_fluo"] > fluo_thresh
# m_frames, m_counts = np.unique(fluo_df.loc[mask_df_ft, "frame"], return_counts=True)
# seg_counts[m_frames.astype(int)] = m_counts

# # generate data frame
# lcp_df0 = pd.DataFrame(full_frame_vec, columns=["frame"])
# lcp_df0["n_lcp_cells"] = track_counts
# lcp_df0["data type"] = "tracking"

# lcp_df1 = pd.DataFrame(full_frame_vec, columns=["frame"])
# lcp_df1["n_lcp_cells"] = seg_counts
# lcp_df1["data type"] = "segmentation"

# lcp_df = pd.concat([lcp_df0, lcp_df1], ignore_index=True)
# lcp_df["stage"] = 26 + lcp_df["frame"] * 1.5 / 60 

# fig = px.scatter(lcp_df, x="stage", y="n_lcp_cells", color="data type", trendline="ols", opacity=0.5, trendline_options={"poly_order": 3})

# # fig.update_traces()

# axis_labels = ["stage (hpf)", "number of detected lcp+ cells"]
# fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

# fig.show()

In [None]:

# Choose a rolling window size (adjust based on your data):
window_size = 40

# Sort the DataFrame by stage so that rolling is done in the correct order:
lcp_df = lcp_df.sort_values("stage")

# Group by "data type" and compute rolling mean and std.
# Using min_periods=1 so that we still get values at the beginning.
df_trend = lcp_df.groupby("data type").apply(
    lambda x: x.assign(
        moving_avg=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).mean(),
        moving_std=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).std()
    )
).reset_index(drop=True)

# Use a defined color sequence from Plotly and build a mapping (you could also use a custom dictionary).
color_sequence = px.colors.qualitative.Plotly
unique_types = sorted(lcp_df["data type"].unique())  # sort for consistency
color_map = {dt: color for dt, color in zip(unique_types, color_sequence)}
# Alternatively, you can pass a discrete map directly when calling px.scatter:
fig = go.Figure()
# px.scatter(
#     lcp_df, 
#     x="stage", 
#     y="n_lcp_cells", 
#     color="data type", 
#     color_discrete_map=color_map,
#     opacity=0.5
# )

fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

# Calculate moving average and std (using a rolling window)
lcp_df_sorted = lcp_df.sort_values("stage")
df_trend = lcp_df_sorted.groupby("data type").apply(
    lambda x: x.assign(
        moving_avg=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).mean(),
        moving_std=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).std()
    )
).reset_index(drop=True)

# Utility: Convert HEX color to RGBA (for the translucent fill)
def hex_to_rgba(hex_color, alpha=0.2):
    hex_color = hex_color.lstrip("#")
    r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

# Add traces for each data type: trendline and shaded error band.
for dt in unique_types:
    df_sub = df_trend[df_trend["data type"] == dt].copy()
    # Get the color for this group from our mapping:
    base_color = color_map[dt]
    
    # Add trendline (moving average) trace:
    avg_trace = go.Scatter(
        x=df_sub["stage"],
        y=df_sub["moving_avg"],
        mode="lines",
        name=f"{dt}",
        line=dict(color=base_color, width=3)
    )
    fig.add_trace(avg_trace)

    # Calculate upper and lower bounds:
    upper_bound = df_sub["moving_avg"] + df_sub["moving_std"]
    lower_bound = df_sub["moving_avg"] - df_sub["moving_std"]

    # Create a translucent shaded region for ±1 standard deviation:
    error_band_trace = go.Scatter(
        x=np.concatenate([df_sub["stage"].to_numpy(), df_sub["stage"].to_numpy()[::-1]]),
        y=np.concatenate([upper_bound.to_numpy(), lower_bound.to_numpy()[::-1]]),
        fill="toself",
        fillcolor=hex_to_rgba(base_color, alpha=0.4),
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=False,
        name=f"{dt} ±1 SD"
    )
    fig.add_trace(error_band_trace)

# Format the axes if desired (using your custom formatting function, for example):
axis_labels = ["stage (hpf)", "number of detected lcp+ cells"]
# Assuming format_2d_plotly is your custom function:

fig.show()

fig.write_image(fig_path + "n_lcp_cells_vs_stage.png")

### Segments look shite. What about the raw masks?
Manual inspection indicates that a number of raw masks corresponding to lcp+ nuclei are beingd dropped durring tracking, which is frustrating

In [None]:
from glob2 import glob
from tqdm import tqdm

fluo_path = os.path.join(root, "built_data", "fluorescence_data", project_name, "")
fluo_df_path_list = sorted(glob(fluo_path + "*.csv"))
fluo_df_list = []
for fluo_p in tqdm(fluo_df_path_list):
    df = pd.read_csv(fluo_p)
    fluo_df_list.append(df)

fluo_df = pd.concat(fluo_df_list, axis=0, ignore_index=True)

In [None]:
print(np.sum(fluo_df["mean_fluo"]>fluo_thresh))
print(np.sum(tracks_df["mean_fluo"]>fluo_thresh))

We see substantially more high-fluo frames. Let's look at trends over time

In [None]:
N = 50  # for example

# Group by time 't' and, for each group, pick the N rows with the highest 'mean_fluo'
top_fluo_df = fluo_df.groupby('frame', group_keys=False).apply(lambda x: x.nlargest(N, columns='mean_fluo')).reset_index(drop=True)

In [None]:
fig = px.scatter(top_fluo_df, x="frame", y="mean_fluo")
fig.show()

In [None]:
fi, fc = np.unique(fluo_df.loc[fluo_df["mean_fluo"]>fluo_thresh, "frame"], return_counts=True)

fig = px.scatter(x=fi, y=fc)
fig.show()

Clearly we're losing a ton during the tracking process. Sad.

### Assess overall quality of the tracks. Can we reconstruct lineage trees?

In [None]:
from ultrack.tracks.graph import inv_tracks_df_forest

forest_graph = tracks_df_forest(tracks_df)
inv_forest_graph = inv_tracks_df_forest(tracks_df)

In [None]:
def get_root(cell, parent_map):
    """
    Recursively follow the parent mapping until a cell is reached that has no parent.
    Assumes parent_map[cell] returns a list of parent IDs (with one parent per cell).
    """
    while cell in parent_map:
        # For a simple 1-to-1 mapping, take the first (and only) parent.
        cell = parent_map[cell]
    return cell

# Build a list of results for each child that is a key in parent_map.
results = []
track_index = np.unique(tracks_df["track_id"])
mapped_ids = np.asarray(list(inv_forest_graph.keys()))
for child in tqdm(track_index):
    if child in mapped_ids:
        root = get_root(child, inv_forest_graph)
    else:
        root = child
    # Look up the frame number for the root cell
    root_frame = np.min(tracks_df.loc[tracks_df["track_id"]==root, "t"])
    leaf_frame = np.max(tracks_df.loc[tracks_df["track_id"]==child, "t"])
    results.append({'child_id': child, 'root_id': root, 'root_frame': root_frame, 'leaf_frame': leaf_frame})

# Convert results to a DataFrame
df_roots = pd.DataFrame(results)
df_roots = df_roots.merge(counts_df, how="left", left_on="child_id", right_on="track_id")

In [None]:
df_roots_ft = df_roots.loc[df_roots["track_length"] >= 10]
print(df_roots_ft.shape)

In [None]:
df_roots_ft["span"] = df_roots_ft["leaf_frame"] - df_roots_ft["root_frame"]

fig = px.scatter(df_roots_ft, x="leaf_frame", y="span")
fig.show()

In [None]:
from ultrack.tracks.gap_closing import close_tracks_gaps

test = close_tracks_gaps(tracks_df, max_gap=3, max_radius=50, scale=np.asarray([3.0, 1.0, 1.0]))

In [None]:
test

In [None]:
len(np.unique(test["track_id"]))

In [None]:
len(np.unique(tracks_df["track_id"]))