In [None]:
import pandas as pd
import numpy as np
import os
from src.utilities.plot_functions import format_2d_plotly
from ultrack.tracks.graph import get_paths_to_roots, tracks_df_forest
from glob2 import glob
from tqdm import tqdm

## Load tracking data

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/killi_tracker/"

# main tracking dataset(s)
project_name = "20250311_LCP1-NLSMSC_local"
tracking_config = "tracking_20250328_redux"
tracking_name0 = "track_0000_2200"
tracking_name1 = "track_2000_2339"

track_path0 = os.path.join(root, "tracking", project_name, tracking_config, "well0000", tracking_name0, "tracks_stitched_fluo.csv")
tracks_df0 = pd.read_csv(track_path0)

track_path1 = os.path.join(root, "tracking", project_name, tracking_config, "well0000", tracking_name1, "tracks_stitched_fluo.csv")
tracks_df1 = pd.read_csv(track_path1)
tracks_df1.loc[:, "t"] = tracks_df1.loc[:, "t"] + 2000
# fluo mask-based tracking
# main tracking dataset(s)
project_name_m = "20250311_LCP1-NLSMSC_marker_local"
tracking_config = "tracking_20250328_redux"
tracking_name_m = "track_1200_2339"

track_path_m = os.path.join(root, "tracking", project_name_m, tracking_config, "well0000", tracking_name_m, "tracks_stitched_fluo.csv")
marker_tracks_df = pd.read_csv(track_path_m)
marker_tracks_df.loc[:, "t"] = marker_tracks_df.loc[:, "t"] + 1200

## Load mask data

In [None]:
# full mask dataset
full_mask_fluo_dir = os.path.join(root, "built_data", "fluorescence_data", project_name, "")
fluo_frames = sorted(glob(full_mask_fluo_dir + "*.csv"))
fluo_df_list = []
for df_path in tqdm(fluo_frames):
    df = pd.read_csv(df_path)
    fluo_df_list.append(df)

fluo_df_full = pd.concat(fluo_df_list, axis=0, ignore_index=True)

# marker mask dataset
marker_mask_fluo_dir = os.path.join(root, "built_data", "fluorescence_data", project_name_m, "")
marker_fluo_frames = sorted(glob(marker_mask_fluo_dir + "*.csv"))
fluo_df_list = []
for df_path in tqdm(marker_fluo_frames):
    df = pd.read_csv(df_path)
    fluo_df_list.append(df)

fluo_df_marker = pd.concat(fluo_df_list, axis=0, ignore_index=True)
# fluo_df_marker.loc[:, "frame"] = fluo_df_marker.loc[:, "frame"] + 1200 

In [None]:
np.max(tracks_df0["t"])

## Combine

In [None]:
tracks_df0["df"] = "main0"
tracks_df0 = tracks_df0.rename(columns={"t":"frame", "track_id":"nucleus_id", "fluo_mean":"mean_fluo"})
tracks_df1["df"] = "main1"
tracks_df1 = tracks_df1.rename(columns={"t":"frame", "track_id":"nucleus_id", "fluo_mean":"mean_fluo"})
marker_tracks_df["df"] = "marker"
marker_tracks_df = marker_tracks_df.rename(columns={"t":"frame", "track_id":"nucleus_id", "fluo_mean":"mean_fluo"})

fluo_df_full["df"] = "mask"
fluo_df_marker["df"] = "mask_marker"

master_df = pd.concat([tracks_df0, tracks_df1, marker_tracks_df, fluo_df_full, fluo_df_marker], axis=0, ignore_index=True)
master_df["stage"] = 26 + master_df["frame"]*1.5/60

In [None]:
# make figure directory
fig_path = os.path.join(root, "figures", "tracking", project_name, tracking_config)
os.makedirs(fig_path, exist_ok=True)

## Plot numbers of lcp+ cells over time according to our various datasets

In [None]:
import plotly.express as px

fluo_thresh = 115
window_size = 25

master_df["fluo_flag"] = (master_df["mean_fluo"] > fluo_thresh).astype(float)
master_df_g = master_df.loc[:, ["stage", "df", "fluo_flag"]].groupby(["stage", "df"]).sum().reset_index()

master_df_g['fluo_trend'] = master_df_g.groupby(['df'])['fluo_flag']\
    .transform(lambda s: s.rolling(window=window_size, min_periods=1).mean())

fig = px.line(master_df_g, x="stage", y="fluo_trend", color="df")
axis_labels = ["stage (hpf)", "number of lcp+ cells"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

fig.update_traces(line=dict(width=5))

fig.write_image(os.path.join(fig_path, "lcp_cell_count_by_method.png"), scale=2)

fig.show()


In [None]:
(76 - 26) * 60 / 1.5

## Investigate apparent quality of lineage graph

In [None]:
from ultrack.tracks.graph import tracks_df_forest, inv_tracks_df_forest
from ultrack.tracks.gap_closing import tracks_starts, tracks_ends

track_path0 = os.path.join(root, "tracking", project_name, tracking_config, "well0000", tracking_name0, "tracks_stitched_fluo.csv")
tracks_df_test = pd.read_csv(track_path0)

# get dict giving parent for each node
leaf_to_root = inv_tracks_df_forest(tracks_df_test)

# get start and end frames for each track id
starts = tracks_starts(tracks_df_test)
ends = tracks_ends(tracks_df_test)

In [None]:
# how far back, on average, can we go?
lineage_df = ends.copy().drop(labels=["parent_track_id", "id"], axis=1).reset_index(drop=True)

root_df_list = []
child_id_key = np.asarray(list(leaf_to_root.keys()))
for track_id in tqdm(lineage_df["track_id"], "tracing from leaf to root..."):
    
    track_curr = track_id
    n_layers = 0
    while np.isin(track_curr, child_id_key):
        track_curr = leaf_to_root[track_curr]
        n_layers += 1

    start_row = starts.loc[starts["track_id"]==track_curr, ["track_id", "t", "z", "y", "z", "fluo_mean"]]
    start_row["n_branches"] = n_layers
    start_row["leaf_id"] = track_id
    root_df_list.append(pd.DataFrame(start_row))



In [None]:
print("Generating root df...")
root_df = pd.concat(root_df_list, axis=0, ignore_index=True)
root_df = root_df.rename(columns={"track_id":"root_id", "t":"ts", "z":"zs", "y":"ys", "x":"xs", "fluo_mean":"fluo_s"})
print("Merging...")
lineage_df = lineage_df.merge(root_df, how="left", left_on="track_id", right_on="leaf_id")
lineage_df["tree_length"] = lineage_df["t"] - lineage_df["ts"]
print("Done.")

In [None]:
fig = px.histogram(lineage_df, x="tree_length")
fig.show()

In [None]:
fig = px.histogram(lineage_df, x="n_branches")
fig.show()

In [None]:
fig = px.scatter(lineage_df, x="t", y="ts")
fig.show()

In [None]:
np.sum(lineage_df["tree_length"]**2) / np.sum(lineage_df["tree_length"]) * 1.5 / 60

In [None]:
import plotly.graph_objects as go

# Choose a rolling window size (adjust based on your data):
window_size = 40

# Sort the DataFrame by stage so that rolling is done in the correct order:
lcp_df = lcp_df.sort_values("stage")

# Group by "data type" and compute rolling mean and std.
# Using min_periods=1 so that we still get values at the beginning.
df_trend = lcp_df.groupby("data type").apply(
    lambda x: x.assign(
        moving_avg=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).mean(),
        moving_std=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).std()
    )
).reset_index(drop=True)

# Use a defined color sequence from Plotly and build a mapping (you could also use a custom dictionary).
color_sequence = px.colors.qualitative.Plotly
unique_types = sorted(lcp_df["data type"].unique())  # sort for consistency
color_map = {dt: color for dt, color in zip(unique_types, color_sequence)}
# Alternatively, you can pass a discrete map directly when calling px.scatter:
fig = go.Figure()
# px.scatter(
#     lcp_df, 
#     x="stage", 
#     y="n_lcp_cells", 
#     color="data type", 
#     color_discrete_map=color_map,
#     opacity=0.5
# )

fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=18)

# Calculate moving average and std (using a rolling window)
lcp_df_sorted = lcp_df.sort_values("stage")
df_trend = lcp_df_sorted.groupby("data type").apply(
    lambda x: x.assign(
        moving_avg=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).mean(),
        moving_std=x["n_lcp_cells"].rolling(window=window_size, center=True, min_periods=1).std()
    )
).reset_index(drop=True)

# Utility: Convert HEX color to RGBA (for the translucent fill)
def hex_to_rgba(hex_color, alpha=0.2):
    hex_color = hex_color.lstrip("#")
    r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"

# Add traces for each data type: trendline and shaded error band.
for dt in unique_types:
    df_sub = df_trend[df_trend["data type"] == dt].copy()
    # Get the color for this group from our mapping:
    base_color = color_map[dt]
    
    # Add trendline (moving average) trace:
    avg_trace = go.Scatter(
        x=df_sub["stage"],
        y=df_sub["moving_avg"],
        mode="lines",
        name=f"{dt}",
        line=dict(color=base_color, width=3)
    )
    fig.add_trace(avg_trace)

    # Calculate upper and lower bounds:
    upper_bound = df_sub["moving_avg"] + df_sub["moving_std"]
    lower_bound = df_sub["moving_avg"] - df_sub["moving_std"]

    # Create a translucent shaded region for ±1 standard deviation:
    error_band_trace = go.Scatter(
        x=np.concatenate([df_sub["stage"].to_numpy(), df_sub["stage"].to_numpy()[::-1]]),
        y=np.concatenate([upper_bound.to_numpy(), lower_bound.to_numpy()[::-1]]),
        fill="toself",
        fillcolor=hex_to_rgba(base_color, alpha=0.4),
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=False,
        name=f"{dt} ±1 SD"
    )
    fig.add_trace(error_band_trace)

# Format the axes if desired (using your custom formatting function, for example):
axis_labels = ["stage (hpf)", "number of detected lcp+ cells"]
# Assuming format_2d_plotly is your custom function:

fig.show()

fig.write_image(fig_path + "n_lcp_cells_vs_stage.png")

### Segments look shite. What about the raw masks?
Manual inspection indicates that a number of raw masks corresponding to lcp+ nuclei are beingd dropped durring tracking, which is frustrating

In [None]:
from glob2 import glob
from tqdm import tqdm

fluo_path = os.path.join(root, "built_data", "fluorescence_data", project_name, "")
fluo_df_path_list = sorted(glob(fluo_path + "*.csv"))
fluo_df_list = []
for fluo_p in tqdm(fluo_df_path_list):
    df = pd.read_csv(fluo_p)
    fluo_df_list.append(df)

fluo_df = pd.concat(fluo_df_list, axis=0, ignore_index=True)

In [None]:
print(np.sum(fluo_df["mean_fluo"]>fluo_thresh))
print(np.sum(tracks_df["mean_fluo"]>fluo_thresh))

We see substantially more high-fluo frames. Let's look at trends over time

In [None]:
N = 50  # for example

# Group by time 't' and, for each group, pick the N rows with the highest 'mean_fluo'
top_fluo_df = fluo_df.groupby('frame', group_keys=False).apply(lambda x: x.nlargest(N, columns='mean_fluo')).reset_index(drop=True)

In [None]:
fig = px.scatter(top_fluo_df, x="frame", y="mean_fluo")
fig.show()

In [None]:
fi, fc = np.unique(fluo_df.loc[fluo_df["mean_fluo"]>fluo_thresh, "frame"], return_counts=True)

fig = px.scatter(x=fi, y=fc)
fig.show()

Clearly we're losing a ton during the tracking process. Sad.