In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo
import time 
from glob2 import glob
from tqdm import tqdm

pyo.init_notebook_mode()

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20231106_ds"

model_name_vec = ["SeqVAE_z100_ne250_gamma_temp_self_and_other", "SeqVAE_z100_ne250_triplet_loss_test_self_and_other", 
                  "SeqVAE_z100_ne250_gamma_temp_SELF_ONLY", "SeqVAE_z100_ne250_triplet_loss_SELF_ONLY",
                  "MetricVAE_z100_ne250_temperature_sweep_v2",
                  "VAE_z100_ne250_vanilla"]
short_name_vec = ["SeqVAE - NT-Xent", "SeqVAE - Triplet", "SeqVAE - NT-Xent (self only)", "SeqVAE - Triplet (self only)", 
                  "MetricVAE", "VAE"]
train_dir = os.path.join(root, "training_data", train_name)

# make directory to save results
figure_path = os.path.join(train_dir, "vae_info_analyese", '')
if not os.path.isdir(figure_path):
    os.makedirs(figure_path)

mdl_path_list = []
info_df_list = []

for mdi, model_name in enumerate(model_name_vec):
    model_dir = os.path.join(train_dir, model_name) 
    
    # path to figures and data
    model_path = os.path.join(model_dir,  "figures")
    mdl_path_list.append(model_path)
    
    # concatenate lists to form one large master set
    latent_info_df = pd.read_csv(os.path.join(model_path, "latent_info_df.csv"), index_col=0)
    latent_info_df["model_type"] = short_name_vec[mdi]
    latent_info_df["type_index"] = mdi
    info_df_list.append(latent_info_df)
    
latent_df_master = pd.concat(info_df_list, axis=0, ignore_index=True)


In [None]:
mdi = 3
model_name = model_name_vec[mdi]
model_dir = os.path.join(train_dir, model_name) 
# path to figures and data
model_path = os.path.join(model_dir,  "figures")
mdl_path_list.append(model_path)

# concatenate lists to form one large master set
latent_info_df = pd.read_csv(os.path.join(model_path, "latent_info_df.csv"), index_col=0)

latent_info_df["bio_self_info_mc"] = latent_info_df["bio_entropy_base_mc"] - latent_info_df["bio_entropy_self_mc"]
latent_info_df["bio_seq_info_mc"] = latent_info_df["bio_entropy_base_mc"] - latent_info_df["bio_entropy_seq_mc"]
latent_info_df["bio_self_info_em"] = latent_info_df["bio_entropy_base_em"] - latent_info_df["bio_entropy_self_em"]
latent_info_df["bio_seq_info_em"] = latent_info_df["bio_entropy_base_em"] - latent_info_df["bio_entropy_seq_em"]

latent_info_df

### Calculate self and seq information gains for mc and em samples
We will do this seperately for "biological" and "nuisance" partitions

In [None]:
latent_df_master["bio_self_info_mc"] = latent_df_master["bio_entropy_base_mc"] - latent_df_master["bio_entropy_self_mc"]
latent_df_master["bio_seq_info_mc"] = latent_df_master["bio_entropy_base_mc"] - latent_df_master["bio_entropy_seq_mc"]
latent_df_master["bio_self_info_em"] = latent_df_master["bio_entropy_base_em"] - latent_df_master["bio_entropy_self_em"]
latent_df_master["bio_seq_info_em"] = latent_df_master["bio_entropy_base_em"] - latent_df_master["bio_entropy_seq_em"]

min_size = 5
latent_df_master["bio_self_info_mc_plot"] = latent_df_master["bio_self_info_mc"].copy()
latent_df_master.loc[latent_df_master["bio_self_info_mc_plot"]<=min_size, "bio_self_info_mc_plot"] = min_size

latent_df_master["bio_seq_info_mc_plot"] = latent_df_master["bio_seq_info_mc"].copy()
latent_df_master.loc[latent_df_master["bio_seq_info_mc_plot"]<=min_size, "bio_seq_info_mc_plot"] = min_size

latent_df_master["bio_self_info_em_plot"] = latent_df_master["bio_self_info_em"].copy()
latent_df_master.loc[latent_df_master["bio_self_info_em_plot"]<=min_size, "bio_self_info_em_plot"] = min_size

latent_df_master["bio_seq_info_em_plot"] = latent_df_master["bio_seq_info_em"].copy()
latent_df_master.loc[latent_df_master["bio_seq_info_em_plot"]<=min_size, "bio_seq_info_em_plot"] = min_size

In [None]:
fig = px.scatter(latent_df_master, x="bio_entropy_base_em", y="bio_entropy_self_em", color="model_type", 
                 template="plotly", size="bio_self_info_em_plot",
                 labels=dict(bio_entropy_base_em="total entropy (nats)", bio_entropy_self_em="self pair entropy (nats)",
                             model_type="model class"))

ref_line = np.linspace(-250, 200)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_traces(marker=dict(
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(figure_path, "self_em_entropy.png"))

In [None]:
fig = px.scatter(latent_df_master, x="bio_entropy_base_em", y="bio_entropy_seq_em", color="model_type", 
                 template="plotly", size="bio_seq_info_em_plot", 
                 labels=dict(bio_entropy_base_em="total entropy (nats)", 
                             bio_entropy_seq_em="sequential pair entropy (nats)",
                             model_type="model class"))
ref_line = np.linspace(-250, 200)

fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_traces(marker=dict(line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(figure_path, "seq_em_entropy.png"))

In [None]:
fig = px.scatter(latent_df_master, x="bio_seq_info_em", y="recon_mse_em", color="model_type", 
                  template="plotly",
                  labels=dict(recon_mse_em="image reconstruction error", 
                              bio_seq_info_em="latent information content",
                              model_type="model class"))
                 

fig.update_traces(marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(figure_path, "info_seq_em_vs_recon_mse.png"))

In [None]:
slide_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/20240207/"

so_ntx_bool = (latent_df_master["model_type"]=="SeqVAE - NT-Xent").to_numpy()
so_trip_bool = (latent_df_master["model_type"]=="SeqVAE - Triplet").to_numpy()
vae_bool = (latent_df_master["model_type"]=="VAE").to_numpy()
metric_bool = (latent_df_master["model_type"]=="MetricVAE").to_numpy()


latent_df_simp = latent_df_master.copy()
latent_df_simp.loc[so_ntx_bool | metric_bool, "model_type"] = "contrastive-VAE"
plot_bool = so_ntx_bool | metric_bool | vae_bool |so_trip_bool

vae_val = latent_df_simp.loc[vae_bool, "bio_seq_info_em"].to_numpy()[0]
latent_df_simp["bio_seq_info_em"] = latent_df_simp["bio_seq_info_em"]-vae_val 

fig = px.scatter(latent_df_simp.loc[plot_bool], x="bio_seq_info_em", y="recon_mse_em", color="model_type", 
                  template="plotly",
                  labels=dict(recon_mse_em="image reconstruction error (MSE)", 
                              bio_seq_info_em="information gain (nats)",
                              model_type="model class"))
                 

fig.update_traces(marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers')
                 )
fig.update_layout(
    legend_title="stage (hpf)",
    font=dict(
        family="Arial",
        size=14,
        color="Black"
    ))

fig.show()
fig.write_image(os.path.join(slide_path, "info_vs_mse.png"), scale=2)

In [None]:
n_all = 1000
n_pos = 50
all_sigma = 10
pos_sigma = 1

all_samples = np.random.multivariate_normal(np.zeros(2), np.asarray([[all_sigma, 0], [0, all_sigma]]), size=n_all)
pos_samples = np.random.multivariate_normal(np.zeros(2), np.asarray([[pos_sigma, 0], [0, pos_sigma]]), size=n_pos)

fig = go.Figure()

# fig.add_scatter(x=[0], y=[0], mode="markers", 
#                 marker=dict(size=20, opacity=1, color="black", line=dict(width=1, color='DarkSlateGrey')),
#                 name="reference point")

fig.add_scatter(x=all_samples[:, 0], y=all_samples[:, 1], mode="markers", 
                marker=dict(size=12, color="rgb(141,160,203)", opacity=0.75, line=dict(width=1, color='DarkSlateGrey')),
                name="all embryos")



fig.update_layout(showlegend= True,
                template='plotly',
                plot_bgcolor='rgba(0, 0, 0, 0)',
                paper_bgcolor='rgba(0, 0, 0, 0)',
                xaxis =  {'range': [-all_sigma*1.1, all_sigma*1.1],                                     
#                         'showgrid': False,
#                         'zeroline': False,
                        'visible': False
                             },
                yaxis = {'range': [-all_sigma*1.1, all_sigma*1.1],                             
#                        'showgrid': False,
#                        'zeroline': False,
                       'visible': False
                            }
                )

fig.update_layout(
    legend_title="",
    font=dict(
        family="Arial",
        size=18,
        color="Black"
    ))

fig.write_image(os.path.join(slide_path, "info_cartoon_neg_examples.png"), scale=2)

fig.add_scatter(x=pos_samples[:, 0], y=pos_samples[:, 1], mode="markers", 
                marker=dict(color="rgb(166,216,84)", size=12, opacity=1, 
                            line=dict(width=1, color='DarkSlateGrey')),
                name="positive")

fig.write_image(os.path.join(slide_path, "info_cartoon_all_examples.png"), scale=2)

fig.show()

In [None]:

# model_name_vec = ["SeqVAE_z100_ne250_gamma_temp_self_and_other", "SeqVAE_z100_ne250_triplet_loss_test_self_and_other", 
#                   "SeqVAE_z100_ne250_gamma_temp_SELF_ONLY", "SeqVAE_z100_ne250_triplet_loss_SELF_ONLY",
#                   "MetricVAE_z100_ne250_temperature_sweep_v2",
#                   "VAE_z100_ne250_vanilla"]
# short_name_vec = ["SeqVAE - NT-Xent", "SeqVAE - Triplet", "SeqVAE - NT-Xent (self only)", "SeqVAE - Triplet (self only)", 
#                   "MetricVAE", "VAE"]

fig = px.scatter(latent_df_master, x="n_entropy_base_em", y="n_entropy_seq_em", color="model_type", 
                 template="plotly",
                  labels=dict(n_entropy_base_em="total nuisance entropy (nats)", 
                              n_entropy_seq_em="sequential pair nuisance entropy (nats)",
                              model_type="model class"))

ref_line = np.linspace(-50, 50)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_traces(marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(slide_path, "seq_entropy.png"))

In [None]:
so_ntx_bool

In [None]:
slide_path = /Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/20240207/""

fig = px.scatter(latent_df_master, x="bio_seq_info_em", y="recon_mse_em", color="model_type", 
                  template="plotly",
                  labels=dict(recon_mse_em="image reconstruction error", 
                              bio_seq_info_em="latent information content",
                              model_type="model class"))
                 

fig.update_traces(marker=dict(size=12,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(slide_path, "info_seq_em_vs_recon_mse.png"))

In [None]:
fig = px.scatter(latent_df_master, x="bio_seq_info_mc", y="bio_seq_info_em", color="model_type", 
                 template="plotly",
                 labels=dict(bio_seq_info_mc="sequential information (MC)", 
                             bio_seq_info_em="sequential information (EM)",
                             model_type="model class"))
ref_line = np.linspace(-50, 150)

fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_traces(marker=dict(size=10, 
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(figure_path, "seq_information_em_vs_mc.png"))

In [None]:
fig = px.scatter(latent_df_master, x="bio_self_info_em", y="bio_seq_info_em", color="model_type", 
                 template="plotly",
                 labels=dict(bio_seq_info_mc="sequential information (MC)", 
                             bio_seq_info_em="sequential information (EM)",
                             model_type="model class"))
ref_line = np.linspace(-50, 150)

fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_traces(marker=dict(size=10, 
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()
fig.write_image(os.path.join(figure_path, "self_vs_seq_information.png"))