## Exploring morph VAE output
This notebook generates visualizations and conducts analyses to assess the biological content of the latent space representations learned by our VAE models

In [1]:
from functions.pythae_utils import *
import os
from pythae.models import VAE, VAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models import AutoModel
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import umap
import numpy as np
from sklearn.preprocessing import StandardScaler
from pythae.samplers import NormalSampler

KeyboardInterrupt: 

#### Get paths to data, figures, and latent space outputs

In [None]:
root = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/"
# root = "E:\\Nick\\Dropbox (Cole Trapnell's Lab)\\Nick\\morphseq\\"
train_name = "20230804_vae_full"
model_name = "20230804_vae_full_conv_z25_bs032_ne100_depth05"
train_dir = os.path.join(root, "training_data", train_name)
output_dir = os.path.join(train_dir, model_name) 

# get path to model
last_training = sorted(os.listdir(output_dir))[-1]
trained_model = AutoModel.load_from_folder(
    os.path.join(output_dir, last_training, 'final_model'))

# path to figures and data
figure_path = os.path.join(output_dir, last_training, "figures")

#### Create DataLoader objects for train, eval, and test sets
- "Train" data were used to train the model
- "Eval" data were used to assess model during training
- "Test" data were untouched during training process

In [None]:
import pandas as pd
# main_dims = (576, 256) # size of images used for training
# data_transform = make_dynamic_rs_transform(main_dims)

# mode_vec = ["train", "eval", "test"]
# data_sampler_vec = []
# for mode in mode_vec:
#     ds_temp = MyCustomDataset(
#         root=os.path.join(train_dir, mode),
#         transform=data_transform
#     )
#     data_sampler_vec.append(ds_temp)

# load data frame with results
morph_df = pd.read_csv(os.path.join(figure_path, "embryo_stats_df.csv"), index_col=0)

In [None]:
morph_df.columns

#### Question 1: how well does the model generalize?
We can compare reconstruction error across the three data groups. In a perfect world, the model would do as well at reconstructing embryos it has never seen as on embryos that were in the training set

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# load vectors that contain reconstruction loss for images
train_mse_vec = np.load(os.path.join(figure_path, "train_set_recon_loss.npy"))
eval_mse_vec = np.load(os.path.join(figure_path, "eval_set_recon_loss.npy"))
test_mse_vec = np.load(os.path.join(figure_path, "test_set_recon_loss.npy"))

# normalizing constant
global_min = np.min([train_mse_vec, eval_mse_vec, test_mse_vec])

# calculate averages
train_mu = np.mean(train_mse_vec-global_min)
eval_mu = np.mean(eval_mse_vec-global_min)
test_mu = np.mean(test_mse_vec-global_min)

fig = px.histogram(morph_df, x="recon_mse", color="train_cat")

# fig.update_layout(legend=[ f"training images (mu={np.round(train_mu)})", f"eval images (mu={np.round(eval_mu)})", f"test images (mu={np.round(test_mu)})"]) 


fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.5)

fig.show()

Note that the train and eval data compositions should be identical, so likely the eval set is a slightly better point of comparison. Overall, we see that the model does better on training data, but not overwhelmingly so. 

Let's look at a couple comparison images to get a sense for the qualitative implications of these reconstruction error values.

In [None]:
from skimage import io
import plotly.express as px

im_train_avg = io.imread("/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/training_data/20230804_vae_full/20230804_vae_full_conv_z25_bs032_ne100_depth05/VAE_training_2023-08-09_08-35-26/figures/train_images/im_2966_loss06647.tiff")
im_train_good = io.imread("/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/training_data/20230804_vae_full/20230804_vae_full_conv_z25_bs032_ne100_depth05/VAE_training_2023-08-09_08-35-26/figures/train_images/im_15056_loss06313.tiff")
im_train_bad = io.imread("/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/training_data/20230804_vae_full/20230804_vae_full_conv_z25_bs032_ne100_depth05/VAE_training_2023-08-09_08-35-26/figures/train_images/im_11428_loss07044.tiff")

im_test_avg = io.imread("/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/training_data/20230804_vae_full/20230804_vae_full_conv_z25_bs032_ne100_depth05/VAE_training_2023-08-09_08-35-26/figures/test_images/im_4090_loss06825.tiff")

fig = px.imshow(im_train_avg, title="Representative training image")
fig.show()

In [None]:
fig = px.imshow(im_test_avg, title="Representative test image")
fig.show()

#### Question 2: What does latent space "look" like?
We cannot visualize the full space, but we can use UMAP to generate reduced representations. Not clear how informative this kind of thing is, but worth having for reference

In [None]:
# load arrays with UMAP embeddings


fig = px.scatter(morph_df, x="UMAP_00", y="UMAP_01", color="train_cat", opacity=0.5, template="plotly",
                title="VAE latent space by training class")

fig.show()

In [None]:
morph_df["experiment_id"] = morph_df["experiment_date"].astype('category',copy=True)

fig = px.scatter(morph_df, x="UMAP_00", y="UMAP_01", color="experiment_id", opacity=0.5, template="plotly",
                title="VAE latent space by experiment date")

fig.show()

In [None]:
fig = px.scatter(morph_df, x="UMAP_00", y="UMAP_01", color="master_perturbation", opacity=0.5, template="plotly",
                title="VAE latent space by perturbation")

fig.show()

In [None]:
fig = px.scatter(morph_df, x="UMAP_00", y="UMAP_01", color="predicted_stage_hpf", opacity=0.5, template="plotly",
                title="VAE latent space by age")

fig.show()

### Next steps

1. **Dig further into latent space**
    - What causes the mirroring in UMAP space?
    - What do the outlier points look like? 
    
<br>

2. **Build quantitative pipeline for assessing information content of latent space**
    - Use simple NN to predict age and perturbation from latent space
    - Use multivariate regression (or similar) to look at how much info is _linearly_ decodable
    - Short-term: use these methods to identify optimal model architecture and hyperparameters
    - Mid-term: I want to use a simple linear model to reduce latent space to a biologically salient subspace
    
<br>

3. **Improve and extend the model**
    - Deeper decoder to improve image fidelity
    - Try incorporating embryoNET as the encoder
    
<br>

4. **Test VAE variants to improve biological salience of latent space**
    - TC-VAE to inforce statistical independent latent variables
    - Apply metric constraint to incentivize model to learn representations of key biological variables (time, genotype) while minimizing the impact of "nuisance variables (embryo pose, experimental variability, etc.) 