### Compare hotfish stage scaling to theoretical expectation

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

### Set paths

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

# path to save data
data_path = os.path.join(root, "results", "20250312", "morph_latent_space", "")
os.makedirs(data_path, exist_ok=True)


# set path to Hooke outputs
model_name = "bead_expt_linear" 
latent_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/emb_projections/latent_projections/"
model_path = os.path.join(latent_path, model_name, "")

# fig path
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250320/morphseq_mdl_params/"
os.makedirs(fig_path, exist_ok=True)

### Load data

In [None]:
# morph latent encodings
hf_morph_df = pd.read_csv(data_path + "hf_pca_morph_df.csv")
hf_morph_df = hf_morph_df.set_index("snip_id")

# hooke latent encodings
# seq_df = pd.read_csv(data_path + "hf_seq_df.csv", index_col=0)
seq_time_df = pd.read_csv(model_path + "time_predictions.csv", index_col=0)

# metadata df that allows us to link the two
morphseq_df = pd.read_csv(os.path.join(root, "metadata", "morphseq_metadata.csv"))

### Subset for hotfish2 

In [None]:
# filter out a couple observations that had QC problems
hf_morphseq_df = morphseq_df.loc[np.isin(morphseq_df["snip_id"], hf_morph_df.index), :].reset_index()

# merge on morph stage
hf_morphseq_df = hf_morphseq_df.merge(pd.DataFrame(hf_morph_df.loc[:, ["mdl_stage_hpf"]]), left_on="snip_id", right_index=True, how="left").rename(
                                columns={"mdl_stage_hpf":"morph_stage_hpf"})

# merge on seq staging info
hf_morphseq_df = hf_morphseq_df.merge(pd.DataFrame(seq_time_df.loc[:, ["pseudostage"]]), left_on="sample", right_index=True, how="left").rename(
                    columns={"pseudostage":"seq_stage_hpf"})

### Make seq vs morph plots

In [None]:
from src.functions.plot_functions import format_2d_plotly

# now group by cohort
cohort_stage_df = hf_morphseq_df.loc[:, ["temperature", "timepoint", "seq_stage_hpf", "morph_stage_hpf"]].groupby(
                    ["temperature", "timepoint"]).agg(["mean", "std"])

cohort_stage_df.columns = [f"{col[0]}_{col[1]}" for col in cohort_stage_df.columns]
cohort_stage_df = cohort_stage_df.reset_index()
cohort_stage_df = cohort_stage_df.rename(columns={"timepoint_":"timepoint", "temperature_":"temperature"})


# get predicted stage using linear formula
ref_vec = np.linspace(14, 48)
marker_size = 14
colormap = "RdBu_r"

fig = px.scatter(cohort_stage_df, y="seq_stage_hpf_mean", x="morph_stage_hpf_mean", 
                 error_y="seq_stage_hpf_std", error_x="morph_stage_hpf_std",
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["morphological stage (hpf)", "transcriptional stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)

# fig.update_xaxes(range=[24, 48])
# fig.update_yaxes(range=[24, 48])

fig.show()

fig.write_image(fig_path + "seq_vs_morph_stage.png", scale=2)
fig.write_html(fig_path + "seq_vs_morph_stage.html")

In [None]:
fig = px.scatter(cohort_stage_df, y="seq_stage_hpf_std", x="morph_stage_hpf_std",
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

ref_vec = np.linspace(0, 4)

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["morphological stage (hpf)", "transcriptional stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)

fig.show()

fig.write_image(fig_path + "seq_vs_morph_stage_noise.png", scale=2)
fig.write_html(fig_path + "seq_vs_morph_stage_noise.html")

### Use basic Arrhenius model to make predictions

In [None]:
from functools import partial
from scipy.optimize import least_squares

# dev_rate = A*exp(-E/(RT))
# R is 8.314
# from Toulany et al 2023: E=65.2
# let's fit A using 28C data

fit_temps = np.asarray([25, 28.5, 32])

def pd_arr(times, temps, params, R=8.314):
    arr_pd = np.multiply(times, params[0] * np.exp(-params[1] * (temps**-1) / R))
    return arr_pd
    
# def arrhenius_fit(params, times,  temps, stages_morph, stages_seq):
#     # R = 8.314
#     # inv_temps = temps**-1
#     stages_morph_pd = pd_arr(times=times, temps=temps, params=params[:2])# * np.exp(-params[1] * inv_temps / R)
#     stages_seq_pd = pd_arr(times=times, temps=temps, params=[params[0], params[2]])#params[0] * np.exp(-params[2] * inv_temps / R)

#     ds_vec = stages_seq.ravel() - stages_seq_pd.ravel()
#     dm_vec = stages_morph.ravel() - stages_morph_pd.ravel()
    
#     return np.hstack((ds_vec, dm_vec))

def arrhenius_fit(params, times,  temps, stages):
    
    stages_pd = pd_arr(times=times, temps=temps, params=params)# * np.exp(-params[1] * inv_temps / R)
    # stages_seq_pd = pd_arr(times=times, temps=temps, params=[params[0], params[2]])#params[0] * np.exp(-params[2] * inv_temps / R)

    diff_vec = stages_pd.ravel() - stages.ravel()
    
    return diff_vec

# get fit vectors
fit_filter = np.isin(hf_morphseq_df["temperature"], fit_temps)
temp_vec = hf_morphseq_df.loc[fit_filter, "temperature"].to_numpy() + 273.15
time_vec = hf_morphseq_df.loc[fit_filter, "timepoint"].to_numpy()
seq_stage_vec = hf_morphseq_df.loc[fit_filter, "seq_stage_hpf"].to_numpy()
morph_stage_vec = hf_morphseq_df.loc[fit_filter, "morph_stage_hpf"].to_numpy()

# fit
arr_fit_seq = partial(arrhenius_fit, temps=temp_vec, times=time_vec, stages=seq_stage_vec)
arr_fit_morph = partial(arrhenius_fit, temps=temp_vec, times=time_vec, stages=morph_stage_vec)
# arr_fit = partial(arrhenius_fit, temps=temp_vec, times=time_vec, stages_seq=seq_stage_vec, stages_morph=morph_stage_vec)

x0 = [1, 200]

res_seq = least_squares(arr_fit_seq, x0)
res_morph = least_squares(arr_fit_morph, x0)
# res_morph = least_squares(arr_fit_morph, x0)

In [None]:
arr_params_seq = res_seq.x
arr_params_morph = res_morph.x
time_vec = cohort_stage_df["timepoint"].to_numpy()
temp_vec = cohort_stage_df["temperature"].to_numpy() + 273.15
arr_prediction_morph = pd_arr(time_vec, temp_vec, arr_params_morph)
arr_prediction_seq = pd_arr(time_vec, temp_vec, arr_params_seq)

ref_vec = np.linspace(14, 50)

# make seq fig
fig = px.scatter(cohort_stage_df, y="seq_stage_hpf_mean", x=arr_prediction_seq, 
                 error_y="seq_stage_hpf_std",
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["predicted stage (Arrhenius)", "transcriptional stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)

fig.show()

fig.write_image(fig_path + "seq_vs_expected_stage.png", scale=2)
fig.write_html(fig_path + "seq_vs_expected_stage.html")

In [None]:
temps_to_show = np.asarray([28.5, 25, 32, 19, 33.5, 35])
arr_path = os.path.join(fig_path, "morph_temp_series","")
os.makedirs(arr_path, exist_ok=True)

for t in range(3, len(temps_to_show)+1):

    temp_filter = np.isin(cohort_stage_df["temperature"], temps_to_show[:t])
    # make seq fig
    fig = px.scatter(cohort_stage_df.loc[temp_filter], y="morph_stage_hpf_mean", x=arr_prediction_morph[temp_filter], 
                     error_y="morph_stage_hpf_std",
                     color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])
    
    fig.update_traces(error_y=dict(width=0))
    # fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))
    
    fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))
    
    axis_labels = ["predicted stage (Arrhenius)", "morphological stage (hpf)"]
    
    fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)
    
    fig.write_image(arr_path + f"morph_vs_expected_stage_{t:02}.png", scale=2)
    fig.write_html(arr_path + f"morph_vs_expected_stage_{t:02}.html")

fig.show()

### Replot morph shift metrics

In [None]:
cohort_stage_df["stage_shift_morph"] = cohort_stage_df["morph_stage_hpf_mean"] - cohort_stage_df["timepoint"]
cohort_stage_df["stage_shift_morph_pd"] = arr_prediction_morph - cohort_stage_df["timepoint"]

fig = px.scatter(cohort_stage_df, x="temperature", y="stage_shift_morph", 
                     error_y="morph_stage_hpf_std",
                     color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])
    
# fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))
# sym_list = ["circle", "diamond", "square"]


axis_labels = ["temperature (C)", "stage shift (δₜ)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)

fig.write_image(fig_path + f"morph_shift_vs_temp.png", scale=2)
fig.write_html(fig_path + f"morph_vs_temp.html")

for t, tp in enumerate([24, 30, 36]):#, 30, 36]):
    t_filter = cohort_stage_df["timepoint"]==tp
    if t == 1:
        lw = 3
    else:
        lw = 1
    fig.add_trace(go.Scatter(x=cohort_stage_df.loc[t_filter, "temperature"], 
                             y=cohort_stage_df.loc[t_filter, "stage_shift_morph_pd"], 
                             mode="lines", line=dict(color="white", width=lw, dash="dash"), showlegend=False))

fig.write_image(fig_path + f"morph_shift_vs_temp_pd.png", scale=2)
fig.write_html(fig_path + f"morph_vs_temp_pd.html")
fig.show()

In [None]:
ind = 0
fig = px.scatter(pd_df_mean, x="total_se_mean", y=null_df_mean["total_se_mean"], color="temperature", symbol="timepoint")
                # log_x=True, log_y=True)
fig.update_traces(marker=dict(size=8))
fig.update_layout(width=1000, height=800)
fig.update_xaxes(range=[0, 4])
fig.update_yaxes(range=[0, 4])
fig.show()

In [None]:
pca_cols_morph

In [None]:
mean_cols