In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo

pyo.init_notebook_mode()

In [None]:
from glob2 import glob

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20240204_ds_v2"
model_name = "SeqVAE_z100_ne250_triplet_loss_test_SELF_and_OTHER"

# training_name = "SeqVAE_training_2024-02-05_23-23-11" #"SeqVAE_training_2024-02-06_03-12-13" #"SeqVAE_training_2024-02-05_21-41-32"
# training_name = "SeqVAE_training_2024-02-05_04-16-08" # best for v2
training_name = "SeqVAE_training_2024-02-06_16-36-54"
train_dir = os.path.join(root, "training_data", train_name)
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = os.path.join(output_dir, training_name)

read_path = os.path.join(training_path, "figures", "")
# path to figures and data
fig_root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/20240207/"
figure_path = os.path.join(fig_root, training_name, )
if not os.path.isdir(figure_path):
    os.makedirs(figure_path)

In [None]:
# load datasets
umap_df = pd.read_csv(os.path.join(read_path, "umap_df.csv"), index_col=0)
meta_df = pd.read_csv(os.path.join(read_path, "meta_summary_df.csv"), index_col=0)
embryo_df = pd.read_csv(os.path.join(read_path, "embryo_stats_df.csv"), index_col=0)
# metric_df = pd.read_csv(os.path.join(figure_path, "metric_df.csv"), index_col=0)

wik_indices = umap_df["master_perturbation"]=="wik"
tbxta_indices = umap_df["master_perturbation"]=="tbxta"
gdf3_indices = umap_df["master_perturbation"]=="gdf3"
lmx_indices = umap_df["master_perturbation"]=="lmx1b"



In [None]:

print(embryo_df.shape[0])
eid_vec = [snip[:16] for snip in embryo_df["snip_id"]]
print("n_embryos: "+ str(len(np.unique(eid_vec))))
print("max age: "+ str(np.max(embryo_df["predicted_stage_hpf"])))
print("min age: "+ str(np.min(embryo_df["predicted_stage_hpf"])))
print("n experiments: "+ str(len(np.unique(embryo_df["experiment_date"]))))

In [None]:
temperature = meta_df["temperature"].values
print(temperature)
gamma = meta_df["gamma"].values
print(gamma)

In [None]:
plot_bool =  wik_indices | gdf3_indices | tbxta_indices | lmx_indices

color_discrete_map = {"lmx1b" :"#EF553B" , "wik":"#636EFA", "gdf3":"#AB63FA", "tbxta":"#00CC96"}


fig = px.scatter(umap_df.loc[plot_bool], x="UMAP_00_bio_3", y="UMAP_01_bio_3",
                         color='master_perturbation', opacity=0.5,
                         template="plotly", color_discrete_map=color_discrete_map)


# fig.update_layout(
#                 xaxis_title="UMAP 1",
#                 yaxis_title="UMAP 2"
#             )
fig.update_traces(
    marker=dict(size=6)
    )


fig.show()

In [None]:
plot_bool =  wik_indices | gdf3_indices | tbxta_indices | lmx_indices

fig = px.scatter_3d(umap_df.loc[plot_bool], x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                         color='master_perturbation', opacity=0.85,
                         template="plotly", color_discrete_map=color_discrete_map)


# fig.update_layout(
#                 xaxis_title="UMAP 1",
#                 yaxis_title="UMAP 2"
#             )
# fig.update_traces(marker=dict(size= 10, line=dict(color="DarkSlateGrey", width=0.1)))


fig.show()

In [None]:
import math
from tqdm import tqdm

def make_rotating_figure(plot_df, angle_vec, frame_dir, marker_opacity=0.75, marker_size=4, color_var=None):

    if color_var is None:
        color_var = "predicted_stage_hpf"
        
    for iter_i, a in enumerate(tqdm(angle_vec)):
        angle = a
        za = 0.3
        vec = np.asarray([math.cos(angle), math.sin(angle), za])
        vec = vec*2
        camera = dict(
            eye=dict(x=vec[0], y=vec[1], z=vec[2]))
        
        fig = px.scatter_3d(plot_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                            color=color_var, opacity=marker_opacity,
                         color_discrete_map=color_discrete_map,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"})
        
        fig.update_traces(marker=dict(size=marker_size,line=dict(width=0.1, color='rgba(70,70,70,0.02)')))
        
        fig.update_layout(template="plotly", showlegend=False)

        fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

        fig.update_layout(scene = dict(
                        xaxis_title='UMAP 1',
                        yaxis_title='UMAP 2',
                        zaxis_title='UMAP 3',
                        xaxis = dict(showticklabels=False, range=[-5,15]),
                        yaxis = dict(showticklabels=False, range=[-5, 15]),
                        zaxis = dict(showticklabels=False, range=[-15, 15])))
        
        fig.write_image(os.path.join(frame_dir, "umap_scatter" + "_" + color_var + f"_{iter_i:03}" + ".png"), scale=2)
        
    return fig

In [None]:
angle_vec[19]-2*np.pi

In [None]:
angle_vec = np.linspace(2.62, 2.62+2*np.pi, 25) #(1.25*np.pi, 3.25*np.pi, 25)
frame_dir = os.path.join(figure_path, "umap3D_WT_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.loc[wik_indices], angle_vec, frame_dir, color_var="master_perturbation")
fig.show()

In [None]:
angle_vec = np.linspace(2.62, 2.62+2*np.pi, 25)
frame_dir = os.path.join(figure_path, "umap3D_WT_gdf3_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.loc[wik_indices | gdf3_indices], angle_vec, frame_dir, color_var="master_perturbation")
fig.show()

In [None]:
angle_vec = np.linspace(2.62, 2.62+2*np.pi, 25)
frame_dir = os.path.join(figure_path, "umap3D_WT_gdf3_tbxta_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.loc[wik_indices | gdf3_indices | tbxta_indices], angle_vec, frame_dir, color_var="master_perturbation")
fig.show()

In [None]:
angle_vec = np.linspace(2.62, 2.62+2*np.pi, 25)
frame_dir = os.path.join(figure_path, "umap3D_all_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.loc[plot_bool], angle_vec, frame_dir, color_var="master_perturbation")
fig.show()