## Exploring morph VAE output
This notebook generates visualizations and conducts analyses to assess the biological content of the latent space representations learned by our VAE models

In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo

pyo.init_notebook_mode()

#### Get paths to data, figures, and latent space outputs

In [None]:
root = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/"
# root = "E:\\Nick\\Dropbox (Cole Trapnell's Lab)\\Nick\\morphseq\\"
train_name = "20230915_vae"
# train_name = "20231106_ds"
# /Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/training_data/20230915_vae_flipped/z100_bs032_ne100_depth05
# model_name = "20230804_vae_full_conv_z25_bs032_ne100_depth05"
model_name = "z100_bs032_ne250_depth05_out16_temperature_sweep2"
# model_name = "z100_bs064_ne250_depth05_out16_class_ignorance_test"
train_dir = os.path.join(root, "training_data", train_name)
output_dir = os.path.join(train_dir, model_name) 

# get path to model
last_training = sorted(os.listdir(output_dir))[-1]

# path to figures and data
figure_path = os.path.join(output_dir, last_training, "figures")
out_figure_path = os.path.join("/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/slides/20231130/")
if not os.path.isdir(out_figure_path):
    os.makedirs(out_figure_path)

In [None]:
umap_df = pd.read_csv(os.path.join(figure_path, "umap_df.csv"), index_col=0)
umap_df["UMAP_00_bio_3"] = -umap_df["UMAP_00_bio_3"] 
from scipy.interpolate import LinearNDInterpolator
from scipy import ndimage

In [None]:
umap_df_interp_list = []
snip_vec = np.asarray([umap_df.loc[i, "snip_id"][:-10] for i in range(umap_df.shape[0])])
snip_index = np.unique(snip_vec)
# tres = 0.25
hpf_interp_vec = np.arange(0, 72, 0.5)

for s, snip in enumerate(snip_index):
    s_indices = np.where(snip_vec==snip)[0]
    
    u0 = umap_df.loc[s_indices, "UMAP_00_bio_3"].to_numpy()
    u1 = umap_df.loc[s_indices, "UMAP_01_bio_3"].to_numpy()
    u2 = umap_df.loc[s_indices, "UMAP_02_bio_3"].to_numpy()
    
    t = umap_df.loc[s_indices, "predicted_stage_hpf"].to_numpy()
    
    # interpolate
    interp_ref_vec = hpf_interp_vec[(hpf_interp_vec>=t[0]) & (hpf_interp_vec<=t[-1])]
    u0_interp = np.interp(interp_ref_vec, t, u0)
    u1_interp = np.interp(interp_ref_vec, t, u1)
    u2_interp = np.interp(interp_ref_vec, t, u2)
    
    umap_array_interp = np.concatenate((u0_interp[:, np.newaxis], u1_interp[:, np.newaxis]
                                   ,u2_interp[:, np.newaxis], interp_ref_vec[:, np.newaxis]), 
                                  axis = 1)
    
    df_temp = pd.DataFrame(umap_array_interp, columns=["umap00", "umap01", "umap02", "hpf"])
    df_temp["snip_id"] = snip
    df_temp["master_perturbation"] = umap_df.loc[s_indices, "master_perturbation"]
    
    umap_df_interp_list.append(df_temp)
    
umap_df_interp = pd.concat(umap_df_interp_list, axis=0)

## Use kmeans to divide UMAP space into 50 discrete morphology "states"

In [None]:
from sklearn.cluster import KMeans

# umap_df_nn = umap_df_interp.dropna()

# lets aim for 100 obs per cluster
# COARSE LEVEL IS USED ONLY TO FIX POSITIONS OF GAUSSIAN KERNELS
n_states = 50

# cluster
kmeans_out = KMeans(n_clusters=n_states, random_state=0, n_init="auto").fit(
        umap_df_interp.loc[:, ["umap00", "umap01"]])

umap_df_interp.loc[:, "kmeans_label"] = kmeans_out.labels_
# cols_to_average = ["UMAP_00_bio_3", "UMAP_01_bio_3", "UMAP_02_bio_3", 
#                    "UMAP_00_bio_3_vel", "UMAP_01_bio_3_vel", "UMAP_02_bio_3_vel", "predicted_stage_hpf"]
# avg_vel_df = umap_df_wt.loc[:, cols_to_average + ["kmeans_label"]].groupby("kmeans_label").mean()



#### Now, it should be possible to calculate empirical transition matrices for N times steps

In [None]:
n_steps_max = 50
transition_count_array = np.zeros((n_steps_max+1, n_states, n_states)) + 1e-3
transition_count_array[0, :, :] = np.eye(n_states)
snip_vec = umap_df_interp["snip_id"].to_numpy()

for s, snip in enumerate(snip_index):
    s_indices = np.where(snip_vec==snip)[0]
    state_vec = umap_df_interp["kmeans_label"].iloc[s_indices].to_numpy()
    sl = len(state_vec)
    for n in range(n_steps_max):
        tr_temp = transition_count_array[n+1, :, :].copy()
        for s in range(sl-n-1):
            fr = state_vec[s]
            to = state_vec[s+n+1]
            tr_temp[to, fr] += 1
        transition_count_array[n+1, :, :] = tr_temp    
        


In [None]:
avg_df = umap_df_interp.groupby("kmeans_label").mean()
avg_df

In [None]:
# umap_df_interp["umap00"] = -umap_df_interp["umap00"] 

fig = px.scatter(umap_df_interp, x="umap00", y="umap01", 
                    color="kmeans_label", opacity=0.25)

fig.update_xaxes(range=[-20, 0])
fig.update_yaxes(range=[0, 12])

fig.show()
fig.write_image(os.path.join(out_figure_path, "UMAP_wt_states.png"), scale=2)
#, color_continuous_scale="magma")

In [None]:
import plotly.offline as pyo

pyo.init_notebook_mode()

state_id = 5
# n_steps = 10

hm_dir = os.path.join(out_figure_path, "tr_frames_emp", '')
if not os.path.isdir(hm_dir):
    os.makedirs(hm_dir)
    

step_vec = range(0, 48, 1)

# transition_prob_array = np.divide(transition_count_array, np.sum(transition_count_array, axis=2))

for n_step in step_vec:
    
    
    s_emp = transition_count_array[n_step, :, :]
    s_emp = s_emp / np.sum(s_emp, axis=0)
#     A_pd = np.squeeze(transition_prob_array[1, :, :])
#     A_pd = A_pd / np.sum(A_pd, axis=0)
#     state_vec = np.zeros((n_states, 1))
#     state_vec[state_id] = 1
#     for i in range(1, n_step):
#         state_vec = np.matmul(A_pd, state_vec)
    
    
#     color_vec_norm = state_vec[umap_df_interp["kmeans_label"].to_numpy()]#, state_id]



#     color_vec_norm = color_vec# / np.max(color_vec)
    color_vec_norm = s_emp[umap_df_interp["kmeans_label"].to_numpy(), state_id]

    fig = px.scatter(umap_df_interp, x="umap00", y="umap01", 
                        color=color_vec_norm.flatten(), opacity=0.5, color_continuous_scale="Blues", 
                        range_color=[0,0.15], title="predicted state probabilities (" + str(np.round(n_step/2,1)) + " hrs)")


#     fig.add_trace(go.Scatter(x=umap_df_interp["umap00"].iloc[kmeans_out.labels_==state_id],
#                                y=umap_df_interp["umap01"].iloc[kmeans_out.labels_==state_id],
#                                mode="markers", 
#                                marker=dict(size=5, opacity=1, line=dict(color="rgba(70, 70, 70, 0.5)"))))
    
#     fig.update_xaxes(range=[-20, 0])
    fig.update_layout(
        coloraxis_colorbar=dict(
            title="state probability",
        ),
    )


#     fig.show()

    fig.write_image(os.path.join(hm_dir, f"_UMAP_wt_states_pd_s{n_step:003}.png"), scale=2)

In [None]:
import plotly.offline as pyo

pyo.init_notebook_mode()

state_id = 5
# n_steps = 10

hm_dir = os.path.join(out_figure_path, "tr_frames", '')
if not os.path.isdir(hm_dir):
    os.makedirs(hm_dir)
    

step_vec = range(0, 24, 1)

A_pd = np.squeeze(transition_count_array[2, :, :])
A_pd /= np.sum(A_pd, axis=0)

# A_pd = np.linalg.matrix_power(A_base, 1/4)
# transition_prob_array = np.divide(transition_count_array, np.sum(transition_count_array, axis=2))
state_vec = np.zeros((n_states, 1))
state_vec[state_id] = 1

for n_step in step_vec:
    
    
#     s_emp = transition_count_array[n_step, :, :]
#     s_emp = s_emp / np.sum(s_emp, axis=0)
#     A_pd = np.squeeze(transition_prob_array[1, :, :])
#     A_pd = A_pd / np.sum(A_pd, axis=0)

    if n_step > 0:
        state_vec = np.matmul(A_pd, state_vec)
    
    
#     color_vec_norm = state_vec[umap_df_interp["kmeans_label"].to_numpy()]#, state_id]



#     color_vec_norm = color_vec# / np.max(color_vec)
    color_vec_norm = state_vec[umap_df_interp["kmeans_label"].to_numpy()]

    fig = px.scatter(umap_df_interp, x="umap00", y="umap01", 
                        color=color_vec_norm.flatten(), opacity=0.5, color_continuous_scale="Blues", 
                        range_color=[0,0.15], title="predicted state probabilities (" + str(np.round(n_step,1)) + " hrs)")


#     fig.add_trace(go.Scatter(x=umap_df_interp["umap00"].iloc[kmeans_out.labels_==state_id],
#                                y=umap_df_interp["umap01"].iloc[kmeans_out.labels_==state_id],
#                                mode="markers", 
#                                marker=dict(size=5, opacity=1, line=dict(color="rgba(70, 70, 70, 0.5)"))))
    
#     fig.update_xaxes(range=[-20, 0])
    fig.update_layout(
        coloraxis_colorbar=dict(
            title="state probability",
        ),
    )


#     fig.show()

    fig.write_image(os.path.join(hm_dir, f"_UMAP_wt_states_pd_s{n_step:003}.png"), scale=2)

In [None]:
px.imshow(A_pd)

In [None]:
# define function to calculate predicted potential at each point in 3D space

def predict_U_array(xyz_array, sigma_array, amp_array, Xg, Yg, Zg):
    
    U_array = np.zeros(Xg.shape)
    for i in range(xyz_array.shape[0]):
        xyz = xyz_array[i, :]
        sig = sigma_array[i, :]
        U_array += -amp_array[i]*np.exp(-0.5*(((Xg-xyz[0])/sig[0])**2 + 
                                             ((Yg-xyz[1])/sig[1])**2 + 
                                             ((Zg-xyz[2])/sig[2])**2))

    return U_array

In [None]:
def predict_dU_array(xyz_array, sigma_array, amp_array, Xg, Yg, Zg):
    
    dUdX = np.zeros(Xg.shape)
    dUdY = np.zeros(Yg.shape)
    dUdZ = np.zeros(Zg.shape)
    for i in range(xyz_array.shape[0]):
        xyz = xyz_array[i, :]
        sig = sigma_array[i, :]
        gv = -amp_array[i]*np.exp(-0.5*(((Xg-xyz[0])/sig[0])**2 + 
                                             ((Yg-xyz[1])/sig[1])**2 + 
                                             ((Zg-xyz[2])/sig[2])**2))
        
        dUdX += (xyz[0]-Xg)/sig[0]**2 * gv
        dUdY += (xyz[1]-Yg)/sig[1]**2 * gv
        dUdZ += (xyz[2]-Zg)/sig[2]**2 * gv

    return dUdX, dUdY, dUdZ

In [None]:
def predict_U_km(xyz_array, sigma_array, amp_array, xyz_km):
    
    U_array = np.zeros((xyz_km.shape[0],))
    for i in range(xyz_array.shape[0]):
        xyz = xyz_array[i, :]
        sig = sigma_array[i, :]
        U_array += -amp_array[i]*np.exp(-0.5*(((xyz_km[:, 0]-xyz[0])/sig[0])**2 + 
                                             ((xyz_km[:, 1]-xyz[1])/sig[1])**2 + 
                                             ((xyz_km[:, 2]-xyz[2])/sig[0])**2))

    return U_array

In [None]:
def predict_dU_km(xyz_array, sigma_array, amp_array, xyz_km):
    
    dU_array = np.zeros((xyz_km.shape))
    for i in range(xyz_array.shape[0]):
        xyz = xyz_array[i, :]
        sig = sigma_array[i, :]
        gv = -amp_array[i]*np.exp(-0.5*(((xyz_km[:, 0]-xyz[0])/sig[0])**2 + 
                                             ((xyz_km[:, 1]-xyz[1])/sig[1])**2 + 
                                             ((xyz_km[:, 2]-xyz[2])/sig[2])**2))
        
        dU_array[:, 0] += (xyz_km[:, 0]-xyz[0])/sig[0]**2 * gv
        dU_array[:, 1] += (xyz_km[:, 1]-xyz[1])/sig[1]**2 * gv
        dU_array[:, 2] += (xyz_km[:, 2]-xyz[2])/sig[2]**2 * gv

    return dU_array

In [None]:
sigma_array = np.random.rand(xyz_gauss.shape[0], xyz_gauss.shape[1])*5
amp_array = np.random.rand(xyz_gauss.shape[0])

U_out = predict_U_array(xyz_gauss, sigma_array, amp_array, Xg, Yg, Zg)
fig = go.Figure(data=go.Volume(
    x=Xg.flatten(), y=Yg.flatten(), z=Zg.flatten(),
    value=U_out.flatten(),
#     isomin=0.25,
#     isomax=0.7,
    opacity=0.25,
    surface_count=25,
    colorscale="ice"
    ))

fig.show()


In [None]:
xyz_fine.shape

In [None]:
def objective(param_vec):
    
    param_array = np.reshape(param_vec, (xyz_gauss.shape[0], 4))
    dU_pd = predict_dU_km(xyz_gauss, param_array[:, :-1], param_array[:, -1], xyz_fine) 
    
    return np.sum((dU_pd - du_xyz_fine)**2)

In [None]:
objective(x0)

In [None]:
from scipy.optimize import minimize
x0 = np.random.rand(xyz_gauss.shape[0], 4)*5
bnd1 = (0.1, None)
bnd2 = (None, None)
bnds = ((bnd1, )*xyz_gauss.size) + ((bnd2, )*xyz_gauss.shape[0])

solution = minimize(objective, x0.flatten(), options={'maxiter':5000}, bounds=bnds)
# objective(x0)              

In [None]:
solution

In [None]:
sol_vec = solution.x
param_array_sol = np.reshape(sol_vec, (xyz_gauss.shape[0], 4))

U_pd = predict_U_array(xyz_gauss, param_array_sol[:, :-1], param_array_sol[:, -1], Xg, Yg, Zg) 
dU_pd = predict_dU_array(xyz_gauss, param_array_sol[:, :-1], param_array_sol[:, -1], Xg, Yg, Zg) 

fig = go.Figure(data=go.Volume(
    x=Xg.flatten(), y=Yg.flatten(), z=Zg.flatten(),
    value=U_pd.flatten(),
#     isomin=0.25,
#     isomax=0.7,
    opacity=0.25,
    surface_count=25,
    colorscale="ice"
    ))
fig.update_layout(template="plotly")
fig.show()

In [None]:
px.imshow(U_pd[:, :, 25])

In [None]:
dUdX = np.diff(U_out, axis=0)
dUdY = np.diff(U_out, axis=1)
dUdZ = np.diff(U_out, axis=2)

fig = go.Figure(data=go.Volume(
    x=Xg[:-1, :, :].flatten(), 
    y=Yg[:-1, :, :].flatten(), 
    z=Zg[:-1, :, :].flatten(),
    value=dUdX.flatten(),
#     isomin=0.25,
#     isomax=0.7,
    opacity=0.25,
    surface_count=25,
    colorscale="ice"
    ))

fig.show()

In [None]:
dU_out = predict_dU_array(xyz_gauss, sigma_array, amp_array, Xg, Yg, Zg)

fig = go.Figure(data=go.Volume(
    x=Xg.flatten(), 
    y=Yg.flatten(), 
    z=Zg.flatten(),
    value=dU_out[0].flatten(),
#     isomin=0.25,
#     isomax=0.7,
    opacity=0.25,
    surface_count=25,
    colorscale="ice"
    ))

fig.show()

In [None]:
def objective(x):
    return (x[0] - 3)**2 + (x[1] - 4)**2

In [None]:
Zg[:-1, :, :].shape

In [None]:
df = avg_vel_df.copy()
max_v = 1.5
df.loc[df["UMAP_00_bio_3_vel"]>max_v, ["UMAP_00_bio_3_vel"]] = max_v
df.loc[df["UMAP_01_bio_3_vel"]>max_v, ["UMAP_01_bio_3_vel"]] = max_v
df.loc[df["UMAP_02_bio_3_vel"]>max_v, ["UMAP_02_bio_3_vel"]] = max_v

df.loc[df["UMAP_00_bio_3_vel"]<-max_v, ["UMAP_00_bio_3_vel"]] = -max_v
df.loc[df["UMAP_01_bio_3_vel"]<-max_v, ["UMAP_01_bio_3_vel"]] = -max_v
df.loc[df["UMAP_02_bio_3_vel"]<-max_v, ["UMAP_02_bio_3_vel"]] = -max_v


fig = go.Figure(data = go.Cone(
    x=df["UMAP_00_bio_3"],
    y=df["UMAP_01_bio_3"],
    z=df["UMAP_02_bio_3"],
    u=df["UMAP_00_bio_3_vel"],
    v=df["UMAP_01_bio_3_vel"],
    w=df["UMAP_02_bio_3_vel"],
    colorscale='Blues',
    sizemode="absolute",
        sizeref=1))

fig.update_layout(scene=dict(aspectratio=dict(x=1, y=1, z=0.8),
                             camera_eye=dict(x=1.2, y=1.2, z=0.6)))

fig.show()

### Experiment with using simple linear interpolation to obtain an estimate for the developmental "potential"

In [None]:
umap_df_all = umap_df.dropna()
n_points = umap_df_all.shape[0]

# lets aim for 25 obs per cluster
n_points_per_cluster = 24
n_clusters = int(np.floor(n_points/ n_points_per_cluster))

# cluster
kmeans_out = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit(
        umap_df_all.loc[:, ["UMAP_00_bio_3", "UMAP_01_bio_3", "UMAP_02_bio_3"]])

umap_df_all.loc[:, "kmeans_label"] = kmeans_out.labels_
cols_to_average = ["UMAP_00_bio_3", "UMAP_01_bio_3", "UMAP_02_bio_3", 
                   "UMAP_00_bio_3_vel", "UMAP_01_bio_3_vel", "UMAP_02_bio_3_vel", "predicted_stage_hpf"]
avg_vel_df_all = umap_df_all.loc[:, cols_to_average + ["kmeans_label"]].groupby("kmeans_label").mean()

In [None]:
fig = ff.create_quiver(x=avg_vel_df_all["UMAP_00_bio_3"], y=avg_vel_df_all["UMAP_01_bio_3"], 
                       u=avg_vel_df_all["UMAP_00_bio_3_vel"], v=avg_vel_df_all["UMAP_01_bio_3_vel"],
                       scale=0.25, showlegend=False)

fig.add_trace(go.Scatter(x=avg_vel_df_all["UMAP_00_bio_3"], y=avg_vel_df_all["UMAP_01_bio_3"],
                    mode='markers',
                    marker=dict(color=avg_vel_df_all["predicted_stage_hpf"], size=4),
                        showlegend=False))

fig.show()

In [None]:
X = avg_vel_df_all["UMAP_00_bio_3"].to_numpy()
Y = avg_vel_df_all["UMAP_01_bio_3"].to_numpy()
Z = avg_vel_df_all["UMAP_02_bio_3"].to_numpy()

dX = avg_vel_df_all["UMAP_00_bio_3_vel"].to_numpy()
dY = avg_vel_df_all["UMAP_01_bio_3_vel"].to_numpy()
dZ = avg_vel_df_all["UMAP_02_bio_3_vel"].to_numpy()

n_bins = 30
xx = np.linspace(min(X), max(X), num=n_bins+1)
yy = np.linspace(min(Y), max(Y), num=n_bins+1)
zz = np.linspace(min(Z), max(Z), num=n_bins+1)

Xg, Yg, Zg = np.meshgrid(xx, yy, zz)  # 3D grid for interpolation

xyz_grid_long = np.concatenate((X[:, np.newaxis], Y[:, np.newaxis], Z[:, np.newaxis]), axis=1)

# interpolate each direction
interp_dx = LinearNDInterpolator(xyz_grid_long, dX.flatten(), fill_value=0)
dXI = interp_dx(Xg, Yg, Zg)

interp_dy = LinearNDInterpolator(xyz_grid_long, dY.flatten(), fill_value=0)
dYI = interp_dy(Xg, Yg, Zg)

interp_dz = LinearNDInterpolator(xyz_grid_long, dZ.flatten(), fill_value=0)
dZI = interp_dz(Xg, Yg, Zg)

# dXInn = dXI.copy()
# dXIzeros[np.where(dXI==0)]
# dXIzeros = dXI
# dXIzeros = dXI

# Apply mild gaussian smoothing
dXIS = ndimage.gaussian_filter(dXI, 1)
dYIS = ndimage.gaussian_filter(dYI, 1)
dZIS = ndimage.gaussian_filter(dZI, 1)

# calculate overall potential gradient magnitudes
UM = np.sqrt(dXIS**2 + dYIS**2 + dZIS**2)


fig = go.Figure(data=go.Volume(
    x=Xg.flatten(), y=Yg.flatten(), z=Zg.flatten(),
    value=UM.flatten(),
    isomin=0.25,
#     isomax=0.7,
    opacity=0.25,
    surface_count=25,
    colorscale="ice"
    ))

fig.update_layout(template="plotly")
# fig.update_layout(scene_xaxis_showticklabels=False,
#                   scene_yaxis_showticklabels=False,
#                   scene_zaxis_showticklabels=False)
fig.show()

In [None]:
# now let's try a crude integration approach to get a representation of the potential itself