#### Use CCA to look for axes of correspondence between morph and seq modalities

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

### Set paths

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

# path to save data
out_path = os.path.join(root, "results", "20250312", "morph_latent_space", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morphseq_cca/"
os.makedirs(fig_path, exist_ok=True)

### Load data

In [None]:
# morph latent encodings
morph_df = pd.read_csv(out_path + "hf_pca_morph_df.csv")

# hooke latent encodings
seq_df = pd.read_csv(out_path + "hf_seq_df.csv", index_col=0)

# metadata df that allows us to link the two
morphseq_df = pd.read_csv(os.path.join(root, "metadata", "morphseq_metadata.csv"))

### Subset for hotfish2 

In [None]:
import re
hf_experiments = np.asarray(["20240813_24hpf", "20240813_30hpf", "20240813_36hpf"])
hf_morphseq_df = morphseq_df.loc[np.isin(morphseq_df["experiment_date"], hf_experiments), :].reset_index(drop=True)

# subset morph 
# mu_cols = [col for col in morph_df.columns.tolist() if "z_mu_b" in col]
pattern = r"PCA_.*_bio"
pca_cols_morph = [col for col in morph_df.columns if re.search(pattern, col)]
pca_cols_seq = [col for col in seq_df.columns if "PCA" in col]

hf_morph_df = pd.DataFrame(hf_morphseq_df.loc[:, ["snip_id", "sample"]]).merge(morph_df, how="inner", on="snip_id")
hf_morph_df = hf_morph_df.set_index("snip_id")
hf_morph_df = hf_morph_df.loc[:, pca_cols_morph + ["sample"]]


# subset seq dataset
hf_seq_df = pd.DataFrame(hf_morph_df.loc[:, "sample"]).merge(seq_df, how="inner", right_index=True, left_on="sample")
hf_seq_df = hf_seq_df.set_index("sample")
print(hf_seq_df.shape)

# get rid of sample col
hf_morph_df = hf_morph_df.drop(labels=["sample"], axis=1)
print(hf_morph_df.shape)

# filter out a couple observations that had QC problems
hf_morphseq_df = hf_morphseq_df.loc[np.isin(hf_morphseq_df["snip_id"], hf_morph_df.index), :].reset_index()
hf_morphseq_df = hf_morphseq_df.merge(morph_df.loc[:, ["snip_id", "mdl_stage_hpf"]])
print(hf_morphseq_df.shape)

### make arrhenius plots

In [None]:
from sklearn.metrics import mean_squared_error

# get cohort averages
morph_df_true = hf_morph_df.copy().reset_index()
morph_df_true = morph_df_true.merge(morphseq_df.loc[:, ["snip_id", "timepoint", "temperature"]], how="left", on="snip_id")
morph_df_mean = morph_df_true.drop(labels=["snip_id"], axis=1).groupby(["temperature", "timepoint"]).agg(["mean"])

# Flatten the MultiIndex columns to a single level:
morph_df_mean.columns = [f"{col[0]}_{col[1]}" for col in morph_df_mean.columns]
morph_df_mean = morph_df_mean.reset_index()

# merge back to original obs
morph_df_null = morph_df_true.loc[:, ["snip_id", "timepoint", "temperature"]].merge(
                morph_df_mean, how="left", on=["timepoint", "temperature"])

# extract just the PCA values to compare
Y_pd = morph_pd_df[mean_cols].values
Y_mean = morph_df_null[mean_cols].values
Y_true = morph_df_true[pca_cols_morph[:n_dim_out]].values

# calculate mse
pd_error = (Y_true-Y_pd)**2
null_error = (Y_true-Y_mean)**2

# convert to DFz
pd_df = pd.DataFrame(pd_error, columns=pca_cols_morph[:n_dim_out])
pd_df["total_se"] = np.sqrt(np.sum(pd_df[pca_cols_morph[:n_dim_out]], axis=1))
pd_df["timepoint"] = morph_df_true["timepoint"].to_numpy()
pd_df["temperature"] = morph_df_true["temperature"].to_numpy()
pd_df_mean = pd_df.groupby(["temperature", "timepoint"]).agg(["mean"])
pd_df_mean.columns = [f"{col[0]}_{col[1]}" for col in pd_df_mean.columns]
pd_df_mean = pd_df_mean.reset_index()

null_df = pd.DataFrame(null_error, columns=pca_cols_morph[:n_dim_out])
null_df["total_se"] = np.sqrt(np.sum(null_df[pca_cols_morph[:n_dim_out]], axis=1))
null_df["timepoint"] = morph_df_true["timepoint"].to_numpy()
null_df["temperature"] = morph_df_true["temperature"].to_numpy()
null_df_mean = null_df.groupby(["temperature", "timepoint"]).agg(["mean"])
null_df_mean.columns = [f"{col[0]}_{col[1]}" for col in null_df_mean.columns]
null_df_mean = null_df_mean.reset_index()

In [None]:
ind = 0
fig = px.scatter(pd_df_mean, x="total_se_mean", y=null_df_mean["total_se_mean"], color="temperature", symbol="timepoint")
                # log_x=True, log_y=True)
fig.update_traces(marker=dict(size=8))
fig.update_layout(width=1000, height=800)
fig.update_xaxes(range=[0, 4])
fig.update_yaxes(range=[0, 4])
fig.show()

In [None]:
pca_cols_morph

In [None]:
mean_cols