In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo

pyo.init_notebook_mode()

In [None]:
from glob2 import glob

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20231106_ds"
model_name = "SeqVAE_z100_ne250_triplet_loss_test_self_and_other"
training_name = "SeqVAE_training_2024-01-06_03-55-23" # "SeqVAE_training_2024-01-09_13-17-47"
train_dir = os.path.join(root, "training_data", train_name)
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = os.path.join(output_dir, training_name)

read_path = os.path.join(training_path, "figures", "")
# path to figures and data
fig_root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/20240116/"
figure_path = os.path.join(fig_root, training_name, )
if not os.path.isdir(figure_path):
    os.makedirs(figure_path)

In [None]:
# load datasets
umap_df = pd.read_csv(os.path.join(read_path, "umap_df.csv"), index_col=0)
meta_df = pd.read_csv(os.path.join(read_path, "meta_summary_df.csv"), index_col=0)
# metric_df = pd.read_csv(os.path.join(figure_path, "metric_df.csv"), index_col=0)

wik_indices = np.where(umap_df["master_perturbation"]=="wck-AB")[0]
gdf3_indices = np.where(umap_df["master_perturbation"]=="gdf3")[0]

In [None]:
temperature = meta_df["temperature"].values
print(temperature)
gamma = meta_df["gamma"].values
print(gamma)

In [None]:
import math
from tqdm import tqdm

def make_rotating_figure(plot_df, angle_vec, frame_dir, marker_opacity=0.5, marker_size=6, color_var=None):

    if color_var is None:
        color_var = "predicted_stage_hpf"
        
    for iter_i, a in enumerate(tqdm(angle_vec)):
        angle = a
        za = 0.3
        vec = np.asarray([math.cos(angle), math.sin(angle), za])
        vec = vec*2
        camera = dict(
            eye=dict(x=vec[0], y=vec[1], z=vec[2]))
        
        fig = px.scatter_3d(plot_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                            color=color_var, opacity=marker_opacity,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"})
        
        fig.update_traces(marker={'size': marker_size})
        
        fig.update_layout(template="plotly")

        fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

        fig.update_layout(scene = dict(
                        xaxis_title='UMAP 1',
                        yaxis_title='UMAP 2',
                        zaxis_title='UMAP 3',
                        xaxis = dict(showticklabels=False),
                        yaxis = dict(showticklabels=False),
                        zaxis = dict(showticklabels=False)))

#         fig.update_layout(coloraxis_showscale=False)
        
#         fig.update_layout(
#                 scene=dict(aspectratio=dict(x=1, y=1, z=1))
#         )

        fig.write_image(os.path.join(frame_dir, "umap_scatter" + "_" + color_var + f"_{iter_i:03}" + ".png"), scale=2)
            
    return fig

In [None]:
angle_vec = np.linspace(1.25*np.pi, 3.25*np.pi, 25)
frame_dir = os.path.join(figure_path, "hpf_umap_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.iloc[wik_indices], angle_vec, frame_dir)
fig.show()

In [None]:
frame_dir

In [None]:
# look at the umap
fig = px.scatter_3d(umap_df.iloc[wik_indices], x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                         color='predicted_stage_hpf', opacity=0.5,
                         template="plotly")


# fig.update_layout(
#                 xaxis_title="UMAP 1",
#                 yaxis_title="UMAP 2"
#             )
fig.update_traces(
    marker=dict(size=6)
    )


fig.show()
# fig.write_image(os.path.join(out_figure_path, "UMAP_wt_scatter_bio.png"))

In [None]:
import plotly.graph_objects as go

fig = px.scatter_3d(umap_df.iloc[wik_indices], x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                         opacity=0.5,
                         template="plotly")

fig.add_trace(go.Scatter3d(x=umap_df.loc[gdf3_indices, "UMAP_00_bio_3"],
                           y=umap_df.loc[gdf3_indices, "UMAP_01_bio_3"],
                           z=umap_df.loc[gdf3_indices, "UMAP_02_bio_3"],
                           mode="markers", 
                           marker=dict(opacity=0.5)))


fig.update_traces(
    marker=dict(size=6)
    )


fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Histogram(x=metric_df.loc[:, "euc_bio_rand"], name="euc_bio_rand")])
fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_bio"], name="euc_bio"))

fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_nbio_rand"], name="euc_nbio_rand"))
fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_nbio"], name="euc_nbio"))

fig.show()