#### This notebook looks at temperature-dependent changes to embryo morphology

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob
from src.functions.plot_functions import format_3d_plotly, rotate_figure, format_2d_plotly

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

# path to save data
read_path = os.path.join(root, "results", "20250312", "morph_latent_space", "")

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250513/morph_metrics/"
fig_data_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250513/data/morph_metrics/"
os.makedirs(fig_path, exist_ok=True)
os.makedirs(fig_data_path, exist_ok=True)

In [None]:
import joblib

# load datasets
hf_pca_df = pd.read_csv(os.path.join(read_path, "hf_pca_morph_df.csv"))
ref_pca_df = pd.read_csv(os.path.join(read_path, "ab_ref_pca_morph_df.csv"))
spline_df = pd.read_csv(os.path.join(read_path, "spline_morph_df_full.csv"))
spline_df["knot_index"] = spline_df.index

# Save the model to a file
morph_stage_model = joblib.load(os.path.join(read_path, 'morph_stage_model.joblib'))

## Plot ref trajectories

In [None]:
from tqdm import tqdm 

angle_vec = np.linspace(0, 360, 50)
t_lim_vec = np.linspace(12, 48, 50)
marker_size = 4

frame_path = os.path.join(fig_path, "ref_pca_rot_frames", "")
os.makedirs(frame_path, exist_ok=True)


# set plot parameters
zoom_factor = 0.21
z_rotation = -30 + 102
elevation = -10
marker_size = 3

xrange = [-2.2, 2.5]
yrange = [-2.7, 1.6]
zrange = [-2.1, 1.7]

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
# make fig


# fig.show()
# for a, angle in enumerate(tqdm(angle_vec)):
    
    
#     fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    

#     fig.write_image(os.path.join(frame_path, f"hotfish_pca_rot{a:02}.png"), scale=2)
#     # fig.write_html(os.path.join(fig_path, f"hotfish_pca_rot{a:02}.html"))
    
# fig.show()
# fig.write_image(os.path.join(fig_path, "ref_pca_all.png"), scale=2)
# fig.write_html(os.path.join(fig_path, "ref_pca_all.html"))


# surf_frame_path = os.path.join(fig_path, "hf_morph_surf_frames", "")
# os.makedirs(surf_frame_path, exist_ok=True)

for t, t_lim in enumerate(tqdm(t_lim_vec)):
    angle = angle_vec[t]
    
    t_filter = ref_pca_df["mdl_stage_hpf"] <= t_lim 
    # if np.sum(t_filter) == 0:
    #     opacity = 0
    #     t_filter = ref_pca_df["mdl_stage_hpf"] <= np.inf
    # else:
    #     opacity = 1
        # t_filter = hf_pca_df["mdl_stage_hpf"] <= t_lim 

    fig = px.scatter_3d(ref_pca_df.loc[t_filter], x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="mdl_stage_hpf", opacity=1, range_color=[10, 42],
                    hover_data={"snip_id"})

    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange

    
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, aspectmode="manual",theme="light", 
                           marker_size=marker_size, marker_edge=True)

    grid_dict = dict(
          tickmode="array",     # custom list
          tickvals=[-2, -1, 0, 1, 2],
          ticktext=["-2","-1","0","1","2"],
          showgrid=True)
    
    fig.update_layout(
      scene=dict(
        xaxis=grid_dict,
        yaxis=grid_dict,
        zaxis=grid_dict
      ))
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)

    fig.update_layout(
          # coloraxis_showscale=False,                       # hide
          # or:
          coloraxis_showscale=True,
          coloraxis_colorbar=dict(title="embryo stage (hpf)")     # relabel
        )
    
    fig.write_image(os.path.join(frame_path, f"ref_pca_ab_angle{t:02}.png"), scale=2)

fig.show()

### Now add lines for individual embryo

In [None]:
start_stage = 12
stop_stage = 38
# get list of embryo IDs
embryo_df = ref_pca_df.loc[:, [ "embryo_id", "mdl_stage_hpf"]].groupby(
                        [ "embryo_id"])["mdl_stage_hpf"].agg(["min", "max"]).reset_index()

stage_filter = (embryo_df["min"] <= start_stage) & (embryo_df["max"] >= stop_stage)

# indices_to_plot = np.asarray([1, 4, 7])
# embryo_id_index = np.unique(ref_pca_df["embryo_id"])
embryo_ids_to_plot = embryo_df.loc[stage_filter, "embryo_id"].to_numpy()
n_plot = 3
# np.sum(stage_filter)

In [None]:
fig = px.scatter_3d(ref_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="mdl_stage_hpf", opacity=.1, range_color=[10, 42],
                    hover_data={"snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, aspectmode="manual",theme="light", marker_size=3, marker_edge=False)

grid_dict = dict(
          tickmode="array",     # custom list
          tickvals=[-2, -1, 0, 1, 2],
          ticktext=["-2","-1","0","1","2"],
          showgrid=True)
    
fig.update_layout(
      scene=dict(
        xaxis=grid_dict,
        yaxis=grid_dict,
        zaxis=grid_dict
      ))
    
fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)

fig.update_layout(
          # coloraxis_showscale=False,                       # hide
          # or:
          coloraxis_showscale=True,
          coloraxis_colorbar=dict(title="embryo stage (hpf)")     # relabel
        )

fig.write_image(os.path.join(fig_path, f"embryo_lines{0:02}.png"), scale=2)
for e, embryo_id in enumerate(embryo_ids_to_plot[:n_plot]):
    plot_filter = ref_pca_df["embryo_id"] == embryo_id
    fig.add_traces(go.Scatter3d(x=ref_pca_df.loc[plot_filter, "PCA_00_bio"], 
                     y=ref_pca_df.loc[plot_filter, "PCA_01_bio"], 
                     z=ref_pca_df.loc[plot_filter, "PCA_02_bio"],
                     mode="markers+lines", marker=dict(size=5), line=dict(width=2),
                               showlegend=False))
    
    fig.write_image(os.path.join(fig_path, f"embryo_lines{e+1:02}.png"), scale=2)

fig.show()

### Flux plot

### Make 3D PCA plot for embryos from hotfish experiments

### Make series of plots illustrating the location of different temperatures along the trajectory

In [None]:
embryo_index = np.unique(ref_pca_df["embryo_id"])
n_plot = 75

In [None]:
from tqdm import tqdm 
z_rotation = 72
elevation = -10


temps_to_plot = np.asarray([28.5, 19, 25, 32, 33.5, 35])
times_to_plot = np.asarray([24, 30, 36])
marker_size = 4

for t in tqdm(range(len(temps_to_plot)+1)):    

    if t == 0:
        opacity=0
        tlim=temps_to_plot
    #     fig = go.Figure()
    else:
        opacity=1
        tlim = np.asarray(temps_to_plot[:t])
    temp_filter = np.isin(hf_pca_df["temperature"], tlim)
    fig = px.scatter_3d(hf_pca_df.loc[temp_filter], x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                        color="temperature", symbol="timepoint", opacity=opacity,
                        color_continuous_scale="RdBu_r", range_color=[19, 38], 
                        hover_data={"timepoint", "snip_id"})
    
    for e, embryo_id in enumerate(embryo_index[:n_plot]):
        plot_filter = ref_pca_df["embryo_id"] == embryo_id
        fig.add_traces(go.Scatter3d(x=ref_pca_df.loc[plot_filter, "PCA_00_bio"], 
                         y=ref_pca_df.loc[plot_filter, "PCA_01_bio"], 
                         z=ref_pca_df.loc[plot_filter, "PCA_02_bio"],
                         mode="lines", line=dict(
                                color=ref_pca_df.loc[plot_filter, "mdl_stage_hpf"], 
                                cmin=10,
                                width=1), showlegend=False, opacity=0.25))
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange
    
    pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="light", marker_size=marker_size, font_size=12)
    
    fig.update_layout(
      scene=dict(
        # 1) lock in the ranges you want
        xaxis=dict(range=xrange, autorange=False),
        yaxis=dict(range=yrange, autorange=False),
        zaxis=dict(range=zrange, autorange=False),
        # 2) keep aspect‐ratio constant
        aspectmode="manual",
        aspectratio=dict(
          x=(xrange[1]-xrange[0]),
          y=(yrange[1]-yrange[0]),
          z=(zrange[1]-zrange[0])
        ),
        # 3) fix the camera “eye” (distance + angle) once
        camera=dict(
          eye=dict(x=1.2, y=1.2, z=0.8)
        )
      )
    )

    fig = rotate_figure(fig, zoom_factor=zoom_factor*0.95, z_rotation=z_rotation, elev_rotation=elevation)
    
    grid_dict = dict(
          tickmode="array",     # custom list
          tickvals=[-2, -1, 0, 1, 2],
          ticktext=["-2","-1","0","1","2"],
          showgrid=True)
    
    fig.update_layout(
          scene=dict(
            xaxis=grid_dict,
            yaxis=grid_dict,
            zaxis=grid_dict
          ))
    fig.update_layout(
          coloraxis_showscale=False,  
          showlegend=False# hide
        )
    
    fig.write_image(os.path.join(fig_path, f"hotfish_pca_temp{t:02}.png"), scale=2)
    fig.write_html(os.path.join(fig_path, f"hotfish_pca_temp{t:02}.html"))
    
fig.show()



### Make rotating plot

In [None]:
start_z_rotation = 72
stop_z_rotation = 360 + 15
start_elevation = -10
stop_elevation = 10
a_vec = np.linspace(start_z_rotation, stop_z_rotation, 50)
e_vec = np.linspace(start_elevation, stop_elevation, 50)

In [None]:
temp_frame_path = os.path.join(fig_path, "temp_pca_rot_frames", "")
os.makedirs(temp_frame_path, exist_ok=True)
marker_size = 4

for i in tqdm(range(len(e_vec))):
    elevation = e_vec[i]
    z_rotation = a_vec[i]
    
    fig = px.scatter_3d(hf_pca_df.loc[temp_filter], x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                            color="temperature", symbol="timepoint", opacity=opacity,
                            color_continuous_scale="RdBu_r", range_color=[19, 38], 
                            hover_data={"timepoint", "snip_id"})
        
    for e, embryo_id in enumerate(embryo_index[:n_plot]):
        plot_filter = ref_pca_df["embryo_id"] == embryo_id
        fig.add_traces(go.Scatter3d(x=ref_pca_df.loc[plot_filter, "PCA_00_bio"], 
                         y=ref_pca_df.loc[plot_filter, "PCA_01_bio"], 
                         z=ref_pca_df.loc[plot_filter, "PCA_02_bio"],
                         mode="lines", line=dict(
                                color=ref_pca_df.loc[plot_filter, "mdl_stage_hpf"], 
                                cmin=10,
                                width=1), showlegend=False, opacity=0.25))
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange
    
    pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
    fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="light", marker_size=marker_size, font_size=12)
    
    fig.update_layout(
      scene=dict(
        # 1) lock in the ranges you want
        xaxis=dict(range=xrange, autorange=False),
        yaxis=dict(range=yrange, autorange=False),
        zaxis=dict(range=zrange, autorange=False),
        # 2) keep aspect‐ratio constant
        aspectmode="manual",
        aspectratio=dict(
          x=(xrange[1]-xrange[0]),
          y=(yrange[1]-yrange[0]),
          z=(zrange[1]-zrange[0])
        ),
        # 3) fix the camera “eye” (distance + angle) once
        camera=dict(
          eye=dict(x=1.2, y=1.2, z=0.8)
        )
      )
    )
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor*0.95, z_rotation=z_rotation, elev_rotation=elevation)
    
    grid_dict = dict(
          tickmode="array",     # custom list
          tickvals=[-2, -1, 0, 1, 2],
          ticktext=["-2","-1","0","1","2"],
          showgrid=True)
    
    fig.update_layout(
          scene=dict(
            xaxis=grid_dict,
            yaxis=grid_dict,
            zaxis=grid_dict
          ))
    fig.update_layout(
          coloraxis_showscale=False,  
          showlegend=False# hide
        )

    fig.write_image(os.path.join(temp_frame_path, f"hotfish_pca_temp{i:02}.png"), scale=2)
    fig.write_html(os.path.join(temp_frame_path, f"hotfish_pca_temp{i:02}.html"))

fig.show()

In [None]:
angle_vec = np.linspace(0, 360, 50)

frame_path = os.path.join(fig_path, "hf_pca_rot_frames", "")
os.makedirs(frame_path, exist_ok=True)

# initialize figure
fig = px.scatter_3d(hf_pca_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                        color="temperature", symbol="timepoint", opacity=opacity,
                        color_continuous_scale="RdBu_r", range_color=[19, 38], 
                        hover_data={"timepoint", "snip_id"})

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

pca_axis_labels=["morph PC 1", "morph PC 2", "morph PC 3"]
fig = format_3d_plotly(fig, axis_labels=pca_axis_labels, theme="dark", marker_size=marker_size)

if t == 0:
    fig.add_trace(dummy_trace)
elif t == 1:
    fig.update_traces(marker=dict(line=dict(color="black")))

# rotate
for a, angle in enumerate(tqdm(angle_vec)):
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    

    fig.write_image(os.path.join(frame_path, f"hotfish_pca_rot{a:02}.png"), scale=2)
    # fig.write_html(os.path.join(fig_path, f"hotfish_pca_rot{a:02}.html"))
    
fig.show()