## Comparing VAE architectures
This notebook compares the performance of different VAE architectures. Specifically, we are testing how model depth (num convolutional layers) and the size of the latent space impact:
1. Image reconstruction quality
2. Model generalizability
3. Biological information content of the latent space

In [None]:
import os
from pythae.models import AutoModel
import numpy as np
import glob as glob
from functions.utilities import path_leaf

#### Get paths to data, figures, and latent space outputs

In [None]:
# root = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/"
root = "E:\\Nick\\Dropbox (Cole Trapnell's Lab)\\Nick\\morphseq\\"

train_name = "20230815_vae"
train_dir = os.path.join(root, "training_data", train_name, '')
model_path_list = sorted(glob.glob(train_dir + '*depth*'))
model_name_list = [path_leaf(m) for m in model_path_list]

output_dir = os.path.join(train_dir, "figures")
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

# # get path to model
# last_training = sorted(os.listdir(output_dir))[-1]


# # path to figures and data
# figure_path = os.path.join(output_dir, last_training, "figures")
# out_figure_path = os.path.join(output_dir, last_training, "figures", "model_assessment")
# if not os.path.isdir(out_figure_path):
#     os.makedirs(out_figure_path)

#### Create DataLoader objects for train, eval, and test sets
- "Train" data were used to train the model
- "Eval" data were used to assess model during training
- "Test" data were untouched during training process

In [None]:
import pandas as pd
from functions.pythae_utils import *

model_name = model_name_list[-3]
    
mdir = os.path.join(train_dir, model_name) 

last_training = sorted(os.listdir(mdir))[-1]

trained_model = AutoModel.load_from_folder(
    os.path.join(mdir, last_training, 'final_model'))

m_fig_path = os.path.join(mdir, last_training, "figures")
model_figpath_list.append(m_fig_path)

# load data frame with results
morph_df = pd.read_csv(os.path.join(m_fig_path, "embryo_stats_df.csv"), index_col=0)

# mode_vec = ["train", "eval", "test"]
# data_sampler_vec = []

main_dims = (576, 256)
data_transform = make_dynamic_rs_transform(main_dims)
        
train_dataset = MyCustomDataset(
    root=os.path.join(train_dir, "train"),
    transform=data_transform,
    return_name=True
)

In [None]:
from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
from pythae.samplers import NormalSampler

## Draw naive samples using the Gaussian prior

In [None]:
# create normal sampler
normal_samper = NormalSampler(
    model=trained_model
)

# sample
gen_data = normal_samper.sample(
    num_samples=25
)

# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)
plt.show()

## Fit a GMM and see if this leads to more plausible samples

In [None]:
# set up GMM sampler config
gmm_sampler_config = GaussianMixtureSamplerConfig(
    n_components=10
)

# create gmm sampler
gmm_sampler = GaussianMixtureSampler(
    sampler_config=gmm_sampler_config,
    model=trained_model
)

n_images = len(train_dataset)
train_stack = np.empty((n_images, 1, main_dims[0], main_dims[1]))

for t in range(n_images):
    train_stack[t, 0, :, :] = np.asarray(train_dataset[t][0]).tolist()[0]
    
# fit the sampler
gmm_sampler.fit(train_stack)

In [None]:
# sample using GMM
gen_data = gmm_sampler.sample(
    num_samples=25
)

In [None]:
from matplotlib import pyplot as plt

# show results with gmm sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)
plt.show()