#### This notebook looks at temperature-dependent changes to embryo morphology

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20241107_ds"
model_name = "SeqVAE_z100_ne150_sweep_01_block01_iter030" 
train_dir = os.path.join(root, "training_data", train_name, "")
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = sorted(glob(os.path.join(output_dir, "*")))[-1]
training_name = os.path.dirname(training_path)
read_path = os.path.join(training_path, "figures", "")

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/_archive/"
os.makedirs(fig_path, exist_ok=True)

In [None]:
morph_df = pd.read_csv(read_path + "embryo_stats_df.csv", index_col=0)
umap_df = pd.read_csv(read_path + "umap_df.csv", index_col=0)
print(umap_df.shape)
umap_df = umap_df.merge(morph_df.loc[:, ["snip_id", "embryo_id", "experiment_time"]], how="left", on=["snip_id"])
print(umap_df.shape)

### Make 3D UMAP and PCA for hotfish experiments

In [None]:
HF_experiments = np.asarray(['20240813_24hpf', '20240813_30hpf', '20240813_36hpf']) #, '20240813_extras'])
hf_morph_df = morph_df.loc[np.isin(morph_df["experiment_date"], HF_experiments), :].reset_index()
hf_umap_df = umap_df.loc[np.isin(umap_df["experiment_date"], HF_experiments), :].reset_index()
hf_outlier_snips = np.asarray(["20240813_24hpf_F06_e00_t0000", "20240813_36hpf_D03_e00_t0000", "20240813_36hpf_C03_e00_t0000"]) 
hf_umap_df = hf_umap_df.loc[~np.isin(hf_umap_df["snip_id"], hf_outlier_snips), :]

In [None]:
# make umap scatter
fig = px.scatter_3d(hf_umap_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})
fig.update_traces(marker=dict(size=6))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_umap.png"))
fig.write_html(os.path.join(fig_path, "hotfish_umap.html"))

In [None]:
fig = px.scatter_3d(hf_umap_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})
fig.update_traces(marker=dict(size=6))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca.html"))

### Problem: 28C is our control group, but we don't have stage-matching due to stage shifting
**Potential solution:** search for reference embryos from timelapse data that closely overlap with 28C, but which also extend out into later timepoints

In [None]:
short_pert_name = "wt_ab" # genotype
target_stage = 44 # alive through at least this point
start_stage = 18

embryo_df = morph_df.loc[:, ["experiment_date", "embryo_id", "predicted_stage_hpf", "short_pert_name"]].groupby(
                        ["experiment_date", "embryo_id", "short_pert_name"])["predicted_stage_hpf"].agg(["min", "max"]).reset_index()

pert_filter = embryo_df["short_pert_name"] == short_pert_name
stage_filter = (embryo_df["min"] <= start_stage) & (embryo_df["max"] >= target_stage)

embryo_df = embryo_df.loc[stage_filter & pert_filter, :]
# embryo_df.shape

ref_umap_df = umap_df.merge(embryo_df.loc[:, ["embryo_id"]], how="inner", on="embryo_id")

In [None]:
fig = px.scatter_3d(hf_umap_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date"})

embryo_index = np.unique(ref_umap_df["embryo_id"])
for eid in embryo_index:
    e_filter = ref_umap_df["embryo_id"]==eid
    fig.add_traces(go.Scatter3d(x=ref_umap_df.loc[e_filter, "UMAP_00_bio_3"], 
                                y=ref_umap_df.loc[e_filter, "UMAP_01_bio_3"], 
                                z=ref_umap_df.loc[e_filter, "UMAP_02_bio_3"], mode="lines", 
                                line=dict(color='rgba(0, 0, 0, 0.2)'), showlegend=False ))

fig.update_traces(marker=dict(size=4))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_umap_ref.png"))
fig.write_html(os.path.join(fig_path, "hotfish_umap_ref.html"))

In [None]:
fig = px.scatter_3d(hf_umap_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date"})

embryo_index = np.unique(ref_umap_df["embryo_id"])
for eid in embryo_index:
    e_filter = ref_umap_df["embryo_id"]==eid
    fig.add_traces(go.Scatter3d(x=ref_umap_df.loc[e_filter, "PCA_00_bio"], 
                                y=ref_umap_df.loc[e_filter, "PCA_01_bio"], 
                                z=ref_umap_df.loc[e_filter, "PCA_02_bio"], mode="lines", 
                                line=dict(color='rgba(0, 0, 0, 0.2)'), showlegend=False ))

fig.update_traces(marker=dict(size=4))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_ref.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_ref.html"))

### Look at variability by time cohort and temp 

In [None]:
# avg_early_timepoint

### Experiment with fitting 3D spline to re

In [None]:
from src.functions.spline_fitting_v2 import LocalPrincipalCurve
import time
import re 
from tqdm import tqdm 

pattern = r"PCA_.*_bio"
pca_cols = [col for col in ref_umap_df.columns if re.search(pattern, col)]
# pca_cols = [col for col in ref_umap_df.columns.tolist() if "PCA" in col] #["PCA_00_bio", "PCA_01_bio", "PCA_02_bio"]
bandwidth = .5
max_iter = 2500
tol = 1e-5
angle_penalty_exp = 0.5
n_boots = 50
boot_size = np.min([ref_umap_df.shape[0], 2500])
num_points = 2500

# Extract PCA coordinates
pert_array = ref_umap_df[pca_cols].values

# Compute average early stage point
min_time = ref_umap_df["predicted_stage_hpf"].min()
early_mask = (ref_umap_df["predicted_stage_hpf"] >= min_time) & \
             (ref_umap_df["predicted_stage_hpf"] < min_time + 2)
early_points = ref_umap_df.loc[early_mask, pca_cols].values

early_options = np.arange(early_points.shape[0])

# Compute average late stage point
max_time = ref_umap_df["predicted_stage_hpf"].max()
late_mask = (ref_umap_df["predicted_stage_hpf"] >= (max_time - 2))
late_points = ref_umap_df.loc[late_mask, pca_cols].values
late_options = np.arange(late_points.shape[0])
# generate array to store spline fits
spline_boot_array = np.zeros((num_points, len(pca_cols), n_boots))

# Randomly select a subset of points for fitting
rng = np.random.RandomState(42)

for n in tqdm(range(n_boots)):
    subset_indices = rng.choice(len(pert_array), size=boot_size, replace=True)
    pert_array_subset = pert_array[subset_indices, :]

    start_ind = np.random.choice(early_options,1)[0]
    stop_ind = np.random.choice(late_options,1)[0]
    start_point = early_points[start_ind, :]
    stop_point = late_points[stop_ind, :]
    
    # Fit LocalPrincipalCurve
    lpc = LocalPrincipalCurve(
        bandwidth=bandwidth,
        max_iter=max_iter,
        tol=tol,
        angle_penalty_exp=angle_penalty_exp
    )
    
    # Fit with the optional start_points/end_point to anchor the spline
    lpc.fit(
        pert_array_subset,
        start_points=start_point[None, :],
        end_point=stop_point[None, :],
        num_points=num_points
    )

    spline_boot_array[:, :, n] = lpc.cubic_splines[0]
# stop = time.time()


# spline_points = None
# if len(lpc.cubic_splines) > 0:
#     # If your local principal curve class stores the final spline
#     spline_points = lpc.cubic_splines[0]
# else:
#     # If no spline was built, skip
#     continue

# # Create a temporary DataFrame for the current spline
# spline_df = pd.DataFrame(spline_points, columns=["PCA_1", "PCA_2", "PCA_3"])
# spline_df["phenotype"] = pert

In [None]:
import plotly.graph_objects as go

def compute_tube_points(spline, std_err, num_circle_points=20):
    """
    Given a spline (N x 3) and corresponding radius (N,) at each point,
    compute tube mesh coordinates.
    """
    N, d = spline.shape
    assert d == 3, "This function assumes 3D data."

    # Compute tangent vectors by differentiating the spline
    tangents = np.gradient(spline, axis=0)
    tangents = np.array([t / np.linalg.norm(t) if np.linalg.norm(t) > 0 else np.array([1,0,0]) for t in tangents])
    
    # For each tangent, compute two orthogonal vectors:
    tube_x, tube_y, tube_z = [], [], []
    for i in range(N):
        t = tangents[i]
        # Find an arbitrary vector not parallel to t
        arbitrary = np.array([1, 0, 0]) if abs(t[0]) < 0.9 else np.array([0, 1, 0])
        # Compute a vector perpendicular to t
        n1 = np.cross(t, arbitrary)
        n1 /= np.linalg.norm(n1)
        # Compute the second perpendicular vector
        n2 = np.cross(t, n1)
        n2 /= np.linalg.norm(n2)
        
        # Build circle points around the spline point:
        angles = np.linspace(0, 2*np.pi, num_circle_points, endpoint=False)
        for angle in angles:
            offset = std_err[i] * (np.cos(angle) * n1 + np.sin(angle) * n2)
            tube_x.append(spline[i, 0] + offset[0])
            tube_y.append(spline[i, 1] + offset[1])
            tube_z.append(spline[i, 2] + offset[2])
    
    # Reshape to (N, num_circle_points)
    tube_x = np.array(tube_x).reshape(N, num_circle_points)
    tube_y = np.array(tube_y).reshape(N, num_circle_points)
    tube_z = np.array(tube_z).reshape(N, num_circle_points)
    
    return tube_x, tube_y, tube_z



In [None]:
# calculate mean and se
mean_spline = np.mean(spline_boot_array, axis=2)
se_spline = np.std(spline_boot_array, axis=2)

In [None]:
plot_dims = np.asarray([0, 1, 2])

# get se mesh for spline
tube_x, tube_y, tube_z = compute_tube_points(mean_spline[:, plot_dims], se_spline[:, plot_dims])

se_mesh = go.Mesh3d(
    x=tube_x.flatten(),
    y=tube_y.flatten(),
    z=tube_z.flatten(),
    i=[], j=[], k=[],  # You would need to compute triangle indices based on the grid structure
    color='lightblue',
    opacity=0.2,
    name='Uncertainty'
)


plot_strings = [pca_cols[p] for p in plot_dims]

fig = px.scatter_3d(hf_umap_df, x=plot_strings[0], y=plot_strings[1], z=plot_strings[2], opacity=1,
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})

fig.update_traces(marker=dict(size=5, showscale=False))

fig.add_traces(go.Scatter3d(x=mean_spline[:, plot_dims[0]], y=mean_spline[:, plot_dims[1]], 
                            z=mean_spline[:, plot_dims[2]],
                           mode="lines", line=dict(color="darkblue", width=4), name="reference curve"))

# fig.add_traces(go.Scatter3d(x=[P2[0]], y=[P2[1]], z=[P2[2]], mode="markers"))

# fig.add_traces(se_mesh)

fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_with_spline.html"))

We can use this spline to stage embryos. The first step is to calculate elapsed experimental time along each segment. This, plus an estimated starting time will give us a calibration curve

In [None]:
# def compute_m_distance(ref_data, mean_spline, se_spline):
#     n, d = ref_data.shape
#     m, _ = mean_spline.shape
#     dist_matrix = np.empty((n, m))
#     # Compute the Mahalanobis distance for each pair (i, j)
#     for j in tqdm(range(m)):
#         # Construct the diagonal inverse covariance matrix for spline point j
#         # inv_var = 1.0 / (se_spline[j] ** 2)  # shape: (d,)
#         se_total = np.sum(se_spline[j, :]**2)
#         diff = np.sum((ref_data - mean_spline[j])**2, axis=1)  # shape: (n, d)
#         # Compute squared Mahalanobis distances: sum over dimensions of diff**2 * inv_var
#         dist_matrix[:, j] = np.sqrt(diff / se_total) #np.sqrt(np.sum(diff**2 * inv_var, axis=1))
        
#     closest_indices = np.argmin(dist_matrix, axis=1)
#     closest_dist = np.min(dist_matrix, axis=1)

#     return closest_dist, closest_indices, dist_matrix

#### Use reference data to get calibration curve

In [None]:
from scipy.spatial import distance_matrix

ref_data = ref_umap_df[pca_cols].values
ref_dist = distance_matrix(ref_data, mean_spline)
sigma = np.sqrt(np.mean(np.sum(se_spline**2, axis=1)))
ref_dist_z = ref_dist / sigma
ref_dist_z.shape

ref_weights = np.exp(-0.5 * ref_dist_z**2)
ref_weights[ref_dist_z > (2 * np.sqrt(len(pca_cols)))] = 0

# calculate weighted average spline index for each obs
knot_i_vec = np.arange(num_points)[None, :]
ref_ci_vec = np.argmin(ref_dist, axis=1) #np.divide(np.sum(np.multiply(knot_i_vec, ref_weights), axis=1), np.sum(ref_weights, axis=1))

In [None]:
ref_umap_df["knot_index"] = ref_ci_vec
# ref_umap_df["knot_dist"] = ref_cd 

diff_cols = pca_cols + ["experiment_time", "knot_index"]
diff_cols_lb = [col + "_diff" for col in diff_cols]
for col in diff_cols_lb:
    if col in ref_umap_df.columns.tolist():
        ref_umap_df = ref_umap_df.drop(labels=col, axis=1)

# calculate morphological flux for each embryo
ref_umap_df[diff_cols_lb] = ref_umap_df.groupby('embryo_id')[diff_cols].diff()
ref_umap_df = ref_umap_df.fillna(method='bfill') 
# diff_df = diff_df.rename(columns={col: f"{col}_diff" for col in pca_cols + ["experiment_time"]})
# ref_umap_df = ref_umap_df.join(diff_df)

In [None]:
px.histogram(ref_umap_df["knot_index"])

In [None]:
# Ok, now get average flux for each knot point
ref_umap_df["dtds"] = np.divide(ref_umap_df["experiment_time_diff"], ref_umap_df["knot_index_diff"])
m_flux_array = np.zeros(mean_spline.shape)
t_flux_array = np.zeros((mean_spline.shape[0], 1))
pd_time_array = np.zeros((mean_spline.shape[0], 1))
# inlier_filter = ref_cd <= (2 * np.sqrt(len(pca_cols)))

for t in range(m_flux_array.shape[0]):
    knot_indices = np.where(ref_weights[:, t] > 0)[0]
    if len(knot_indices) > 0:
        # m_flux_array[t, :] = np.mean(ref_umap_df.loc[knot_ref_filter, diff_cols[:-2]], axis=0)
        pdt_vals = ref_umap_df.loc[knot_indices, "predicted_stage_hpf"]
        dt_vals = ref_umap_df.loc[knot_indices, "experiment_time_diff"]
        ds_vals = ref_umap_df.loc[knot_indices, "knot_index_diff"]
        # dtds_vals = ref_umap_df.loc[knot_indices, "dtds"]
        # inf_filter = ~np.isinf(dtds_vals)
        wt_vec = ref_weights[knot_indices, t]
        dt_avg = np.sum(np.multiply(dt_vals, wt_vec)) / np.sum(wt_vec)
        ds_avg = np.sum(np.multiply(ds_vals, wt_vec)) / np.sum(wt_vec)
        t_flux_array[t, :] = dt_avg / ds_avg
        # t_flux_array[t, :] = np.sum(np.multiply(dtds_vals[inf_filter], wt_vec[inf_filter])) / np.sum(wt_vec[inf_filter])
        pd_time_array[t, :] = np.sum(np.multiply(pdt_vals, wt_vec)) / np.sum(wt_vec)


In [None]:
from sklearn.linear_model import LinearRegression
check_cols = ["spline_trend_hpf", "spline_stage_hpf", "spline_flux_hpf", "pd_time_hpf"]
for col in check_cols:
    if col in ref_umap_df.columns.tolist():
        ref_umap_df = ref_umap_df.drop(labels=[col], axis=1)
        
# calculate predicted stage as a function of knot position
start_stage = pd_time_array[0, 0]
knot_trend = np.cumsum(t_flux_array) / 3600

# fit a linear model to estimate starting stage
# stage_mdl = reg = LinearRegression().fit(knot_slope, y) #knot_stage_hpf = start_stage + 

spline_stage_df = pd.DataFrame(np.arange(num_points), columns=["knot_index"])
spline_stage_df["spline_trend_hpf"] = knot_trend
spline_stage_df["spline_flux_hpf"] = t_flux_array / 3600
spline_stage_df["pd_time_hpf"] = pd_time_array
spline_stage_df[pca_cols] = mean_spline

# join on fields
ref_umap_df = ref_umap_df.merge(spline_stage_df.loc[:, ["knot_index", "spline_trend_hpf", "spline_flux_hpf"]], how="left", on="knot_index")

# run regression to get offset
reg = LinearRegression().fit(ref_umap_df["spline_trend_hpf"].to_numpy()[:, None], ref_umap_df["predicted_stage_hpf"].to_numpy()[:, None])
reg_inv = LinearRegression().fit(ref_umap_df["predicted_stage_hpf"].to_numpy()[:, None], 
                                 ref_umap_df["spline_trend_hpf"].to_numpy()[:, None] + reg.intercept_)

# use offset to get trend
spline_stage_df["spline_stage_hpf"] = spline_stage_df["spline_trend_hpf"].copy() + reg.intercept_
ref_umap_df["spline_stage_hpf"] = ref_umap_df["spline_trend_hpf"].copy() + reg.intercept_

# assign stages to ref embryos (sanity check)
fig = px.line(spline_stage_df, x="knot_index", y="spline_stage_hpf")
# fig = px.scatter(spline_stage_df, x="pd_time_hpf", y="spline_stage_hpf")
fig.show()

In [None]:
inlier_filter = np.max(ref_weights, axis=1) > 0

fig = px.scatter(ref_umap_df, x="predicted_stage_hpf", y="spline_stage_hpf", color=inlier_filter)

fig.update_layout(width=800, height=600)

fig.show()

### Now use spline to calibrate hotfish embryos

In [None]:
check_cols = ["spline_trend_hpf", "spline_stage_hpf", "spline_flux_hpf", "pd_time_hpf"]
for col in check_cols:
    if col in hf_umap_df.columns.tolist():
        hf_umap_df = ref_umap_df.drop(labels=[col], axis=1)
        
hf_data = hf_umap_df[pca_cols].values
hf_dist = distance_matrix(hf_data, mean_spline)

hf_ci_vec = np.argmin(hf_dist, axis=1)
hf_umap_df["knot_index"] = hf_ci_vec
spl_transfer_cols = [col for col in spline_stage_df.columns.tolist() if col not in pca_cols]
hf_umap_df = hf_umap_df.merge(spline_stage_df.loc[:, spl_transfer_cols], how="left", on="knot_index")

In [None]:
fig = px.scatter(hf_umap_df, x="predicted_stage_hpf", y="spline_stage_hpf", color="temperature")

fig.update_layout(width=800, height=600)

fig.show()

### Clearly we run into some issues with 35C...they are too far diverged to accurately register
what if we try NN?

In [None]:
k_nn = 5
hf_ref_dist = distance_matrix(hf_data, ref_data)
nn_indices = np.argpartition(hf_ref_dist, kth=k_nn, axis=1)[:, :k_nn]

In [None]:
pd_stage_vec = ref_umap_df["predicted_stage_hpf"].to_numpy()
hf_nn_stage_vec = np.mean(pd_stage_vec[nn_indices], axis=1)
hf_umap_df["nn_stage_hpf"] = hf_nn_stage_vec
hf_umap_df["nn_spline_stage_hpf"] = reg_inv.predict(hf_nn_stage_vec[:, None])

In [None]:
fig = px.scatter(hf_umap_df, x="predicted_stage_hpf", y="nn_spline_stage_hpf", color="temperature")

fig.update_layout(width=800, height=600)

fig.show()

In [None]:
fig = px.scatter(hf_umap_df, x="nn_spline_stage_hpf", y="spline_stage_hpf", color="temperature")

fig.update_layout(width=800, height=600)

fig.show()

### Now...let's use stages to estimate phenotypic severity and variability as a function of temperature and time

In [None]:
# first, calculate mean divergence from the WT trajectory
# We have to first re-find the most appropriate index, since this will NOT always be closest spline index
stage_dist_mat = distance_matrix(hf_umap_df["nn_spline_stage_hpf"].to_numpy()[:, None], 
                                spline_stage_df["spline_stage_hpf"].to_numpy()[:, None])
hf_umap_df["spline_index_adjusted"] = np.argmin(stage_dist_mat, axis=1)[:, None]
hf_umap_df["wt_spline_dist"] = hf_dist[np.arange(hf_dist.shape[0]), hf_umap_df["spline_index_adjusted"].to_numpy()]

In [None]:
hf_umap_df["timepoint"] = np.round(hf_umap_df["predicted_stage_hpf"].to_numpy()).astype(int)
cohort_dist_df = hf_umap_df.loc[:, ["timepoint", "temperature", "wt_spline_dist"]].groupby(
    ["timepoint", "temperature"]).agg(["mean", "std"]).reset_index()
cohort_dist_df.columns = ['_'.join(map(str, col)).strip() for col in cohort_dist_df.columns.values]
cohort_dist_df.head()                                                                                        

In [None]:
fig = px.scatter(hf_umap_df, x="timepoint", y="wt_spline_dist", color="temperature")
fig.show()

### There is some subtlety here...
Quick tangent: what if we fit a surface?

In [None]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline

# Define a pipeline that first transforms the input and then fits a linear model.
degree = 2  # or any degree you choose
model = Pipeline([
    ('poly', PolynomialFeatures(degree=degree, include_bias=True)),
    ('linear', LinearRegression())
])

X = ref_umap_df[pca_cols].values
y = ref_umap_df["predicted_stage_hpf"].values
# Assume X is your (n_samples x N) input array and y is your (n_samples,) target (time).
model.fit(X, y)

X_new = hf_umap_df[pca_cols].values
# You can then use the model to predict or analyze the polynomial surface.
hf_surf_predictions = model.predict(X_new)

In [None]:
hf_umap_df.head()

In [None]:
fig = px.scatter(hf_umap_df, x="timepoint", y=hf_surf_predictions, color="temperature")
fig.show()

In [None]:
import numpy as np
import plotly.graph_objects as go

# Assume X is your (n_samples x N) input data and model is your trained polynomial regression model.
N = X.shape[1]

# Determine the grid range for the first two dimensions
x_min, x_max = X[:, 0].min(), X[:, 0].max()
y_min, y_max = X[:, 1].min(), X[:, 1].max()

# Create a grid for the first two dimensions
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                     np.linspace(y_min, y_max, 100))

# Fix remaining dimensions (e.g., at their mean value)
fixed_values = np.mean(X[:, 2:], axis=0) if N > 2 else []

# Prepare grid data for prediction: reshape xx and yy to columns and append fixed values
grid_points = np.c_[xx.ravel(), yy.ravel()]
if N > 2:
    # Repeat fixed values for each grid point
    fixed_values_repeated = np.tile(fixed_values, (grid_points.shape[0], 1))
    grid_points = np.hstack([grid_points, fixed_values_repeated])

# Predict the time values for the grid points
predictions = model.predict(grid_points).reshape(xx.shape)

# Create a Plotly surface plot
fig = fig = px.scatter_3d(ref_umap_df, x="PCA_00_bio", y="PCA_01_bio", z="predicted_stage_hpf", 
                    color="predicted_stage_hpf", hover_data={"predicted_stage_hpf", "experiment_date"}, 
                          color_continuous_scale="magma")

fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Surface(x=xx, y=yy, z=predictions, colorscale='magma', opacity=0.8))
fig.update_layout(
    title='Visualization of the Fitted Polynomial Surface',
    scene=dict(
        xaxis_title='Dimension 1',
        yaxis_title='Dimension 2',
        zaxis_title='Predicted Time'
    )
)
fig.show()

In [None]:
P2

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy import interpolate

# hf_umap_df.reset_index(drop=True, inplace=True)

embryo_i = 94
T = hf_umap_df.loc[embryo_i, "nn_spline_stage_hpf"]
P1 = hf_umap_df.loc[embryo_i, pca_cols].to_numpy().astype(float)
spline_i = np.argmin(np.abs(spline_stage_df["pd_time_hpf"] - T))
P2 = spline_stage_df.loc[spline_i, pca_cols].to_numpy().astype(float)
M = 25
# Assume F(D) is your differentiable polynomial function
# t is the fixed level set value (F(P1)=F(P2)=t)
# P1 and P2 are numpy arrays of shape (N,)
# M is the number of segments (M+1 points total)

def total_length(points):
    # points is a flattened array representing the intermediate points
    points = points.reshape(-1, N)
    # Prepend P1 and append P2
    all_points = np.vstack([P1, points, P2])
    # Compute differences between consecutive points
    diffs = np.diff(all_points, axis=0)
    # Compute Euclidean distances for each segment
    distances = np.sqrt(np.sum(diffs**2, axis=1))
    return np.sum(distances)

def constraint_func(points, model=model):
    # For each intermediate point, enforce F(P) - t = 0
    points = points.reshape(-1, N)
    pd = model.predict(points)
    return pd - T

# Number of free points
num_free = M - 1

# Initial guess: linear interpolation between P1 and P2
init_points = np.linspace(P1, P2, M+1)[1:-1].flatten()

# Define constraints for scipy.optimize.minimize
constraints = {'type': 'eq', 'fun': constraint_func}

result = minimize(total_length, init_points, constraints=constraints, options={"maxiter":2500})

# Reshape the result into M-1 points
optimal_points = result.x.reshape(-1, N)

# Combine with endpoints for the full curve
optimal_curve = np.vstack([P1, optimal_points, P2])
geodesic_length = total_length(result.x)

In [None]:
fig = px.line_3d(x=optimal_curve[:, 0], y=optimal_curve[:, 1], z=optimal_curve[:, 2])
fig.add_trace(go.Scatter3d(x=[P1[0]], y=[P1[1]], z=[P1[2]], mode="markers"))
fig.add_trace(go.Scatter3d(x=[P2[0]], y=[P2[1]], z=[P2[2]], mode="markers"))
fig.add_traces(go.Scatter3d(x=mean_spline[:, plot_dims[0]], y=mean_spline[:, plot_dims[1]], 
                            z=mean_spline[:, plot_dims[2]],
                           mode="lines", line=dict(color="darkblue", width=4), name="reference curve"))
fig.show()

In [None]:
result

In [None]:
np.sqrt(np.sum((P1-P2)**2))

In [None]:
print(T)
model.predict(optimal_points)

In [None]:
np.where(hf_umap_df["snip_id"]=="20240813_30hpf_H07_e00_t0000")[0]

In [None]:
hf_umap_df.columns