#### Use CCA to look for axes of correspondence between morph and seq modalities

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob
from sklearn.cross_decomposition import CCA

### Set paths

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20241107_ds"
model_name = "SeqVAE_z100_ne150_sweep_01_block01_iter030" 
train_dir = os.path.join(root, "training_data", train_name, "")
output_dir = os.path.join(train_dir, model_name) 

# get path to morph model
training_path = sorted(glob(os.path.join(output_dir, "*")))[-1]
training_name = os.path.dirname(training_path)
morph_read_path = os.path.join(training_path, "figures", "")

# set path to hooke projections
hooke_model_name = "bead_expt_linear"
latent_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/emb_projections/latent_projections/"
hooke_model_path = os.path.join(latent_path, hooke_model_name, "")

# path to save data
out_path = os.path.join(root, "results", "20240303", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morphseq_cca/"
os.makedirs(fig_path, exist_ok=True)

### Load data

In [None]:
# morph latent encodings
morph_df = pd.read_csv(out_path + "hf_morph_df.csv", index_col=0)

# hooke latent encodings
seq_df = pd.read_csv(out_path + "hf_seq_df.csv", index_col=0)

# metadata df that allows us to link the two
morphseq_df = pd.read_csv(os.path.join(root, "metadata", "morphseq_metadata.csv"))

# load spline datasets for each space--we will use these to pretrain our MLP
morph_spline_df = pd.read_csv(out_path + "spline_morph_df.csv")
morph_spline_df = morph_spline_df.set_index("stage_hpf")
seq_spline_df = pd.read_csv(out_path + "spline_seq_df.csv")
seq_spline_df = seq_spline_df.set_index("stage_hpf")

### Subset for hotfish2 

In [None]:
import re
hf_experiments = np.asarray(["20240813_24hpf", "20240813_30hpf", "20240813_36hpf"])
hf_morphseq_df = morphseq_df.loc[np.isin(morphseq_df["experiment_date"], hf_experiments), :].reset_index(drop=True)

# subset morph 
# mu_cols = [col for col in morph_df.columns.tolist() if "z_mu_b" in col]
pattern = r"PCA_.*_bio"
pca_cols_morph = [col for col in morph_df.columns if re.search(pattern, col)]
pca_cols_seq = [col for col in seq_df.columns if "PCA" in col]

hf_morph_df = pd.DataFrame(hf_morphseq_df.loc[:, ["snip_id", "sample"]]).merge(morph_df, how="inner", on="snip_id")
hf_morph_df = hf_morph_df.set_index("snip_id")
hf_morph_df = hf_morph_df.loc[:, pca_cols_morph + ["sample"]]


# subset seq dataset
hf_seq_df = pd.DataFrame(hf_morph_df.loc[:, "sample"]).merge(seq_df, how="inner", right_index=True, left_on="sample")
hf_seq_df = hf_seq_df.set_index("sample")
print(hf_seq_df.shape)

# get rid of sample col
hf_morph_df = hf_morph_df.drop(labels=["sample"], axis=1)
print(hf_morph_df.shape)

# filter out a couple observations that had QC problems
hf_morphseq_df = hf_morphseq_df.loc[np.isin(hf_morphseq_df["snip_id"], hf_morph_df.index), :].reset_index()
print(hf_morphseq_df.shape)

### Extract spline and obs columns to fit

In [None]:
from sklearn.decomposition import PCA
n_components = len(pca_cols_morph) # captures over 99% of variance in both modalities

# fit morph PCA
# morph_pca = PCA(n_components=n_components)
# morph_pca.fit(hf_morph_df)

# get morph array
morph_pca = hf_morph_df[pca_cols_morph].to_numpy() #morph_pca.transform(hf_morph_df)

# get morph spline
morph_spline_pca = morph_spline_df[pca_cols_morph].to_numpy()

# get seq array
seq_pca = hf_seq_df[pca_cols_seq].to_numpy() #morph_pca.transform(hf_morph_df)

# get seq spline
seq_spline_pca = seq_spline_df[pca_cols_seq].to_numpy()

### Visualize the two latent spaces

In [None]:
fig = px.scatter_3d(x=morph_pca[:, 0], y=morph_pca[:, 1], z=morph_pca[:, 2], color=hf_morphseq_df["temperature"])

fig.add_traces(go.Scatter3d(x=morph_spline_pca[:, 0], y=morph_spline_pca[:, 1], z=morph_spline_pca[:, 2], mode="lines"))

fig.update_traces(marker=dict(size=5))
fig.update_layout(title="morphology space")
fig.show()

In [None]:
fig = px.scatter_3d(x=seq_pca[:, 0], y=seq_pca[:, 1], z=seq_pca[:, 2], 
                     color=hf_morphseq_df["temperature"], hover_data=[hf_morphseq_df["stage_hpf"]])

fig.add_traces(go.Scatter3d(x=seq_spline_pca[:, 0], y=seq_spline_pca[:, 1], z=seq_spline_pca[:, 2], mode="lines"))

fig.update_traces(marker=dict(size=5))
fig.update_layout(title="transcriptional space")
fig.show()

### Fit MLP. We will do this in 2 phases: first pretrain by fitting the WT splines, then fine tune with HF data

In [None]:
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error

# # ndim_out = 3
# # Initialize the MLPRegressor with warm_start enabled.
# # warm_start=True allows the model to continue training from its current state.
# mlp_pre = MLPRegressor(hidden_layer_sizes=(50, 50),  
#                        warm_start=True,    # Retain previous weights between fit calls
#                        max_iter=50000,       # Number of iterations per fit call
#                        random_state=42)

# # Pretrain on the spline data to learn the dominant trend
# n_dim_out = 4
# n_dim_in = 5
# XS = seq_spline_pca[:, :n_dim_in]
# YS = morph_spline_pca[:, :n_dim_out]

# mlp_pre.fit(XS, YS)
# # m_spline_pd = mlp_pre.predict(XS)
# print("Pretraining MSE on spline data:", mlp_pre.score(XS, YS))

In [None]:
n_dim_out = 3
n_dim_in = 5

X = seq_pca[:, :n_dim_in]
y = morph_pca[:, :n_dim_out]

# split into train/test
test_frac = 0.15
n_total = X.shape[0]
n_test = int(test_frac*n_total)
indices = np.arange(n_total)
test_indices = np.random.choice(indices, n_test, replace=False)
train_indices = indices[~np.isin(indices, test_indices)]

# X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1, test_size=0.25)
X_train = X[train_indices]
X_test = X[test_indices]
Y_train = y[train_indices]
Y_test = y[test_indices]

# fine tune
mlp = MLPRegressor(random_state=42, max_iter=2000000, hidden_layer_sizes=(50, 50))
mlp.fit(X_train, Y_train)

# see how well we did
morph_train_pred = mlp.predict(X_train)
# print("Fine-tuned MSE on train data:", mlp_pre.score(X_train, Y_train))
print("Raw MSE on train data:", mlp.score(X_train, Y_train))
morph_test_pred = mlp.predict(X_test)
# print("Fine-tuned MSE on test data:", mlp_pre.score(X_test, Y_test))
print("Raw MSE on test data:", mlp.score(X_test, Y_test))

In [None]:
fig = px.scatter_3d(x=morph_test_pred[:, 0], y=morph_test_pred[:, 1], z=morph_test_pred[:, 2], 
                    color=hf_morphseq_df.loc[test_indices, "temperature"])
fig.update_traces(marker=dict(size=5))
fig.show()

In [None]:
fig = px.scatter_3d(x=morph_train_pred[:, 0], y=morph_train_pred[:, 1], z=morph_train_pred[:, 2], 
                    color=hf_morphseq_df.loc[train_indices, "temperature"])
fig.update_traces(marker=dict(size=5))
fig.show()