### This notebook fits a reference spline to HF and AB reference data

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob
from pathlib import Path
from tqdm import tqdm
from src.functions.plot_functions import format_2d_plotly, format_3d_plotly
import plotly.io as pio            # <-- run this once near the top of the notebook
pio.templates.default = "plotly"

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
fig_root = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250730")
fig_path = fig_root / "chem_screen_results"
os.makedirs(fig_path, exist_ok=True)

root = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/")
model_name = "20241107_ds_sweep01_optimum"
model_class = "legacy"
# later_than = 20250501
experiments = ["20240812", "20250215",
               "20250612_24hpf_ctrl_atf6", "20250612_24hpf_wfs1_ctcf",
               "20250612_30hpf_ctrl_atf6", "20250612_30hpf_wfs1_ctcf",
               "20250612_36hpf_ctrl_atf6", "20250612_36hpf_wfs1_ctcf"]


# load latent embeddings
latent_path = root / "analysis" / "latent_embeddings" / model_class / model_name
df_list = []
for e, exp in enumerate(tqdm(experiments)):
    df_path = latent_path / f"morph_latents_{exp}.csv"
    df_temp = pd.read_csv(df_path)
    df_list.append(df_temp)

latent_df = pd.concat(df_list) 

# load metadata
meta_path = root / "metadata" / "embryo_metadata_files"
df_list = []
for e, exp in enumerate(tqdm(experiments)):
    df_path = meta_path / f"{exp}_embryo_metadata.csv"
    df_temp = pd.read_csv(df_path)
    df_list.append(df_temp)

meta_df = pd.concat(df_list) 

# remove one problematic ID
# print(meta_df.shape)
# rm_ids = ["20250624_chem02_35C_T00_1216_C02_e01", "20250624_chem02_35C_T01_1711_C02_e01","20250625_chem02_35C_T02_1228_C02_e01"]
# meta_df = meta_df.loc[~meta_df["embryo_id"].isin(rm_ids)]
# print(meta_df.shape)
# # path to save data
out_path = os.path.join(root, "results", "20240707", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
# fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/"
# os.makedirs(fig_path, exist_ok=True)
# meta_df.head()

#### Merge embeddings with metadata

In [None]:
# join
keep_cols = ['snip_id', 'well', 'nd2_series_num', 'microscope', 'time_int',  'genotype',
             'chem_perturbation', 'start_age_hpf', 'temperature',
              'well_qc_flag', 'Time Rel (s)', 'use_embryo_flag', 'frame_flag', 'dead_flag']

master_df = meta_df.loc[:, keep_cols].merge(latent_df, how="inner", on=["snip_id"])
master_df["experiment_date"] = master_df["experiment_date"].astype(str)
print(latent_df.shape)
print(master_df.shape)

master_df.tail()

#### Make some helpful flags

In [None]:
# easily distinguish between experiments
crisp_flag_vec = [ "ref0" if exp == "20240812"
                    else "ref1" if exp == "20250215"
                    else "gene7"
                    for exp in master_df["experiment_date"]]
                
master_df["exp_id"] = crisp_flag_vec

# split datasets
ref_embryo_df = master_df.loc[master_df["exp_id"].isin(["ref0", "ref1"])].copy()
crisp_embryo_df = master_df.loc[~master_df["exp_id"].isin(["ref0", "ref1"])].copy()

# # get time point info
# time_int_vec = [int(exp.split("_")[3].replace("T", "")) for exp in chem_embryo_df["experiment_date"]]
# chem_embryo_df.loc[:, "time_int"] = time_int_vec

# time_stamp_vec = [int(exp.replace("_check","").split("_")[-1]) for exp in chem_embryo_df["experiment_date"]]
# chem_embryo_df.loc[:, "time_stamp"] = time_stamp_vec

# datetime_stamp_vec = [int(exp.replace("_check","").split("_")[0] + 
#                           exp.replace("_check","").split("_")[-1]) for exp in chem_embryo_df["experiment_date"]]
# chem_embryo_df.loc[:, "datetime_stamp"] = datetime_stamp_vec

# # spot fix for a metadata issue
# chem_embryo_df.loc[chem_embryo_df["experiment_date"]=="20250703_chem3_34C_T01_1457", "temperature"] = 34

# # ceate new embryo_id var
# eid_vec = chem_embryo_df["exp_id"].str.cat(chem_embryo_df["well"], sep="_").str.cat(chem_embryo_df["temperature"].astype(int).astype(str) 
#                                                                                     + "C", sep="_")
# snip_vec = chem_embryo_df["exp_id"].str.cat(chem_embryo_df["well"], sep="_").str.cat(chem_embryo_df["temperature"].astype(int).astype(str)
#                                                                                      + "C", sep="_").str.cat(
#                                         "T" + chem_embryo_df["time_int"].astype(str).str.zfill(4) , sep="_"      # col 3 (int → 3-digit string)
#                                     )
# chem_embryo_df["snip_id"] = snip_vec
# chem_embryo_df["embryo_id"] = eid_vec

### QC

In [None]:
# # Apply quick fix to some chem labels in #2
# # chem_embryo_df.loc[chem_embryo_df["exp_id"]=="chem2", "chem_perturbation"].unique()
# fix_dict = {'tgfb_i_13':'tgfb_i_6', 'wnt_i_13':'wnt_i_6', 'fgf_i_13':'fgf_i_6', 'bmp_i_13':'bmp_i_6'}

# m = chem_embryo_df["exp_id"].eq("chem2")
# # …and replace only there
# chem_embryo_df.loc[m, "chem_perturbation"] = (
#     chem_embryo_df.loc[m, "chem_perturbation"].replace(fix_dict)
# )

# # remove problem observations from dataset
# qc_emb_list = ["chem_C10_28C", "chem2_A01_35C", "chem2_E01_35C", "chem2_A11_35C", "chem2_B11_35C", 
#                "chem2_D12_34C", "chem2_F12_34C", "chem2_G12_34C", "chem2_H12_34C", "chem_C02_34C",  
#                "chem2_B11_28C", "chem2_A11_34C", "chem_A05_28C", "chem_A02_35C", "chem2_C05_28C"]
# qc_snip_list = ["chem3_F06_35C_T0001"]

# print(chem_embryo_df.shape)
# chem_embryo_df = chem_embryo_df.loc[~chem_embryo_df["snip_id"].isin(qc_snip_list)]
# print(chem_embryo_df.shape)
# chem_embryo_df = chem_embryo_df.loc[~chem_embryo_df["embryo_id"].isin(qc_emb_list)]
# print(chem_embryo_df.shape)
# qc_mask = (~chem_embryo_df["dead_flag"]) & (~chem_embryo_df["frame_flag"]) 
# chem_embryo_df = chem_embryo_df.loc[qc_mask]
# print(chem_embryo_df.shape)

### Fit PCA to just the ref and hotfish data

In [None]:
from sklearn.decomposition import PCA
import re 

# params
n_components = 10
z_pattern = "z_mu_b"
mu_cols = [col for col in ref_embryo_df.columns if re.search(z_pattern, col)]
pca_cols = [f"PCA_{p:02}_bio" for p in range(n_components)]

# fit
np.random.seed(345)
ref_indices = np.random.choice(ref_embryo_df.shape[0], chem_embryo_df.shape[0],replace=False)
morph_pca = PCA(n_components=n_components)
morph_pca.fit(pd.concat([chem_embryo_df[mu_cols], ref_embryo_df.loc[ref_indices, mu_cols]]))#, ref_df[mu_cols]]))

# transform
ref_pca_array = morph_pca.transform(ref_embryo_df[mu_cols])
chem_pca_array = morph_pca.transform(chem_embryo_df[mu_cols])

to_cols = ["snip_id", "embryo_id", "exp_id", "temperature", "timepoint", "chem_perturbation"]
from_cols = ["snip_id", "embryo_id", "exp_id", "temperature", "time_int", "chem_perturbation"]
ref_pca_df = pd.DataFrame(ref_pca_array, columns=pca_cols)
ref_pca_df[to_cols] = ref_embryo_df[from_cols].to_numpy()


chem_pca_df = pd.DataFrame(chem_pca_array, columns=pca_cols)
chem_pca_df[to_cols] = chem_embryo_df[from_cols].to_numpy()

In [None]:
np.sum(chem_embryo_df["embryo_id"]=="chem3_A06_34C")

In [None]:
var_cumulative = np.cumsum(morph_pca.explained_variance_ratio_)
fig = px.line(x=np.arange(n_components), y=var_cumulative, markers=True)

fig.update_layout(xaxis=dict(title="PC number"),
                  yaxis=dict(title="total variance explained"),
                  title="PCA decomposition of morphVAE latent space",
                     font=dict(
                        family="Arial, sans-serif",
                        size=18,  # Adjust this value to change the global font size
                        color="black"
                    ))

fig = format_2d_plotly(fig, axis_labels=["morph PC number", "total variance explained"], marker_size=12)

fig.show()

fig.write_image(os.path.join(fig_path, "morph_pca_var_explained.png"))

## Conduct vector calculations

In [None]:
from itertools import product
import numpy as np
import pandas as pd

# --- Setup ---
n_pc = 10
times_to_use = np.asarray([0])
ctrl_temp = 28
ctrl_pert = "DMSO_6"
n_bootstrap = 1000

analysis_df = chem_pca_df.loc[chem_pca_df["timepoint"].isin(times_to_use)].reset_index(drop=True)
analysis_df["temperature"] = analysis_df["temperature"].astype(int)
exp_id_vec = chem_embryo_df["exp_id"].unique()
df_list = []

# --- Define vector metrics to record ---
delta_vecs = {
    "th_ch": lambda th, ch, tl, cl: th - ch,
    "th_tl": lambda th, ch, tl, cl: th - tl,
    "th_cl": lambda th, ch, tl, cl: th - cl,
    "tl_cl": lambda th, ch, tl, cl: tl - cl,
    "ch_cl": lambda th, ch, tl, cl: ch - cl,
    "pd":    lambda th, ch, tl, cl: th - (cl + (ch - cl) + (tl - cl)),
}

centroid_stds = {
    "cl": lambda cl: cl,
    "ch": lambda ch: ch,
    "tl": lambda tl: tl,
    "th": lambda th: th,
}

for exp in tqdm(exp_id_vec):
    mask = analysis_df["exp_id"] == exp
    pert_id_vec = analysis_df.loc[mask, "chem_perturbation"].unique()
    temp_id_vec = analysis_df.loc[mask, "temperature"].unique()
    temp_id_vec = temp_id_vec[temp_id_vec != ctrl_temp]
    
    cl_mask = mask & analysis_df["chem_perturbation"].eq(ctrl_pert) & analysis_df["temperature"].eq(ctrl_temp)
    cl_array = analysis_df.loc[cl_mask, pca_cols[:n_pc]].to_numpy()
    
    for temp, chem in product(temp_id_vec, pert_id_vec):
        ch_mask = mask & analysis_df["chem_perturbation"].eq(ctrl_pert) & analysis_df["temperature"].eq(temp)
        ch_array = analysis_df.loc[ch_mask, pca_cols[:n_pc]].to_numpy()
        tl_mask = mask & analysis_df["chem_perturbation"].eq(chem) & analysis_df["temperature"].eq(ctrl_temp)
        tl_array = analysis_df.loc[tl_mask, pca_cols[:n_pc]].to_numpy()
        th_mask = mask & analysis_df["chem_perturbation"].eq(chem) & analysis_df["temperature"].eq(temp)
        th_array = analysis_df.loc[th_mask, pca_cols[:n_pc]].to_numpy()

        if min(cl_array.shape[0], ch_array.shape[0], tl_array.shape[0], th_array.shape[0]) < 3:
            continue

        # --- Bootstrap storage ---
        vec_metrics = {k: np.zeros((n_bootstrap, n_pc)) for k in delta_vecs}
        std_metrics = {k: np.zeros((n_bootstrap,)) for k in centroid_stds}
        # Will store bootstrap magnitudes (for SEs)
        mag_metrics = {k: np.zeros(n_bootstrap) for k in delta_vecs}

#         mag_metrics["thch_diff_tlcl"] = np.zeros(n_bootstrap) 
        
        # --- Bootstrap loop ---
        for i in range(n_bootstrap):
            sample = {
                "cl": cl_array[np.random.choice(cl_array.shape[0], cl_array.shape[0], replace=True)],
                "ch": ch_array[np.random.choice(ch_array.shape[0], ch_array.shape[0], replace=True)],
                "tl": tl_array[np.random.choice(tl_array.shape[0], tl_array.shape[0], replace=True)],
                "th": th_array[np.random.choice(th_array.shape[0], th_array.shape[0], replace=True)],
            }
            centroids = {k: np.mean(sample[k], axis=0) for k in sample}
            stds      = {k: np.sqrt(np.sum(np.var(sample[k], axis=0))) for k in sample}
            
            # Compute all deltas
            for key, func in delta_vecs.items():
                vec = func(centroids["th"], centroids["ch"], centroids["tl"], centroids["cl"])
                vec_metrics[key][i, :] = vec
                mag_metrics[key][i] = np.linalg.norm(vec)
            for key, func in centroid_stds.items():
                std_metrics[key][i] = stds[key]
            
            # Difference of differences (directly on the bootstrap magnitudes)
#             mag_metrics["thch_diff_tlcl"][i] = (
#                 mag_metrics["th_ch"][i] - mag_metrics["tl_cl"][i]
#             )

        # --- Summarize to DataFrame ---
        out = {"chem": chem, "temp": temp, "exp_id": exp}
        # Delta magnitudes & SEs
        for key in mag_metrics:
            out[f"{key}_delta"] = mag_metrics[key].mean()
            out[f"{key}_delta_se"] = mag_metrics[key].std(ddof=1)
        # Centroid stds
        for key, arr in std_metrics.items():
            out[f"{key}_std"] = arr.mean()
            out[f"{key}_std_se"] = arr.std(ddof=1)

        df_list.append(pd.DataFrame([out]))

pert_pd_df = pd.concat(df_list, ignore_index=True)

Combine like treatments into aggregate groups

In [None]:
def summarize_equal_n(g: pd.DataFrame) -> pd.Series:
    """Combine rows that have the same sample size (equal-n rule)."""
    k   = len(g)                               # number of rows in the group
    out = {}

    for col in g.columns:
        # --- point estimates -------------------------------------------------
        if col.endswith("_se"): #and not col.endswith("_delta_se"):
            mean_col = col[:-3]  # removes "_se"
            means = g[mean_col]
            within = np.sqrt((g[col] ** 2).sum()) / k
            between = means.std(ddof=1) / np.sqrt(k) if k > 1 else 0
            out[col] = np.sqrt(within ** 2 + between ** 2)

        # --- standard errors -------------------------------------------------
        elif col.endswith("_delta") | col.endswith("_std"):
            out[col] = g[col].mean()
            

    # --- meta information ----------------------------------------------------
    out["source"] = "+".join(sorted(g["exp_id"].unique()))
    out["n_rep"]  = k
    return pd.Series(out)

pert_pd_summ = (
    pert_pd_df
    .groupby(["chem", "temp"])
    .apply(summarize_equal_n)
    .reset_index()              # <-- turn the group keys back into columns
)

pert_pd_summ.head()

In [None]:
import re, itertools, colorsys
import pandas as pd
import plotly.express as px
import matplotlib.colors as mcolors

labels = [
    "DMSO_6", "pi3k_lo_i_6", "pi3k_hi_i_6", "ra_lo_i_6", "ra_hi_i_6",
    "tgfb_lo_i_6", "tgfb_hi_i_6", "shh_i_6", "notch_i_6", "ra_i_6",
    "tgfb_i_13", "fgf_i_13", "wnt_i_13", "mTOR_i_6", "hsp90_i_6",
    "bmp_i_13", "bmp_i_6", "tgfb_i_6", "fgf_i_6", "wnt_i_6", "nfkb_i_6"
]

###############################################################################
# 1.  Parse label  → treatment_type, time
###############################################################################
def parse(label):
    parts = label.split('_')
    # keep dosage (parts[1]) only when there are four chunks (base-dosage-i-time)
    treat_type = parts[0] if len(parts) < 4 else f"{parts[0]}_{parts[1]}"
    return dict(label=label, treatment_type=treat_type, time=int(parts[-1]))

meta = pd.DataFrame(parse(s) for s in labels)

###############################################################################
# 2.  Build a consistent color dictionary
###############################################################################
def adjust_lightness(hex_color, factor):
    r, g, b = mcolors.to_rgb(hex_color)
    h, l, s = colorsys.rgb_to_hls(r, g, b)
    l = max(0, min(1, l * factor))
    return mcolors.to_hex(colorsys.hls_to_rgb(h, l, s))

palette = itertools.cycle(px.colors.qualitative.Dark24)

# one base colour per treatment_type (except DMSO)
base_color = {t: next(palette) for t in meta.treatment_type.unique()
              if t.lower() != "dmso"}
base_color["DMSO"] = "#808080"     # fixed grey

color_dict = {}
for ttype, sub in meta.groupby("treatment_type"):
    times = sorted(sub.time.unique())
    factors = {t: 1 + 0.5 * (i / max(1, len(times)-1))
               for i, t in enumerate(times)}
    for _, row in sub.iterrows():
        if ttype == "DMSO":
            color_dict[row.label] = base_color["DMSO"]
        else:
            color_dict[row.label] = adjust_lightness(
                base_color[ttype], factors[row.time])

###############################################################################
# 3.  (Optional) quick look
###############################################################################
meta["colour"] = meta.label.map(color_dict)
meta.head()

pert_pd_summ = pert_pd_summ.merge(meta, how="left", left_on="chem", right_on="label")

In [None]:
# helper to spell-out dose tokens
dose_map = {"lo": "low", "hi": "high"}

pretty_dict = {}
for _, row in meta.iterrows():
    base, *maybe_dose = row.treatment_type.split("_")
    base_str = base.upper()                        # PI3K, TGFB, …

    # include dose if present
    if maybe_dose:
        dose_str = dose_map.get(maybe_dose[0], maybe_dose[0])
        pretty = f"{base_str} {dose_str} {row.time}hpf"
    else:
        pretty = f"{base_str} {row.time}hpf"

    pretty_dict[row.label] = pretty
    
    
def apply_pretty_names(fig, pretty):
    """Replace trace legend/hover names with prettier versions."""
    for tr in fig.data:
        if tr.name in pretty:
            new = pretty[tr.name]
            tr.name = new
            tr.legendgroup = new
            # update hover; keeps other fields untouched
            if tr.hovertemplate and "%{customdata}" not in tr.hovertemplate:
                tr.hovertemplate = tr.hovertemplate.replace(tr.name, new)

## Visualize embryo morphologies

In [None]:
chem_pca_df["chem"] = chem_pca_df["chem_perturbation"].map(pretty_dict)
plot_cols = pca_cols[:3]

fig = px.scatter_3d(chem_pca_df, 
                    x=plot_cols[0], y=plot_cols[1],z=plot_cols[2],
                    color="chem",                         # ← use pretty name column
                    symbol="temperature",
                    color_discrete_map={
                    pretty_dict[k]: v for k, v in color_dict.items()
                }
            )

fig.show()


## Compare vector quantities

In [None]:
# plot_filter = pert_comp_df["temp"].ne(28)
pert_pd_summ = pert_pd_summ.copy()
pert_pd_summ["chem_pretty"] = pert_pd_summ["chem"].map(pretty_dict)

fig = px.scatter(
    pert_pd_summ,
    x="tl_cl_delta", y="th_ch_delta",
    error_x="tl_cl_delta_se", error_y="th_ch_delta_se",
    color="chem_pretty",                         # ← use pretty name column
    symbol="temp",
    color_discrete_map={
        pretty_dict[k]: v for k, v in color_dict.items()
    }
)

seen_colours = set()        # remember which chem (colour) we've kept

for tr in fig.data[::-1]:
    col = tr.marker.color          # unique per 'chem'
    if col in seen_colours:
        tr.showlegend = False      # second/third… symbol variant → hide
    else:
        seen_colours.add(col) 
        
# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)

# Add reference lines at x=0 and y=0
xm = 5
ym = 5


fig = format_2d_plotly(fig, axis_labels=["treatment strength: 28C (δ1)", 
                                         "treatment strength: hot (δ3)"], 
                                         font_size=14, marker_size=12)

ref_line = np.linspace(-0.25, xm)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))


fig.update_layout(
    xaxis=dict(range=[-0.25, xm], zeroline=False),
    yaxis=dict(range=[-0.25, ym], zeroline=False),
    width=800,
    height=600,
    legend=dict(title="treatment")
)

fig.update_xaxes(domain=[0, 0.9])

fig.show()
fig.write_image(os.path.join(fig_path, "screen01_treat_lo_treat_hi.png"))

In [None]:
# second version with disqualified candidates grayed out
#808080
pert_pd_summ["thch_tlcl_delta"] = pert_pd_summ["th_ch_delta"] - pert_pd_summ["tl_cl_delta"]
sig_denom = np.sqrt(pert_pd_summ["th_ch_delta_se"]**2 + pert_pd_summ["tl_cl_delta_se"]**2)
pert_pd_summ["thch_tlcl_z"] = np.divide(pert_pd_summ["thch_tlcl_delta"].to_numpy(), sig_denom)
pert_pd_summ["thch_tlcl_sig"] = pert_pd_summ["thch_tlcl_z"] > 1
pert_pd_summ["thch_tlcl_delta_se"] = sig_denom

pert_pd_summ["sig_label"] = pert_pd_summ["thch_tlcl_sig"].map({
    True: "significant",
    False: "not significant"
})
# diff_df["delta_z"] = np.divide(pert_pd_summ["thch_diff_tlcl_delta"], pert_pd_summ["thch_diff_tlcl_delta_se"])
# diff_df["delta_z"].unique()
# pert_pd_summ["sig_flag"] = pert_pd_summ["thch_diff_tlcl_delta"] >= 2
# )

fig = px.scatter(
    pert_pd_summ,
    x="tl_cl_delta", y="th_ch_delta",
    error_x="tl_cl_delta_se", error_y="th_ch_delta_se",
    color="sig_label", 
    hover_data={"chem", "thch_tlcl_delta", "thch_tlcl_z"},
    color_discrete_map={
        "not significant": "#808080",
        "significant": "#00C000"
    },
    symbol="temp"
)

# fig.update_traces(marker=dict(size=12, color=pert_pd_summ["thch_tlcl_z"]))
seen_colours = set()        # remember which chem (colour) we've kept

for tr in fig.data[::-1]:
    col = tr.marker.color          # unique per 'chem'
    if col in seen_colours:
        tr.showlegend = False      # second/third… symbol variant → hide
    else:
        seen_colours.add(col) 
        
# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)

# Add reference lines at x=0 and y=0
xm = 5
ym = 5


fig = format_2d_plotly(fig, axis_labels=["treatment strength (ref)", 
                                         "treatment strength (hot)"], 
                                         font_size=14, marker_size=12)

ref_line = np.linspace(-0.25, xm)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))


fig.update_layout(
    xaxis=dict(range=[-0.25, xm], zeroline=False),
    yaxis=dict(range=[-0.25, ym], zeroline=False),
    width=800,
    height=600,
    legend=dict(title="temperature effect")
)

fig.update_xaxes(domain=[0, 0.9])

fig.show()
fig.write_image(os.path.join(fig_path, "screen01_treat_lo_treat_hi_sig.png"))

### Step 2: make sure that the 28 and 34/35C treatments look distinct from one another

In [None]:
fig = px.scatter(
    pert_pd_summ,
    x="thch_tlcl_delta", y="th_tl_delta",
    error_x="thch_tlcl_delta_se", error_y="th_tl_delta_se",
    hover_data={"chem", "temp"},
    color="chem_pretty",                         # ← use pretty name column
    symbol="temp",
    color_discrete_map={
        pretty_dict[k]: v for k, v in color_dict.items()
    }
)

seen_colours = set()        # remember which chem (colour) we've kept

for tr in fig.data[::-1]:
    col = tr.marker.color          # unique per 'chem'
    if col in seen_colours:
        tr.showlegend = False      # second/third… symbol variant → hide
    else:
        seen_colours.add(col) 
        
# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)

# Add reference lines at x=0 and y=0
xm = 3.5
ym = 3.5


fig = format_2d_plotly(fig, axis_labels=["temperature effect (δ₃ - δ₁)", "effect difference (δ₂)"], 
                       marker_size=10)

# ref_line = np.linspace(-1, xm)
# x0 = np.zeros_like(ref_line)
# fig.add_trace(go.Scatter(x=x0, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))


fig.update_layout(
#     xaxis=dict(range=[-1, xm], zeroline=False),
#     yaxis=dict(range=[-0.25, ym], zeroline=False),
    width=800,
    height=600,
    legend=dict(title="treatment")
)

fig.update_xaxes(domain=[0, 0.9])

fig.show()
fig.write_image(os.path.join(fig_path, "screen02_delta_heat_vs_delta_treat.png"))

In [None]:
ref_filter = (pert_pd_summ["chem"] == "DMSO_6") & (pert_pd_summ["temp"] == 35)
pert_pd_summ["th_tl_delta_delta"] = pert_pd_summ["th_tl_delta"].to_numpy() -  \
                                    pert_pd_summ.loc[ref_filter, "th_tl_delta"].to_numpy()

ref_se = pert_pd_summ.loc[ref_filter,"th_tl_delta_se"].to_numpy()
sig_col = np.sqrt(pert_pd_summ["th_tl_delta_se"].to_numpy()**2 + ref_se**2)
pert_pd_summ["th_tl_delta_delta_se"] = sig_col
pert_pd_summ["th_tl_delta_delta_z"] = np.divide(pert_pd_summ["th_tl_delta_delta"], 
                                                pert_pd_summ["th_tl_delta_delta_se"]) 

bin_vec = (pert_pd_summ["th_tl_delta_delta_z"].to_numpy() > 1) & pert_pd_summ["thch_tlcl_sig"].to_numpy()
string_vec = ["candidate" if b else "other" for b in bin_vec]
pert_pd_summ["candidate_label"] = string_vec

fig = px.scatter(
    pert_pd_summ,
    x="thch_tlcl_delta", y="th_tl_delta",
    error_x="thch_tlcl_delta_se", error_y="th_tl_delta_se",
    color="candidate_label",                    # ← use pretty name column
    symbol="temp",
    hover_data={"chem", "temp"},
    color_discrete_map={
        "other": "#808080",
        "candidate": "#00C000"
    }
)

seen_colours = set()        # remember which chem (colour) we've kept

for tr in fig.data[::-1]:
    col = tr.marker.color          # unique per 'chem'
    if col in seen_colours:
        tr.showlegend = False      # second/third… symbol variant → hide
    else:
        seen_colours.add(col) 
        
# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)

# Add reference lines at x=0 and y=0
xm = 3.5
ym = 3.5


fig = format_2d_plotly(fig, axis_labels=["temperature effect (δ₃ - δ₁)", "effect difference (δ₂)"], 
                       marker_size=10)

fig.update_layout(
    width=800,
    height=600,
    legend=dict(title="nominated treatments")
)

fig.update_xaxes(domain=[0, 0.9])

fig.show()
fig.write_image(os.path.join(fig_path, "screen02_delta_heat_vs_delta_treat_sig.png"))

In [None]:
# pert_pd_summ["std_delta"] = pert_pd_summ["chem_hi_std"] - pert_pd_summ["chem_lo_std"]
# pert_pd_summ["std_delta_se"] = np.sqrt(pert_pd_summ["chem_lo_std_se"]**2 + pert_pd_summ["chem_hi_std_se"]**2)
# # plot_filter = (~pert_pd_summ["temp"].eq(28))# & (pert_pd_df["temp"].eq(35)) #  (~pert_pd_df["exp_id"].eq("chem2")) &

# fig = px.scatter(
#     pert_pd_summ,
#     y="std_delta", x="heat_delta",
#     error_y="std_delta_se", error_x="heat_delta_se",
#     color="chem_pretty",                         # ← use pretty name column
#     symbol="temp",
#     color_discrete_map={
#         pretty_dict[k]: v for k, v in color_dict.items()
#     }
# )

# seen_colours = set()        # remember which chem (colour) we've kept

# for tr in fig.data[::-1]:
#     col = tr.marker.color          # unique per 'chem'
#     if col in seen_colours:
#         tr.showlegend = False      # second/third… symbol variant → hide
#     else:
#         seen_colours.add(col) 
        
# # Increase marker size and remove error bar caps
# fig.update_traces( # Increase dot size
#     error_x=dict(thickness=1, width=0),  # Remove horizontal caps
#     error_y=dict(thickness=1, width=0),  # Remove vertical caps
# )

# # Add reference lines at x=0 and y=0
# xm = 3.5
# ym = 1.5


# fig = format_2d_plotly(fig, axis_labels=["(treatment hi) - (treatment lo)", "(sigma lo) - (sigma hi)"], 
#                        marker_size=10)

# ref_line = np.linspace(-1, xm)
# x0 = np.zeros_like(ref_line)
# fig.add_trace(go.Scatter(x=x0, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))


# fig.update_layout(
#     xaxis=dict(range=[-1, xm], zeroline=False),
#     yaxis=dict(range=[-0.25, ym], zeroline=False),
#     width=800,
#     height=600
# )

# fig.show()

## Take a closer look at the leading candidates

#### First fit a spline to WT

In [None]:
from src.functions.spline_fitting_v2 import spline_fit_wrapper
import time
import re 
from tqdm import tqdm 


# pattern = r"PCA_.*_bio"
# pattern = r"z_mu_b"
n_boots = 50
n_spline_points = 500
boot_size = 1000
n_pc = 5

fit_pca_df = ref_pca_df.loc[ref_pca_df["exp_id"]=="ref0", :]
# fit normal version
spline_df = spline_fit_wrapper(fit_pca_df.drop(labels=pca_cols[n_pc:], axis=1), n_boots=n_boots, n_spline_points=n_spline_points, 
                                    stage_col="timepoint", 
                                   obs_weights=None, boot_size=boot_size)

In [None]:
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors

N = 10  # Number of neighbors you want
n_pc=5
# Fit nearest neighbor model on the PCA space
nbrs = NearestNeighbors(n_neighbors=N, algorithm='auto')
nbrs.fit(fit_pca_df[pca_cols].values)

# For each point in spline_df, find neighbors in fit_pca_df
distances, indices = nbrs.kneighbors(spline_df[pca_cols[:n_pc]].values)

# Extract the time_int values for the neighbors
neighbor_timeints = fit_pca_df["time_int"].values[indices]  # shape: (len(spline_df), N)

# Compute mean and standard error
mean_timeint = neighbor_timeints.mean(axis=1)
se_timeint = neighbor_timeints.std(axis=1, ddof=1) / np.sqrt(N)

# Assign to new columns
spline_df["nn_timeint_mean"] = mean_timeint
spline_df["nn_timeint_se"] = se_timeint

In [None]:
pca_cols

#### Wnt-i (8ss)

In [None]:
# chem_pca_df["chem"] = pert_pd_summ["chem"].map(pretty_dict)
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["wnt_i_13"]) & \
                chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", hover_data={"snip_id", "temperature"},
                    color_discrete_map={pretty_dict[k]: v for k, v in color_dict.items()})

fig.add_traces(go.Scatter3d(x=spline_df[pca_cols[0]],
                                y=spline_df[pca_cols[1]],
                                z=spline_df[pca_cols[2]],
                                mode="lines", line=dict(color=spline_df["timepoint"], 
                                                        width=7, colorscale="YlOrRd_r", cmin=10, cmax=46)))
    
    
fig = format_3d_plotly(fig, marker_size=10, axis_labels=["PC 0", "PC 1", "PC 2"], theme="dark")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.1), showlegend=False))

fig.show()

In [None]:
# Pull snips for each condition

#### HSP90-i (shield)

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["hsp90_i_6", "DMSO_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem", hover_data={"snip_id", "temperature", "exp_id"},
                    color_discrete_map={pretty_dict[k]: v for k, v in color_dict.items()})

fig = format_3d_plotly(fig, marker_size=10, axis_labels=["PC 0", "PC 1", "PC 2"], theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

#### BMP-i (8ss)

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["bmp_i_13", "DMSO_6"]) & \
                chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem", hover_data={"snip_id", "temperature", "exp_id"},
                    color_discrete_map={pretty_dict[k]: v for k, v in color_dict.items()})

fig = format_3d_plotly(fig, marker_size=10, axis_labels=["PC 0", "PC 1", "PC 2"], theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

#### mTOR-i

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["mTOR_i_6", "DMSO_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem", hover_data={"snip_id", "temperature", "exp_id"},
                    color_discrete_map={pretty_dict[k]: v for k, v in color_dict.items()})

fig = format_3d_plotly(fig, marker_size=10, axis_labels=["PC 0", "PC 1", "PC 2"], theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

#### Shh-i

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["shh_i_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="exp_id", hover_data={"snip_id", "temperature", "timepoint"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

#### TGF-Beta

In [None]:
chem_pca_df.head()

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["tgfb_lo_i_6"]) & \
        chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="timepoint", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### HSP90

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["DMSO_6", "hsp90_i_6"]) & \
                        chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem_perturbation", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### RA

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["ra_lo_i_6", "DMSO_6"]) & \
                        chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

#### Wnt

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["wnt_i_13"]) & \
                                    ~chem_pca_df["exp_id"].isin(["chem2"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="timepoint", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### BMP

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["bmp_i_13"]) & \
                                    chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### PI3K

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["DMSO_6", "pi3k_lo_i_6"]) & \
                                    chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "chem_perturbation",
                                                                           "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### mTOR

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["DMSO_6", "mTOR_i_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])# & chem_pca_df["temperature"].eq(34)

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem_perturbation", hover_data={"snip_id", "chem_perturbation",
                                                                                 "temperature"})

# fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
#                           marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

np.mean(chem_embryo_df["use_embryo_flag"])

In [None]:
meta_df.loc[mask].shape

In [None]:
check_list = meta_df.loc[meta_df.focus_flag & (~meta_df.frame_flag) & (~meta_df.dead_flag), "embryo_id"].tolist()
check_list