In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
from pathlib import Path

In [None]:
root = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/killi_tracker/tracking/20250419_BC1-NLSMSC/tracking_20250328_redux/well0000/track_0000_0614/")
# load track class data
track_class_df = pd.read_csv(root / "track_class_df.csv")
# load sphere info
sphere_df = pd.read_csv(root / "sphere_fit.csv")
# load tracks
scale_vec = np.asarray([3, 0.85, 0.85])
track_df = pd.read_csv(root / "tracks_fluo.csv")
track_df[["z", "y", "x"]] = np.multiply(track_df[["z", "y", "x"]].to_numpy(), scale_vec[None, :])
track_class_df.head()

In [None]:
deep_track_df = track_df.copy().merge(track_class_df.loc[:, ["track_id", "track_class"]], how="left", on="track_id")
deep_track_df = deep_track_df.loc[deep_track_df["track_class"]==0]

### Use NN stats to estimate each cell's local density

In [None]:
from src.symmetry_breaking.cluster_tracking import find_clusters_per_timepoint, track_clusters_over_time, stitch_tracklets
from tqdm import tqdm 

# 1) clusters per frame
clusters_by_t = find_clusters_per_timepoint(
    deep_track_df, 
    sphere_df,
    d_thresh=40.0, 
    min_size=25,  # tune
    fluo_col="mean_fluo",
    sphere_center_cols=("xs", "ys","zs")
)

# 2) link across time (motion/feature-aware, with merges)
cluster_ts, merges_df = track_clusters_over_time(
    clusters_by_t,
    link_metric="overlap",         # or "jaccard"
    sim_min=0.3,
    max_centroid_angle=np.deg2rad(15),
    w_sim=1.0, w_feat=0.7, w_pred=0.7,  # tune
    pred_step=1.0
)

# 3) stitch fragmented tracklets (bridge small gaps, fix flips)
stitched_ts, stitch_log = stitch_tracklets(
    cluster_ts,
    gap_max=2, window=1,
    link_metric="overlap", sim_min=0.3,
    max_centroid_angle=np.deg2rad(15),
    w_sim=1.0, w_feat=0.7, w_pred=0.7, w_size=3.0,
    max_iters=3
)

In [None]:
# cluster_ts.tail()

In [None]:
from src.utilities.plot_functions import format_2d_plotly

cluster_ts["cluster_id"] = cluster_ts["cluster_id"].astype(int) 
fig = px.scatter(cluster_ts, x="fluo_mean", y="deg_mean", size="size", color="t")
fig = format_2d_plotly(fig,axis_labels=["nuclear BC1", "average degree"], marker_size=30)
fig.update_traces(marker=dict(size=cluster_ts["size"]))
fig.update_xaxes(range=[100, 700])
fig.update_yaxes(range=[0.9, 6])
fig.show()

In [None]:
stitched_ts["cluster_id_stitched"] = stitched_ts["cluster_id_stitched"].astype(str) 
fig = px.scatter(stitched_ts, x="t", y="fluo_mean", size="size", color="cluster_id_stitched")
# fig = format_2d_plotly(fig,axis_labels=["time", "average degree"], marker_size=30)
# fig.update_traces(marker=dict(size=cluster_ts["size"]))
fig.update_xaxes(range=[0, 614])
# fig.update_yaxes(range=[0, 6])
fig.show()

In [None]:
stitched_ts.columns

In [None]:
lag_vec = results[0]["lags"]
nboots = 100

xcorr_list = []
xcorr_se_list = []
for t in tqdm(range(len(t_bins)-1)):
    results = result_dict[t]
    xcorr_array = np.asarray([r["xcorr"] for r in results])
    options = np.arange(xcorr_array.shape[0])
    mu_array = np.empty((nboots, xcorr_array.shape[1]))
    for n in range(nboots):
        boot_ids = np.random.choice(options, len(options), replace=True)
        mu_array[n, :] = np.mean(xcorr_array[boot_ids, :], axis=0)
    xcorr_mean = np.mean(mu_array, axis=0)
    xcorr_se = np.std(mu_array, axis=0)
    xcorr_list.append(xcorr_mean)
    xcorr_se_list.append(xcorr_se)

In [None]:
lag_long = np.tile(lag_vec, len(t_bins)-1)
id_vec = np.repeat(np.arange(len(t_bins)-1), len(xcorr_list[0]))
xcorr_long = np.concatenate(xcorr_list, axis=0)
xcorr_se_long = np.concatenate(xcorr_se_list, axis=0)

x_df = pd.DataFrame(id_vec, columns=["time_group"])
x_df["lag"] = lag_long
x_df["corr"] = xcorr_long
x_df["corr_se"] = xcorr_se_long

In [None]:
fig = px.line(x_df, x="lag", y="corr", error_y="corr_se",color="time_group")
fig.show()

In [None]:
xcorr_array.shape

In [None]:
import plotly.graph_objects as go

t = 100
test_df = deep_track_df.loc[deep_track_df["t"]==t, :]
sp_test = sphere_df.loc[sphere_df["t"]==t, :]
fig = go.Figure()
fig.add_traces(go.Scatter3d(x=test_df["x"], y=test_df["y"], z=test_df["z"], mode="markers", marker=dict(opacity=0.1, size=6)))
fig.add_traces(go.Scatter3d(x=sp_test["xs"], y=sp_test["ys"], z=sp_test["zs"], mode="markers", marker=dict(size=6)))
fig.show()

In [None]:
test_df.shape

In [None]:
A = 4*np.pi*501**2

### 

In [None]:
941/A * 1000 * 1000