In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo

pyo.init_notebook_mode()

In [None]:
from glob2 import glob

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20240507"
model_name = "VAE_z100_ne250_base_model"
training_name = "VAE_training_2024-05-07_21-11-49" # "SeqVAE_training_2024-01-09_13-17-47"
train_dir = os.path.join(root, "training_data", train_name)
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = os.path.join(output_dir, training_name)

read_path = os.path.join(training_path, "figures", "")
# path to figures and data
fig_root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/20240515/"
figure_path = os.path.join(fig_root, training_name, )
if not os.path.isdir(figure_path):
    os.makedirs(figure_path)

In [None]:
# load datasets
umap_df = pd.read_csv(os.path.join(read_path, "umap_df.csv"), index_col=0)
meta_df = pd.read_csv(os.path.join(read_path, "meta_summary_df.csv"), index_col=0)
age_key = pd.read_csv(os.path.join(root, "metadata", "age_key_df.csv"))

umap_df = umap_df.merge(age_key.loc[:, ["snip_id", "inferred_stage_hpf_reg"]], how="left", on="snip_id")
# metric_df = pd.read_csv(os.path.join(figure_path, "metric_df.csv"), index_col=0)

In [None]:
np.unique(umap_df["master_perturbation"])
pert_key = dict({'DMSO':"WT", 'Fgf_025':"Chem", 'Fgf_050':"Chem", 'Fgf_075':"Chem", 'Fgf_100':"Chem", 
                 'Fgf_150':"Chem", 'H2B-mScarlet':"WT", 'Shh_025':"Chem", 'Shh_050':"Chem", 'Shh_075':"Chem", 
                 'Shh_100':"Chem", "TGFB-i":"Chem", "Uncertain": "Other", 'Wnt-i':"Chem", 'ethanol_ctrl':"WT", 
                 'gdf3':"Gene", 'lmx1b':"Gene", "notch-i":"Chem", "noto":"Gene", 'sox10GFP':"WT", 
                 'sox10GFP-inj-cr':"Gene", 'sox10GFP-inj-nick':"Gene", 'tbx5a-SG':"WT", 'tbxta':"Gene", 'wik':"WT",
                 'wik-ctrl-inj':"WT",'wik-inj-ctrl':"WT"})

cat_vec = [pert_key[pert] for pert in umap_df["master_perturbation"].tolist()]
umap_df["pert_class"] = cat_vec

In [None]:
pert_count_df = umap_df.loc[:, ["pert_class", "master_perturbation", 
                                "inferred_stage_hpf_reg"]].groupby(["pert_class", "master_perturbation"
                                                                   ],  as_index=False).agg(
                                            {'inferred_stage_hpf_reg':['count', 'min', 'max']})

# create an empty list to save the
# names of the flattened columns
flat_cols = []
 
# iterate through this tuples and
# join them as single string
for ii, i in enumerate(pert_count_df.columns):
    if ii < 2:
        flat_cols.append(i[0])
    else:
        flat_cols.append(i[1])
 
# columns to the grouped columns.
pert_count_df.columns = flat_cols

pert_count_df["stage_range"] = pert_count_df["max"] - pert_count_df["min"]
pert_count_df["stage_range_norm"] = pert_count_df["stage_range"] / np.max(pert_count_df["stage_range"])

pert_count_df["width_norm"] = pert_count_df["count"] / np.sum(pert_count_df["count"]) * 360

In [None]:
pert_count_df.head()

### What about sunburst?

In [None]:
import plotly.express as px
import numpy as np

fig = px.sunburst(pert_count_df, path=['pert_class', 'master_perturbation'], values='count', template="plotly",
                  color='pert_class', hover_data=['stage_range'],
                  color_continuous_scale='RdBu')
fig.show()

### Let's try making a polar bar plot

In [None]:
import plotly.graph_objects as go

fig = go.Figure(go.Barpolar(
    r0=pert_count_df["min"].to_numpy(),
    r=pert_count_df["stage_range"].to_numpy(),
    theta=np.linspace(0, 360, pert_count_df.shape[0]),
    #theta=[65, 15, 210, 110, 312.5, 180, 270],
    width=pert_count_df["width_norm"].to_numpy(),
    marker_color=["#E4FF87", '#709BFF', '#709BFF', '#FFAA70', '#FFAA70', '#FFDF70', '#B6FFB4'],
    marker_line_color="black",
    marker_line_width=2,
    opacity=0.8
))

fig.update_layout(
    template=None,
    polar = dict(
        radialaxis = dict(range=[0, 72], showticklabels=False, ticks=''),
        angularaxis = dict(showticklabels=False, ticks='')
    )
)

fig.show()

In [None]:
import plotly.express as px
import numpy as np
# df = px.data.gapminder().query("year == 2007")
fig = px.treemap(pert_count_df, path=["pert_class", 'master_perturbation'], values="counts",
                  color='counts', #hover_data=['iso_alpha'],
                  color_continuous_scale='RdBu',
                  color_continuous_midpoint=np.average(pert_count_df['counts'], weights=pert_count_df['counts']))
fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
fig.show()

In [None]:
import math
color_var="predicted_stage_hpf"
marker_size=6
marker_opacity=0.1
angle=0

fig = px.scatter_3d(umap_df, x="UMAP_00_3", y="UMAP_01_3", z="UMAP_02_3",
                            color=color_var, opacity=0.005,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"},
                           hover_data=["snip_id"])
        
fig.update_traces(marker={'size': marker_size})

fig.update_layout(template="plotly")


za = 0.5
vec = np.asarray([math.cos(angle), math.sin(angle), za])
vec = vec
camera = dict(
    eye=dict(x=vec[0], y=vec[1], z=vec[2]))

fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

fig.update_layout(scene = dict(
                xaxis_title='UMAP 1',
                yaxis_title='UMAP 2',
                zaxis_title='UMAP 3',
                xaxis = dict(showticklabels=False),
                yaxis = dict(showticklabels=False),
                zaxis = dict(showticklabels=False)))

fig.show()

In [None]:
fig = px.scatter_3d(umap_df.loc[wik_indices], x="UMAP_00_3", y="UMAP_01_3", z="UMAP_02_3",
                            color=color_var, opacity=0.005,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"},  hover_data=["snip_id"])
        
fig.update_traces(marker={'size': marker_size})

fig.update_layout(template="plotly")


za = 0.5
vec = np.asarray([math.cos(angle), math.sin(angle), za])
vec = vec
camera = dict(
    eye=dict(x=vec[0], y=vec[1], z=vec[2]))

fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

fig.update_layout(scene = dict(
                xaxis_title='UMAP 1',
                yaxis_title='UMAP 2',
                zaxis_title='UMAP 3',
                xaxis = dict(showticklabels=False),
                yaxis = dict(showticklabels=False),
                zaxis = dict(showticklabels=False)))

fig.show()

In [None]:
from sklearn.svm import OneClassSVM

X = umap_df.loc[:, ["UMAP_00_3", "UMAP_01_3", "UMAP_02_3"]].to_numpy()
clf = OneClassSVM(gamma='auto').fit(X)
umap_df["outlier_flags"] = clf.predict(X)

In [None]:
fig = px.scatter_3d(umap_df.loc[wik_indices], x="UMAP_00_3", y="UMAP_01_3", z="UMAP_02_3",
                            color="outlier_flags", opacity=0.05,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"})
        
fig.update_traces(marker={'size': marker_size})

fig.update_layout(template="plotly")


za = 0.5
vec = np.asarray([math.cos(angle), math.sin(angle), za])
vec = vec
camera = dict(
    eye=dict(x=vec[0], y=vec[1], z=vec[2]))

fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

fig.update_layout(scene = dict(
                xaxis_title='UMAP 1',
                yaxis_title='UMAP 2',
                zaxis_title='UMAP 3',
                xaxis = dict(showticklabels=False),
                yaxis = dict(showticklabels=False),
                zaxis = dict(showticklabels=False)))

fig.show()

In [None]:
import math
from tqdm import tqdm

def make_rotating_figure(plot_df, angle_vec, frame_dir, marker_opacity=0.5, marker_size=6, color_var=None):

    if color_var is None:
        color_var = "predicted_stage_hpf"
        
    for iter_i, a in enumerate(tqdm(angle_vec)):
        angle = a
        za = 0.3
        vec = np.asarray([math.cos(angle), math.sin(angle), za])
        vec = vec*2
        camera = dict(
            eye=dict(x=vec[0], y=vec[1], z=vec[2]))
        
        fig = px.scatter_3d(plot_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                            color=color_var, opacity=marker_opacity,
                            labels={'predicted_stage_hpf': "age (hpf)",
                                    'master_perturbation': "genotype"})
        
        fig.update_traces(marker={'size': marker_size})
        
        fig.update_layout(template="plotly")

        fig.update_layout(scene_camera=camera, scene_dragmode='orbit')

        fig.update_layout(scene = dict(
                        xaxis_title='UMAP 1',
                        yaxis_title='UMAP 2',
                        zaxis_title='UMAP 3',
                        xaxis = dict(showticklabels=False),
                        yaxis = dict(showticklabels=False),
                        zaxis = dict(showticklabels=False)))

#         fig.update_layout(coloraxis_showscale=False)
        
#         fig.update_layout(
#                 scene=dict(aspectratio=dict(x=1, y=1, z=1))
#         )

        fig.write_image(os.path.join(frame_dir, "umap_scatter" + "_" + color_var + f"_{iter_i:03}" + ".png"), scale=2)
            
    return fig

In [None]:
angle_vec = np.linspace(1.25*np.pi, 3.25*np.pi, 25)
frame_dir = os.path.join(figure_path, "hpf_umap_frames", "")
if not os.path.isdir(frame_dir):
    os.makedirs(frame_dir)
    
fig = make_rotating_figure(umap_df.iloc[wik_indices], angle_vec, frame_dir)
fig.show()

In [None]:
frame_dir

In [None]:
# look at the umap
fig = px.scatter_3d(umap_df.iloc[wik_indices], x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                         color='predicted_stage_hpf', opacity=0.5,
                         template="plotly")


# fig.update_layout(
#                 xaxis_title="UMAP 1",
#                 yaxis_title="UMAP 2"
#             )
fig.update_traces(
    marker=dict(size=6)
    )


fig.show()
# fig.write_image(os.path.join(out_figure_path, "UMAP_wt_scatter_bio.png"))

In [None]:
import plotly.graph_objects as go

fig = px.scatter_3d(umap_df.iloc[wik_indices], x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3",
                         opacity=0.5,
                         template="plotly")

fig.add_trace(go.Scatter3d(x=umap_df.loc[gdf3_indices, "UMAP_00_bio_3"],
                           y=umap_df.loc[gdf3_indices, "UMAP_01_bio_3"],
                           z=umap_df.loc[gdf3_indices, "UMAP_02_bio_3"],
                           mode="markers", 
                           marker=dict(opacity=0.5)))


fig.update_traces(
    marker=dict(size=6)
    )


fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Histogram(x=metric_df.loc[:, "euc_bio_rand"], name="euc_bio_rand")])
fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_bio"], name="euc_bio"))

fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_nbio_rand"], name="euc_nbio_rand"))
fig.add_trace(go.Histogram(x=metric_df.loc[:, "euc_nbio"], name="euc_nbio"))

fig.show()