### This notebook fits a reference spline to HF and AB reference data

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

In [None]:

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20241107_ds"
model_name = "SeqVAE_z100_ne150_sweep_01_block01_iter030" 
train_dir = os.path.join(root, "training_data", train_name, "")
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = sorted(glob(os.path.join(output_dir, "*")))[-1]
training_name = os.path.dirname(training_path)
read_path = os.path.join(training_path, "figures", "")

# path to save data
out_path = os.path.join(root, "results", "20240310", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/"
os.makedirs(fig_path, exist_ok=True)

In [None]:
morph_df = pd.read_csv(read_path + "embryo_stats_df.csv", index_col=0)
# umap_df = pd.read_csv(read_path + "umap_df.csv", index_col=0)
# print(umap_df.shape)
# umap_df = umap_df.merge(morph_df.loc[:, ["snip_id", "embryo_id", "experiment_time"]], how="left", on=["snip_id"])
# print(umap_df.shape)

### Make 3D UMAP and PCA for hotfish experiments

In [None]:
HF_experiments = np.asarray(['20240813_24hpf', '20240813_30hpf', '20240813_36hpf']) #, '20240813_extras'])
hf_morph_df = morph_df.loc[np.isin(morph_df["experiment_date"], HF_experiments), :].reset_index()
# hf_umap_df = umap_df.loc[np.isin(umap_df["experiment_date"], HF_experiments), :].reset_index()

hf_outlier_snips = np.asarray(["20240813_24hpf_F06_e00_t0000", "20240813_36hpf_D03_e00_t0000", "20240813_36hpf_C03_e00_t0000"]) 
# hf_umap_df = hf_umap_df.loc[~np.isin(hf_umap_df["snip_id"], hf_outlier_snips), :]
hf_morph_df = hf_morph_df.loc[~np.isin(hf_morph_df["snip_id"], hf_outlier_snips), :]

### Problem: 28C is our control group, but we don't have stage-matching due to stage shifting
**Potential solution:** search for reference embryos from timelapse data that closely overlap with 28C, but which also extend out into later timepoints

In [None]:
short_pert_name = "wt_ab" # genotype
target_stage = 42 # alive through at least this point
start_stage = 18

embryo_df = morph_df.loc[:, ["experiment_date", "embryo_id", "predicted_stage_hpf", "short_pert_name"]].groupby(
                        ["experiment_date", "embryo_id", "short_pert_name"])["predicted_stage_hpf"].agg(["min", "max"]).reset_index()

pert_filter = embryo_df["short_pert_name"] == short_pert_name
stage_filter = (embryo_df["min"] <= start_stage) & (embryo_df["max"] >= target_stage)

ref_embryo_df = embryo_df.loc[stage_filter & pert_filter, :]
# embryo_df.shape

# ref_umap_df = umap_df.merge(ref_embryo_df.loc[:, ["embryo_id"]], how="inner", on="embryo_id")
ref_morph_df = morph_df.merge(ref_embryo_df.loc[:, ["embryo_id"]], how="inner", on="embryo_id")

### Refit PCA to jus the ref and hotfish data

In [None]:
from sklearn.decomposition import PCA
import re 

# params
n_components = 5
z_pattern = "z_mu_b"
mu_cols = [col for col in ref_morph_df.columns if re.search(z_pattern, col)]
pca_cols = [f"PCA_{p:02}_bio" for p in range(n_components)]

# fit
morph_pca = PCA(n_components=n_components)
morph_pca.fit(pd.concat([ref_morph_df[mu_cols], hf_morph_df[mu_cols]]))

# transform
ref_pca_array = morph_pca.transform(ref_morph_df[mu_cols])
hf_pca_array = morph_pca.transform(hf_morph_df[mu_cols])

ref_pca_df = pd.DataFrame(ref_pca_array, columns=pca_cols)
ref_pca_df[["snip_id", "embryo_id", "temperature", "timepoint"]] = ref_morph_df[["snip_id", "embryo_id", "temperature", "predicted_stage_hpf"]].to_numpy()
ref_pca_df["timepoint"] = np.floor(ref_pca_df["timepoint"])

hf_pca_df = pd.DataFrame(hf_pca_array, columns=pca_cols)
hf_pca_df[["snip_id", "embryo_id", "temperature", "timepoint"]] = hf_morph_df[["snip_id", "embryo_id", "temperature", "predicted_stage_hpf"]].to_numpy()
hf_pca_df["timepoint"] = np.floor(hf_pca_df["timepoint"])

In [None]:
var_cumulative = np.cumsum(morph_pca.explained_variance_ratio_)
fig = px.line(x=np.arange(n_components) + 1, y=var_cumulative, markers=True)

fig.update_layout(xaxis=dict(title="PC number"),
                  yaxis=dict(title="total variance explained"),
                  title="PCA decomposition of morphVAE latent space",
                 font=dict(
                    family="Arial, sans-serif",
                    size=18,  # Adjust this value to change the global font size
                    color="black"
                ))

fig.show()

fig.write_image(os.path.join(fig_path, "morph_pca_var_explained.png"))

### Experiment with fitting 3D spline to re

In [None]:
fit_pca_df = pd.concat([ref_pca_df, hf_pca_df.loc[hf_pca_df["temperature"]==28.5, :]], ignore_index=True)
# print(hf_pca_df.loc[hf_pca_df["temperature"]==28.5, :].shape)

One problem I have noticed is that there is a systematic divergence between the reference trajectory and the 28.5C cohort at ~24hpf. This leads to weird results for other temp cohorts at this time point. I want, therefore, to adjust the reference trajectory to "flow" closer to my reference embryos. The simplest way I can think of to achive this is to add my 28C embryos to the spline fitting dataset, and assign them high weights to ensure that the model mus fit them. 

In [None]:
from src.functions.spline_fitting_v2 import spline_fit_wrapper
import time
import re 
from tqdm import tqdm 

alpha = 0.25 # fraction of fitting obs that we want to be from hf. This is ad-hoc currently. 
emb_vec = fit_pca_df["embryo_id"]
hf_flags = np.asarray([1 if "20240813" in e else 0 for e in emb_vec])
spline_weight_vec = np.ones(hf_flags.shape)
spline_weight_vec[hf_flags==1] = alpha * len(hf_flags) / np.sum(hf_flags)
spline_weight_vec[hf_flags==0] = (1-alpha) * len(hf_flags) / (len(hf_flags)-np.sum(hf_flags))

# pattern = r"PCA_.*_bio"
# pattern = r"z_mu_b"
n_boots = 50
n_spline_points = 2500
boot_size = 1000

# fit normal version
spline_df_orig = spline_fit_wrapper(fit_pca_df, n_boots=n_boots, n_spline_points=n_spline_points, stage_col="timepoint", 
                               obs_weights=None, boot_size=boot_size)
# fit weighted version
spline_df = spline_fit_wrapper(fit_pca_df, n_boots=n_boots, n_spline_points=n_spline_points, stage_col="timepoint", 
                               obs_weights=spline_weight_vec, boot_size=boot_size)


In [None]:
plot_dims = np.asarray([0, 1, 2])

plot_strings = [pca_cols[p] for p in plot_dims]

fig = px.scatter_3d(hf_pca_df, x=plot_strings[0], y=plot_strings[1], z=plot_strings[2], opacity=1,
                    color=hf_pca_df["temperature"].astype(float), color_continuous_scale="RdBu_r", hover_data={"timepoint"})

# fig.update_traces(marker=dict(size=5, showscale=False))

fig.add_traces(go.Scatter3d(x=spline_df_orig.loc[:, plot_strings[0]], 
                            y=spline_df_orig.loc[:, plot_strings[1]], 
                            z=spline_df_orig.loc[:, plot_strings[2]],
                           mode="lines", line=dict(color="black", width=3, dash="dash"), name="reference curve"))

fig.add_traces(go.Scatter3d(x=spline_df.loc[:, plot_strings[0]], 
                            y=spline_df.loc[:, plot_strings[1]], 
                            z=spline_df.loc[:, plot_strings[2]],
                           mode="lines", line=dict(color="black", width=4), name="reference curve"))

fig.update_traces(marker=dict(size=10, line=dict(color="black", width=1)))

fig.update_layout(width=1200, height=1000,
                  scene=dict(xaxis=dict(title="morph PC 1"),
                             yaxis=dict(title="morph PC 2"),
                             zaxis=dict(title="morph PC 3")),
                  # title="PCA decomposition of morphVAE latent space",
                 font=dict(
                    family="Arial, sans-serif",
                    size=16,  # Adjust this value to change the global font size
                    color="black"
                ),
                 coloraxis_colorbar=dict(
                    x=1,  # Increase x to move the colorbar rightwards
                    y=0.5,   # Center vertically (default is often around 0.5)
                    len=0.5  # Adjust the length if needed
                ))

fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_with_spline.html"))

### Next, fit a polynomial surface to estimate embryo stages
Let's experiment with fitting derivatives so we can utilize experimental clock time

In [None]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
import joblib

# Define a pipeline that first transforms the input and then fits a linear model.
degree = 2  # or any degree you choose
model = Pipeline([
    ('poly', PolynomialFeatures(degree=degree, include_bias=True)),
    ('linear', LinearRegression())
])

frac_to_fit = 0.8
X = ref_pca_df[pca_cols].values
n_train = int(np.floor(frac_to_fit * X.shape[0]))
X_indices = np.arange(X.shape[0])
train_indices = np.random.choice(X_indices, n_train, replace=False)
test_indices = X_indices[~np.isin(X_indices, train_indices)]

X_train = X[train_indices]
X_test = X[test_indices]

y = ref_pca_df["timepoint"].values
y_train = y[train_indices]
y_test = y[test_indices]

# Assume X is your (n_samples x N) input array and y is your (n_samples,) target (time).
model.fit(X_train, y_train)

y_pd = model.predict(X_test)

### Use surface fit to generate consistent stage predictions

In [None]:
# get predictions for hotfish data
X_hf = hf_pca_df[pca_cols].values
hf_pca_df["mdl_stage_hpf"] = model.predict(X_hf)

# now for ref
X_ref = ref_pca_df[pca_cols].values
ref_pca_df["mdl_stage_hpf"] = model.predict(X_ref)

# now for spline
X_spline = spline_df[pca_cols].values
spline_df["mdl_stage_hpf"] = model.predict(X_spline)

In [None]:
from scipy.interpolate import interp1d

# interpolate spline data to align with transcriptional spline
n_points = 250
t_start = 12.9
t_stop = 50
t_vec = np.linspace(t_start, t_stop, n_points)

# set index to be time
t_vec_orig = spline_df["mdl_stage_hpf"].to_numpy()

# get new PCA values
interp = interp1d(t_vec_orig, spline_df[pca_cols].values, axis=0)
pca_array_interp = interp(t_vec)

# Reindex the dataframe to include the new time points.
spline_df_new = pd.DataFrame(pca_array_interp, columns=pca_cols)
spline_df_new["stage_hpf"] = t_vec

spline_df_new.head()

In [None]:
# fig = px.scatter(x=t_vec, y=spline_df_new["PCA_00_bio"])
# fig.add_traces(go.Scatter(x=t_vec_orig, y=spline_df["PCA_00_bio"]))
# fig.show()

### Save

In [None]:
hf_pca_df.to_csv(os.path.join(out_path, "hf_morph_df.csv"), index=False)
ref_pca_df.to_csv(os.path.join(out_path, "ab_ref_morph_df.csv"), index=False)
spline_df_new.to_csv(os.path.join(out_path, "spline_morph_df.csv"), index=False)
spline_df.to_csv(os.path.join(out_path, "spline_morph_df_full.csv"), index=False)

# Save the model to a file
joblib.dump(model, os.path.join(out_path, 'morph_stage_model.joblib'))