#### This notebook looks at temperature-dependent changes to embryo morphology

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob
from src.functions.plot_functions import format_3d_plotly, rotate_figure, format_2d_plotly

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

# path to save data
read_path = os.path.join(root, "results", "20250312", "morph_latent_space", "")

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/"
fig_data_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/data/morph_metrics/"
os.makedirs(fig_path, exist_ok=True)
os.makedirs(fig_data_path, exist_ok=True)

In [None]:
import joblib

# load datasets
hf_pca_df = pd.read_csv(os.path.join(read_path, "hf_pca_morph_df.csv"))
ref_pca_df = pd.read_csv(os.path.join(read_path, "ab_ref_pca_morph_df.csv"))
spline_df = pd.read_csv(os.path.join(read_path, "spline_morph_df_full.csv"))
spline_df["knot_index"] = spline_df.index

# Save the model to a file
morph_stage_model = joblib.load(os.path.join(read_path, 'morph_stage_model.joblib'))

### Make 3D PCA plot for embryos from hotfish experiments

In [None]:
# set plot parameters
zoom_factor = 0.21
z_rotation = -30
elevation = 5
marker_size = 5

xrange = [-2.2, 2.5]
yrange = [-2.2, 1.6]
zrange = [-2.1, 1.2]

# make fig
fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="temperature", symbol="timepoint",
                    color_continuous_scale="RdBu_r", range_color=[17, 39], 
                    hover_data={"timepoint", "snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_all.png"), scale=2)
fig.write_html(os.path.join(fig_path, "hotfish_pca_all.html"))

### Make series of plots illustrating the location of different temperatures along the trajectory

In [None]:
from tqdm import tqdm 


temps_to_plot = np.asarray([28.5, 19, 25, 32, 33.5, 35])

for t in tqdm(range(0, len(temps_to_plot)+1)):

    
    
    if t > 0:
        opacity = 1
        t_filter = np.isin(hf_pca_df["temperature"], temps_to_plot[:t]) 
        # plot_df = hf_pca_df.loc[t_filter]
    else:
        opacity = 0
        t_filter = np.isin(hf_pca_df["temperature"], temps_to_plot[:1]) 
        # plot_df = dummy_df
        
    plot_df = hf_pca_df.loc[t_filter]
    
    fig = px.scatter_3d(plot_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                        color="temperature", symbol="timepoint", opacity=opacity,
                        color_continuous_scale="RdBu_r", range_color=[19, 38], 
                        hover_data={"timepoint", "snip_id"})

    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange

    pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)
    
    if t == 1:
        fig.update_traces(marker=dict(line=dict(color="black")))
    
    fig.add_traces(go.Scatter3d(x=hf_pca_df.loc[~t_filter, "PCA_00_bio"], 
                                y=hf_pca_df.loc[~t_filter, "PCA_01_bio"],
                                z=hf_pca_df.loc[~t_filter, "PCA_02_bio"], 
                                mode="markers", marker=dict(size=4, opacity=1, color="gray", line=dict(color="white", width=0.75)),
                                showlegend=False))
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    

    fig.write_image(os.path.join(fig_path, f"hotfish_pca_{t:02}.png"), scale=2)
    fig.write_html(os.path.join(fig_path, f"hotfish_pca_{t:02}.html"))
    
fig.show()



### Make rotating plot

In [None]:
angle_vec = np.linspace(0, 360, 50)

frame_path = os.path.join(fig_path, "hf_pca_rot_frames", "")
os.makedirs(frame_path, exist_ok=True)

# initialize figure
fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                        color="temperature", symbol="timepoint", opacity=opacity,
                        color_continuous_scale="RdBu_r", range_color=[19, 38], 
                        hover_data={"timepoint", "snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

if t == 0:
    fig.add_trace(dummy_trace)
elif t == 1:
    fig.update_traces(marker=dict(line=dict(color="black")))

# rotate
for a, angle in enumerate(tqdm(angle_vec)):
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    

    fig.write_image(os.path.join(frame_path, f"hotfish_pca_rot{a:02}.png"), scale=2)
    # fig.write_html(os.path.join(fig_path, f"hotfish_pca_rot{a:02}.html"))
    
fig.show()

### Problem: 28C is our control group, but we don't have stage-matching due to stage shifting
**Potential solution:** search for reference embryos from timelapse data that closely overlap with 28C, but which also extend out into later timepoints

In [None]:
from tqdm import tqdm

# plot that animages reference trajectories
ab_frame_path = os.path.join(fig_path, "hf_ab_pca_time_frames", "")
os.makedirs(ab_frame_path, exist_ok=True)

# zoom_factor = 1
# z_rotation = -10
# elevation = 5

t_lim_vec = np.linspace(12, 44, 50)

for t, t_lim in enumerate(tqdm(t_lim_vec)):

    fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                            color="temperature", symbol="timepoint", 
                            color_continuous_scale="RdBu_r", range_color=[19, 38], 
                            hover_data={"timepoint", "snip_id"})
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange
                       
    pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)
    
    # t_filter = ref_pca_df["timepoint"] <= t_lim
    embryo_index = np.unique(ref_pca_df["embryo_id"])
    for eid in embryo_index:
        e_filter = ref_pca_df["embryo_id"]==eid
        fig.add_traces(go.Scatter3d(x=ref_pca_df.loc[e_filter, "PCA_00_bio"], 
                                    y=ref_pca_df.loc[e_filter, "PCA_01_bio"], 
                                    z=ref_pca_df.loc[e_filter, "PCA_02_bio"], mode="lines", 
                                    opacity=0.15,  # Applies to the whole trace.
                                    line=dict(color='white'), #line=dict(color='rgba(100, 100, 100, 0.1)'), 
                                    showlegend=False ))
    # rotate
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    
    fig.write_image(os.path.join(ab_frame_path, f"hotfish_pca_ab_angle{a:02}.png"), scale=2)

fig.show()

# 

In [None]:
ab_rot_frame_path = os.path.join(fig_path, "hf_pca_ab_rot_frames", "")
os.makedirs(ab_rot_frame_path, exist_ok=True)

for a, angle in enumerate(tqdm(angle_vec)):

    fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                            color="temperature", symbol="timepoint", 
                            color_continuous_scale="RdBu_r", range_color=[19, 38], 
                            hover_data={"timepoint", "snip_id"})
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange
                       
    pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)
    
    t_filter = ref_pca_df["timepoint"] <= t_lim
    embryo_index = np.unique(ref_pca_df["embryo_id"])
    for eid in embryo_index:
        e_filter = ref_pca_df["embryo_id"]==eid
        fig.add_traces(go.Scatter3d(x=ref_pca_df.loc[e_filter & t_filter, "PCA_00_bio"], 
                                    y=ref_pca_df.loc[e_filter & t_filter, "PCA_01_bio"], 
                                    z=ref_pca_df.loc[e_filter & t_filter, "PCA_02_bio"], mode="lines", 
                                    opacity=0.15,  # Applies to the whole trace.
                                    line=dict(color='white'), #line=dict(color='rgba(100, 100, 100, 0.1)'), 
                                    showlegend=False ))
    # rotate
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    
    fig.write_image(os.path.join(ab_rot_frame_path, f"hotfish_pca_ab_time{a:02}.png"), scale=2)

fig.show()

### Use WT refrence spline to stage embryos and measure deviations from WT

In [None]:
plot_dims = np.asarray([0, 1, 2])
pca_cols = [col for col in ref_pca_df.columns if "PCA" in col]
plot_strings = [pca_cols[p] for p in plot_dims]

zoom_factor = 0.21
z_rotation = -30
elevation = 5
marker_size = 5

fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                            color="temperature", symbol="timepoint", 
                            color_continuous_scale="RdBu_r", range_color=[19, 38], 
                            hover_data={"timepoint", "snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

fig.add_traces(go.Scatter3d(x=spline_df[plot_strings[0]], y=spline_df[plot_strings[1]], z=spline_df[plot_strings[2]],
                           mode="lines", line=dict(color="white", width=5), name="reference curve", showlegend=False))


fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_with_spline.html"))

In [None]:
angle_vec = np.linspace(0, 45, 25)
spline_frame_path = os.path.join(fig_path, "hf_morph_spline_rot_frames", "")
os.makedirs(spline_frame_path, exist_ok=True)
# plot_options = np.arange(hf_pca_df.shape[0])
# np.random.seed(371)
# n_plot = 10
# plot_indices = np.random.choice(plot_options, n_plot, replace=False)

fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                            color="temperature", symbol="timepoint", 
                            color_continuous_scale="RdBu_r", range_color=[19, 38], 
                            hover_data={"timepoint", "snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

fig.add_traces(go.Scatter3d(x=spline_df[plot_strings[0]], y=spline_df[plot_strings[1]], z=spline_df[plot_strings[2]],
                           mode="lines", line=dict(color=spline_df["mdl_stage_hpf"], 
                            cmin=10, cmax=44, width=7, colorscale="BuPu_r"), 
                            name="reference curve", showlegend=False))

# fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + 45, elev_rotation=elevation)
# # rotate
for a, angle in enumerate(tqdm(angle_vec)):
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    

    fig.write_image(os.path.join(spline_frame_path, f"hotfish_pca_rot{a:02}.png"), scale=2)
    # fig.write_html(os.path.join(fig_path, f"hotfish_pca_rot{a:02}.html"))
    
fig.show()

In [None]:
# from scipy.spatial import distance_matrix

# dist_mat = distance_matrix(hf_pca_df[pca_cols].values, spline_df[pca_cols].values)
# hf_pca_df["morph_dist_spline"] = np.min(dist_mat, axis=1)

# fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
#                             color="spline_stage_hpf", 
#                             color_continuous_scale="BuPu_r", range_color=[12, 44], 
#                             hover_data={"timepoint", "snip_id"}, labels={"spline_stage_hpf":"spline stage (hpf)"})

# fig.layout.scene.xaxis.range = xrange
# fig.layout.scene.yaxis.range = yrange
# fig.layout.scene.zaxis.range = zrange

# pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
# fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

# fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

# fig.add_traces(go.Scatter3d(x=spline_df[plot_strings[0]], y=spline_df[plot_strings[1]], z=spline_df[plot_strings[2]],
#                            mode="lines", line=dict(color=spline_df["mdl_stage_hpf"], 
#                             cmin=10, cmax=44, width=7, colorscale="BuPu_r"), 
#                             name="reference curve", showlegend=False))

# # fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + 45, elev_rotation=elevation)
# # # rotate

    
# fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle_vec[-1], elev_rotation=elevation)


# fig.write_image(os.path.join(fig_path, f"morph_spline_stage_pca_scatter.png"), scale=2)
# fig.write_html(os.path.join(fig_path, f"morph_spline_stage_pca_scatter.html"))
    
# fig.show()


In [None]:
fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", color="morph_dist_spline", 
                            color_continuous_scale="magma", range_color=[0, 1.5], 
                            hover_data={"timepoint", "snip_id"}, labels={"morph_dist_spline":"dist from spline"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

fig.add_traces(go.Scatter3d(x=spline_df[plot_strings[0]], y=spline_df[plot_strings[1]], z=spline_df[plot_strings[2]],
                           mode="lines", line=dict(color=spline_df["mdl_stage_hpf"], 
                            cmin=10, cmax=44, width=7, colorscale="BuPu_r"), 
                            name="reference curve", showlegend=False))

# fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + 45, elev_rotation=elevation)
# # rotate

    
fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle_vec[-1], elev_rotation=elevation)


fig.write_image(os.path.join(fig_path, f"morph_spline_dist_pca_scatter.png"), scale=2)
fig.write_html(os.path.join(fig_path, f"morph_spline_dist_pca_scatter.html"))
    
fig.show()

### This works well for small deviations, but breaks down when things get really weird

In [None]:
from scipy.spatial import distance_matrix

# get closest spline ID
dist_mat = distance_matrix(hf_pca_df[pca_cols].values, spline_df[pca_cols].values)
nn_i = np.argmin(dist_mat, axis=1)
hf_pca_df["spline_stage_hpf"] = spline_df.loc[nn_i, "mdl_stage_hpf"].to_numpy()

# calculate linear model prediction
hf_pca_df["lin_mdl_pd"] = 6 + (hf_pca_df["timepoint"]-6)*(0.055*hf_pca_df["temperature"]-0.57)

In [None]:
from src.functions.plot_functions import format_2d_plotly
colormap = "RdBu_r"
ref_vec = np.linspace(14, 50)
marker_size = 14

fig = px.scatter(hf_pca_df, x="lin_mdl_pd", y="spline_stage_hpf", 
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])


fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["expected stage (hpf)", "spline stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=18)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "morph_spline_stage_scatter.png", scale=2)
fig.write_html(fig_path + "morph_spline_stage_scatter.html")

In [None]:
from src.functions.plot_functions import format_2d_plotly
colormap = "RdBu_r"
ref_vec = np.linspace(14, 50)
marker_size = 14

fig = px.scatter(hf_pca_df, x="lin_mdl_pd", y="mdl_stage_hpf", 
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])


fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["expected stage (hpf)", "spline stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=18)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "morph_surf_stage_scatter.png", scale=2)
fig.write_html(fig_path + "morph_surf_stage_scatter.html")

In [None]:
fig = px.scatter(hf_pca_df, x="spline_stage_hpf", y="mdl_stage_hpf", 
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])


fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["spline stage (hpf)", "surf stage (hpf)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "morph_spline_vs_surf_stage_scatter.png", scale=2)
fig.write_html(fig_path + "morph_spline_vs_surf_stage_scatter.html")

### Use the surface fit to look at temporal and morphological shifts and noise

In [None]:
pca_cols = [col for col in hf_pca_df.columns if "PCA" in col]



# get cohort averages
hf_cohort_df = hf_pca_df.loc[:, ["timepoint", "temperature", "mdl_stage_hpf"] + pca_cols].groupby(
                    ["timepoint", "temperature"]).agg(["mean", "std"]).reset_index()

hf_cohort_df.columns = [f"{col[0]}_{col[1]}" for col in hf_cohort_df.columns]
hf_cohort_df = hf_cohort_df.rename(columns={"timepoint_":"timepoint", "temperature_":"temperature"})
hf_cohort_df = hf_cohort_df.reset_index()

hf_cohort_df["stage_shift_hpf"] = hf_cohort_df["mdl_stage_hpf_mean"] - hf_cohort_df["timepoint"]

# use average stage to calculate morphological shift
dist_mat_stage = distance_matrix(hf_cohort_df["mdl_stage_hpf_mean"].values[:, None], spline_df["mdl_stage_hpf"].values[:, None])
nn_i_stage = np.argmin(dist_mat_stage, axis=1)
hf_cohort_df["knot_index"] = nn_i_stage
spline_ref_df = spline_df.loc[nn_i_stage, :].reset_index(drop=True)


# mean_cols = [col + "_mean" for col in pca_cols]
# pca_shift_cols = [col + "_shift" for col in pca_cols]
# hf_cohort_df[pca_shift_cols] = hf_cohort_df[mean_cols].values - spline_ref_df[pca_cols].values
# hf_cohort_df["morph_shift"] = np.sqrt(np.sum((hf_cohort_df[pca_shift_cols])**2,axis=1))

hf_cohort_df.head()       

### Use JAX to calculate developmental flux magnitude and direction

In [None]:
import jax
import jax.numpy as jnp

def make_jax_functions(model):

    # Extract the PolynomialFeatures transformer and LinearRegression estimator.
    poly = model.named_steps['poly']
    linear = model.named_steps['linear']
    
    # Extract the exponents for each polynomial term. This is an (m x d) array.
    powers = jnp.array(poly.powers_)
    
    # Extract the coefficients and intercept from the linear model.
    theta = jnp.array(linear.coef_)
    intercept = jnp.array(linear.intercept_)
    
    # def predict(x, theta, intercept):
    #     """
    #     Computes predictions for a batch of inputs.
    #     x: (n_samples x d) input array.
    #     Returns an array of shape (n_samples,) with the model's predictions.
    #     """
    #     # Compute polynomial features: for each sample, raise the input to each power
    #     # and take the product across features. The result is an (n_samples x m) array.
    #     poly_features = jnp.prod(jnp.power(x[:, None, :], powers), axis=2)
    #     return jnp.dot(poly_features, theta) + intercept

    # def loss_fn(params):
    #     """
    #     Computes the mean-squared error loss on the dataset (X, y) given model parameters.
    #     params: tuple (theta, intercept)
    #     """
    #     preds = predict(X, params[0], params[1])
    #     return jnp.mean((preds - y) ** 2)
    
    def predict_single(x, theta, intercept):
        """
        A helper function that computes the prediction for a single input sample.
        x: (d,) array.
        Returns a scalar prediction.
        """
        # For a single sample, x has shape (d,). 
        # The polynomial features are computed by raising x to each power in 'powers' 
        # (which has shape (m, d)) and taking the product over the d features.
        poly_features = jnp.prod(jnp.power(x, powers), axis=1)
        return jnp.dot(poly_features, theta) + intercept

    def predict_and_grad(params, X_new):
        """
        Given parameters (theta, intercept) and a new set of input data X_new,
        returns:
          - preds: the predictions for each input in X_new,
          - grads: the gradient of the scalar prediction function with respect to the input,
                   evaluated at each sample in X_new.
        """
        # Define a function of a single sample.
        f = lambda x: predict_single(x, params[0], params[1])
        # Compute the gradient of f with respect to the input x.
        grad_f = jax.grad(f)
        # Vectorize both the function and its gradient over the batch dimension.
        preds = jax.vmap(f)(X_new)
        grads = jax.vmap(grad_f)(X_new)
        return preds, grads

    return predict_and_grad, (theta, intercept)

In [None]:
predict_and_grad, params = make_jax_functions(morph_stage_model)

In [None]:
from scipy.spatial import distance_matrix
from tqdm import tqdm
sd_pca_cols = [col +"_std" for col in pca_cols]
mean_cols = [col + "_mean" for col in pca_cols]

for row in tqdm(range(hf_cohort_df.shape[0])):
    
    # mean and variance of each morph coordinate
    pca_mu = hf_cohort_df.loc[row, mean_cols].to_numpy() # morph mean
    pca_var = np.diag(hf_cohort_df.loc[row, sd_pca_cols].to_numpy()**2) # morph std

    # pull point positions (needed for gradient calc)
    timepoint = hf_cohort_df.loc[row, "timepoint"]
    temperature = hf_cohort_df.loc[row, "temperature"]
    obs_filter = (hf_pca_df["timepoint"]==timepoint) & (hf_pca_df["temperature"]==temperature)
    pca_obs = hf_pca_df.loc[obs_filter, pca_cols].values
    
    # spline knot index
    knot_i = hf_cohort_df.loc[row, "knot_index"]
    pca_ref = spline_df.loc[spline_df["knot_index"]==knot_i, pca_cols].to_numpy() # stage-matched comparison

    # get phenotypic distance
    hf_cohort_df.loc[row, "morph_shift"] = np.sqrt(np.sum((pca_mu - pca_ref)**2))

    # record total variance
    hf_cohort_df.loc[row, "total_variance"] = np.sum(pca_var)
    hf_cohort_df.loc[row, "total_sigma"] = np.sqrt(np.sum(pca_var))

    # use gradient to decompose variance
    t_var_vec = []
    for o in range(pca_obs.shape[0]):
        stage_pd, grad_pd = predict_and_grad(params, pca_obs[o,:][None, :])
        grad_u = np.asarray(grad_pd / np.sqrt(np.sum(grad_pd**2)))[0]
        t_var_vec.append(grad_u @ pca_var @ grad_u.T)
    # var_null = 0
    # for n in range(100):
    #     rand_u = np.random.permutation(grad_u.copy())
    #     var_null += np.dot(rand_u, pca_obs_var)

    hf_cohort_df.loc[row, "stage_variance"] = np.mean(t_var_vec)
    hf_cohort_df.loc[row, "stage_sigma"] = np.sqrt(hf_cohort_df.loc[row, "stage_variance"])
    # hf_cohort_df.loc[row, "stage_cv"] = np.divide(hf_cohort_df.loc[row, "stage_mdl_hpf_mean"]
    
    # hf_cohort_df.loc[row, "stage_variance_null"] = var_null/100
    
    hf_cohort_df.loc[row, "morph_variance"] = hf_cohort_df.loc[row, "total_variance"] - hf_cohort_df.loc[row, "stage_variance"]
    hf_cohort_df.loc[row, "morph_sigma"] = np.sqrt(hf_cohort_df.loc[row, "morph_variance"])

In [None]:
# fig = px.scatter(hf_cohort_df, x="temperature", y="morph_shift", color="temperature")
# fig.update_traces(marker=dict(size=8))
# fig.show()
range_color=[18, 38]
var_list = ["morph_shift", "stage_shift_hpf", "morph_sigma", "mdl_stage_hpf_std"]
err_list = ["total_sigma", "mdl_stage_hpf_std", "", ""]
ylb_list = [r'morphological shift (δₘ)', 'stage shift (δₜ)', 
                    'morphological noise (εₘ)', 'stage noise (εₜ)']

for i in range(len(var_list)):
    var = var_list[i]
    err = err_list[i]
    ylb = ylb_list[i]

    if False:  #len(err) > 0:
        fig = px.scatter(hf_cohort_df, x="temperature", y=var, error_y=err,
                         color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=range_color)
        
    else:
        fig = px.scatter(hf_cohort_df, x="temperature", y=var, 
                         color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=range_color)
    
    # fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))
    
    axis_labels = ["temperature (C)", ylb]
    
    fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)

    # fig.update_yaxes(title_text=ylb)
    # fig.update_traces(error_y=dict(width=3))
    
    
    fig.write_image(fig_path + f"morph_model_{var}.png", scale=2)
    fig.write_html(fig_path + f"morph_model_{var}.html")

fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="mdl_stage_hpf_std", y="morph_sigma", 
                         color="temperature", symbol="timepoint",
                         color_continuous_scale=colormap, range_color=range_color)
    
# fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ['stage noise (εₜ)', 'morphology noise (εₘ)']

fig = format_2d_plotly(fig, marker_size=20, axis_labels=axis_labels, font_size=20)

fig.show()

### Make 3D plot of averages with error bars

In [None]:
# Function to add error bars using the respective error arrays
def add_error_bars(fig, x, y, z, err_x, err_y, err_z):
    for xi, yi, zi, ex, ey, ez in zip(x, y, z, err_x, err_y, err_z):
        # X error bar: x from xi - ex to xi + ex, constant y and z.
        fig.add_trace(go.Scatter3d(
            x=[xi - ex, xi + ex],
            y=[yi, yi],
            z=[zi, zi],
            mode='lines',
            line=dict(color='gray', width=2),
            showlegend=False
        ))
        # Y error bar: y from yi - ey to yi + ey, constant x and z.
        fig.add_trace(go.Scatter3d(
            x=[xi, xi],
            y=[yi - ey, yi + ey],
            z=[zi, zi],
            mode='lines',
            line=dict(color='gray', width=2),
            showlegend=False
        ))
        # Z error bar: z from zi - ez to zi + ez, constant x and y.
        fig.add_trace(go.Scatter3d(
            x=[xi, xi],
            y=[yi, yi],
            z=[zi - ez, zi + ez],
            mode='lines',
            line=dict(color='gray', width=2),
            showlegend=False
        ))

### Plot polynomial surface
Let's experiment with fitting derivatives so we can utilize experimental clock time

In [None]:
import umap

np.random.seed(42)
umap_model = umap.UMAP(n_components=2)

# Compute the embedding
umap_model.fit(ref_pca_df[pca_cols].values)
embedding = umap_model.transform(ref_pca_df[pca_cols].values)
hf_embedding = umap_model.transform(hf_pca_df[pca_cols].values)

full_embedding = np.vstack((embedding, hf_embedding))
full_pca = np.vstack((ref_pca_df[pca_cols].values, hf_pca_df[pca_cols].values))

In [None]:
# from scipy.interpolate import griddata
# from scipy.spatial import distance_matrix

# # Create a grid over the domain of your data.
# x=ref_pca_df[pca_cols].to_numpy()[:, 0]
# y=ref_pca_df[pca_cols].to_numpy()[:, 1] #full_embedding[:, 1]
# z=morph_stage_model.predict(full_pca) 

# # fig = px.scatter_3d(x=x, y=y, z=z, color=z)
# # fig.show()
# grid_x = np.linspace(0.9*x.min(), 1.1*x.max(), 100)
# grid_y = np.linspace(0.9*y.min(), 1.1*y.max(), 100)
# grid_x, grid_y = np.meshgrid(grid_x, grid_y)

# xy_long = np.c_[grid_x.ravel()[:, None], grid_y.ravel()[:, None]]
# dist_vec = np.min(distance_matrix(xy_long, ref_pca_df[pca_cols].to_numpy()[:, :2]), axis=1)

In [None]:
from scipy.interpolate import griddata
from scipy.spatial import distance_matrix

# Create a grid over the domain of your data.
x=full_embedding[:, 0]
y=full_embedding[:, 1]
z=morph_stage_model.predict(full_pca) 

# fig = px.scatter_3d(x=x, y=y, z=z, color=z)
# fig.show()
grid_x = np.linspace(0.9*x.min(), 1.1*x.max(), 100)
grid_y = np.linspace(0.9*y.min(), 1.1*y.max(), 100)
grid_x, grid_y = np.meshgrid(grid_x, grid_y)

xy_long = np.c_[grid_x.ravel()[:, None], grid_y.ravel()[:, None]]
dist_vec = np.min(distance_matrix(xy_long, full_embedding), axis=1)
# Interpolate the scattered data onto the grid.
# grid_z = griddata(points=(x, y), values=z, xi=(grid_x, grid_y), method='cubic')

# grid_x.shape
# Create the surface plot.
# fig = go.Figure(data=[go.Surface(z=grid_z, x=grid_x, y=grid_y)])
# fig.update_layout(title="3D Surface from Scattered Data", scene=dict(
#                     xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
# fig.show()

In [None]:
from scipy.ndimage import gaussian_filter
# px.histogram(dist_vec)
dist_thresh = 1.5
dist_mat = dist_vec.reshape(100, 100)
grid_z = griddata(points=(x, y), values=z, xi=(grid_x, grid_y), method='nearest')
grid_z_smoothed = gaussian_filter(grid_z, sigma=2, mode="nearest")
grid_z_smoothed[dist_mat>dist_thresh] = np.nan

hf_pca_df["mdl_stage_plot"] = hf_pca_df["mdl_stage_hpf"].copy() 
# Create the surface plot.



# fig.write_image(os.path.join(fig_path, "ab_developmental_surface.png"))
# fig.write_html(os.path.join(fig_path, "ab_developmental_surface.html"))

In [None]:
zoom_factor = 0.95
z_rotation = 275
elevation = 10

xrange = [-4, 15]
yrange = [-5, 15.5]
zrange = [60, 10]

fig = px.scatter_3d(x=hf_embedding[:, 0], y=hf_embedding[:, 1], z=hf_pca_df["mdl_stage_plot"], color=hf_pca_df["temperature"],
                   symbol=hf_pca_df["timepoint"], color_continuous_scale="RdBu_r", range_color=range_color)

fig = format_3d_plotly(fig, axis_labels=["morph 1", "morph 2", "stage (hpf)"], aspectmode="cube", show_gridlines=True)

fig.add_trace(go.Surface(z=grid_z_smoothed, x=grid_x, y=grid_y, opacity=0.5, 
                         colorscale="Purples", showlegend=False, showscale=False))


fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

# fig.update_layout(
#     scene=dict(
#         zaxis=dict(
#             autorange='reversed'
#         )
#     )
# )

fig.show()

fig.write_image(fig_path + f"morph_surface_dope.png", scale=2)
fig.write_html(fig_path + f"morph_surface_dope.html")

In [None]:
t_lim_vec = np.linspace(12, 48, 50)
surf_frame_path = os.path.join(fig_path, "hf_morph_surf_frames", "")
os.makedirs(surf_frame_path, exist_ok=True)

for t, t_lim in enumerate(tqdm(t_lim_vec)):

    t_filter = hf_pca_df["mdl_stage_hpf"] <= t_lim 
    if np.sum(t_filter) == 0:
        opacity = 0
        t_filter = hf_pca_df["mdl_stage_hpf"] <= np.inf
    else:
        opacity = 1
        # t_filter = hf_pca_df["mdl_stage_hpf"] <= t_lim 
    
    fig = px.scatter_3d(x=hf_embedding[t_filter, 0], y=hf_embedding[t_filter, 1], 
                        z=hf_pca_df.loc[t_filter, "mdl_stage_plot"], color=hf_pca_df.loc[t_filter, "temperature"], opacity=opacity,
                       symbol=hf_pca_df.loc[t_filter, "timepoint"], color_continuous_scale="RdBu_r", range_color=range_color)
    
    fig = format_3d_plotly(fig, axis_labels=["morph 1", "morph 2", "stage (hpf)"], aspectmode="cube", show_gridlines=True)

    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange

    fig.add_trace(go.Surface(z=grid_z_smoothed, x=grid_x, y=grid_y, opacity=0.5, 
                             colorscale="Purples", showlegend=False, showscale=False))
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    
#     fig.update_layout(
#         scene=dict(
#             zaxis=dict(
#                 autorange='reversed'
#             )
#         )
# )
    # rotate
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    
    fig.write_image(os.path.join(surf_frame_path, f"hotfish_pca_ab_angle{t:02}.png"), scale=2)

fig.show()

### Calculate mean and standard deviation in embryo morphology

In [None]:
np.sum(hf_pca_df["mdl_stage_hpf"]>48)

In [None]:
pca_cols = [col for col in hf_pca_df.columns if "PCA" in col]

hf_cohort_df = hf_pca_df.loc[:, ["timepoint", "temperature", "mdl_stage_hpf"] + pca_cols].groupby(
                    ["timepoint", "temperature"]).agg(["mean", "std"]).reset_index()
hf_cohort_df.columns.values
hf_cohort_df.columns = ['_'.join(map(str, col)).strip() for col in hf_cohort_df.columns.values]
hf_cohort_df.head()       

In [None]:
plot_dims = np.asarray([0, 1, 2])
mean_pca_cols = [col +"_mean" for col in pca_cols]
plot_strings = [mean_pca_cols[p] for p in plot_dims]

fig = px.scatter_3d(hf_cohort_df, x=plot_strings[0], y=plot_strings[1], z=plot_strings[2], opacity=1,
                    color="temperature_", hover_data={"timepoint_"})

fig.update_traces(marker=dict(size=5, showscale=False))

fig.add_traces(go.Scatter3d(x=spline_df[plot_strings[0][:-5]], y=spline_df[plot_strings[1][:-5]], z=spline_df[plot_strings[2][:-5]],
                           mode="lines", line=dict(color="black", width=4), name="reference curve"))

# fig.add_traces(go.Scatter3d(x=[P2[0]], y=[P2[1]], z=[P2[2]], mode="markers"))

# fig.add_traces(se_mesh)

fig.show()

fig.write_image(os.path.join(fig_path, "avg_hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "avg_hotfish_pca_with_spline.html"))

### Use JAX to generate predicted developmental gradients at each point in latent space

### Calculate stage and morphological deltas

In [None]:
# get stage shift
hf_cohort_df["stage_hpf_mean"] = model.predict(hf_cohort_df[mean_pca_cols].values)
hf_cohort_df["stage_shift_hpf"] = hf_cohort_df["stage_hpf_mean"] - hf_cohort_df["timepoint_"]

predict_and_grad, params = make_jax_functions(morph_stage_model)

In [None]:
np.random.permutation(grad_u)

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="morph_shift", color="temperature_")
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="stage_shift_hpf", color="temperature_")
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="morph_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="stage_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="total_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="morph_variance", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            xaxis=dict(range=[0, 1.7]), 
            yaxis=dict(range=[0, 1.7])
        )
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="stage_variance_null", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            xaxis=dict(range=[0, 0.5]), 
            yaxis=dict(range=[0, 0.5])
        )
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="mdl_stage_hpf_std", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            # xaxis=dict(range=[0, 0.5]), 
            # yaxis=dict(range=[0, 0.5])
        )
fig.show()

### Make figure showing images for sanity check purposes

In [None]:
import skimage.io as io

image_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/training_data/20241107_ds/images/0/"
hf_snip_vec = hf_umap_df["snip_id"].to_numpy()
hf_time_vec = hf_umap_df["timepoint"].to_numpy()
hf_temp_vec = hf_umap_df["temperature"].to_numpy()
image_list = []
for snip_id in hf_snip_vec:
    im = io.imread(os.path.join(image_path, snip_id + ".jpg"))
    image_list.append(im)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

image_path = os.path.join(fig_path, "cohort_images", "")
os.makedirs(image_path, exist_ok=True)
im_shape = image_list[0].shape 

for time in np.unique(hf_time_vec):
    for temp in np.unique(hf_temp_vec):
        obs_indices = np.where((hf_time_vec==time) & (hf_temp_vec==temp))[0]
        
        # fig = go.Figure() # make_subplots(rows=2, cols=4)
        
        # Add each image to a subplot
        top_list = []
        bottom_list = []
        for i in range(8):
            if len(obs_indices) > i:
                im = image_list[obs_indices[i]]
            else:
                im = np.zeros(im_shape, dtype=np.uint8)
                
            if i < 4:
                top_list.append(im)
            else:
                bottom_list.append(im)

        tiled_image = np.block([top_list,
                                bottom_list])
        
        fig = px.imshow(tiled_image, color_continuous_scale="gray", title=f"{temp:02}C @{time:02}hpf")

        
        # Update layout for better display
        # fig.update_layout(
        #     height=600,
        #     width=1200,
        #     title_text="Multiple Images in Plotly"
        # )
        
        fig.write_image(image_path + f"embryo_images_tp{time:02}_temp{temp:02}.png", engine="kaleido")

# fig.show()

### Is it possible to fit to the derivatives?

In [None]:
# get point-over-point differences
cols_to_diff = pca_cols + ["experiment_time"]
diff_cols = [col + "_diff" for col in cols_to_diff]
dt_cols = [col + "_dt" for col in cols_to_diff]
ref_umap_df_dt = ref_umap_df.copy()
ref_umap_df_dt[diff_cols] = ref_umap_df_dt.groupby('embryo_id')[cols_to_diff].diff()
ref_umap_df_dt = ref_umap_df_dt.fillna(method='bfill') 

# we want to calculate the rate of time changes wrpt 
ref_umap_df_dt[dt_cols[:-1]] = np.divide(ref_umap_df_dt[diff_cols[-1]].values[:, None], ref_umap_df_dt[diff_cols[:-1]].values)

In [None]:
# Suppose we have K measurement points in an N-dimensional space.
# D_data: (K, N) array of points.
# G_data: (K, N) array of measured gradients at those points.
# d: polynomial degree

def multiindex_list(N, d):
    # Generate list of multi-indices (tuples) for N dimensions up to degree d.
    # This is a helper function; many implementations exist.
    indices = []
    def rec(current, start, remaining):
        if remaining == 0:
            indices.append(tuple(current))
        else:
            for i in range(start, N):
                new_current = current.copy()
                new_current[i] += 1
                rec(new_current, i, remaining-1)
    # Include all degrees from 0 up to d
    for degree in range(d+1):
        # Initialize multi-index with zeros
        base = [0]*N
        # Recursively fill in
        rec(base, 0, degree)
    return indices

def build_A(D_data):
    for k in range(len(D_data)):
        Dk = D_data[k]  # shape (N,)
        for j in range(len(Dk)):
            row = []
            for alpha in multiindices:
                # For the derivative with respect to D_j,
                # the coefficient is: alpha[j] * Dk^(alpha - e_j)
                # If alpha[j] == 0, this term is zero.
                if alpha[j] == 0:
                    row.append(0.0)
                else:
                    # Compute Dk^(alpha - e_j)
                    term = 1.0
                    for i in range(N):
                        exponent = alpha[i] - (1 if i == j else 0)
                        term *= Dk[i]**exponent if exponent > 0 else 1.0
                    row.append(alpha[j] * term)
            A.append(row)
            
    return np.array(A)

def build_b(G_data):
    for k in range(G_data.shape[0]):
        for j in range(G_data.shape[1]):
            b.append(G_data[k, j])
            
    return np.array(b)

def evaluate_polynomial_array(D, multiindices, c):
    """
    Evaluate the polynomial at multiple points.
    
    Parameters:
    - D: numpy array of shape (M, N) where each row is an N-dimensional input.
    - multiindices: list of tuples, each tuple being the exponents for one term.
    - c: numpy array of coefficients corresponding to each multi-index.
    
    Returns:
    - predictions: numpy array of shape (M,) with the computed polynomial values.
    """
    D = np.asarray(D)  # Ensure D is a numpy array
    M, N = D.shape
    predictions = np.zeros(M)
    
    for coeff, alpha in zip(c, multiindices):
        # Compute the term D^alpha for each point.
        # Convert alpha to an array to enable broadcasting.
        alpha_array = np.array(alpha)
        # For each point, compute the product of each dimension raised to the corresponding power.
        term = coeff * np.prod(D ** alpha_array, axis=1)
        predictions += term
    return predictions

In [None]:
N = len(pca_cols)
d = 2  # for example, quadratic polynomial

# Get multi-index list for polynomial basis.
multiindices = multiindex_list(N, d)
num_terms = len(multiindices)

# Build design matrix A and measurement vector b.
# There will be K * N equations (each derivative component).
A = []
b = []
D_data = ref_umap_df_dt[pca_cols].to_numpy()
G_data = ref_umap_df_dt[dt_cols[:-1]].to_numpy()  

A = build_A(D_data)
b = build_b(G_data)

# Solve the least squares problem
c, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)

In [None]:
prediction = evaluate_polynomial_array(D_data, multiindices, c) / 3600

In [None]:


fig = px.scatter(x=ref_umap_df["predicted_stage_hpf"], y=prediction)
fig.show()