#### This notebook looks at temperature-dependent changes to embryo morphology

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
train_name = "20241107_ds"
model_name = "SeqVAE_z100_ne150_sweep_01_block01_iter030" 
train_dir = os.path.join(root, "training_data", train_name, "")
output_dir = os.path.join(train_dir, model_name) 

# get path to model
training_path = sorted(glob(os.path.join(output_dir, "*")))[-1]
training_name = os.path.dirname(training_path)
read_path = os.path.join(training_path, "figures", "")

# path to save data
out_path = os.path.join(root, "results", "20240303", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/"
os.makedirs(fig_path, exist_ok=True)

In [None]:
morph_df = pd.read_csv(read_path + "embryo_stats_df.csv", index_col=0)
umap_df = pd.read_csv(read_path + "umap_df.csv", index_col=0)
print(umap_df.shape)
umap_df = umap_df.merge(morph_df.loc[:, ["snip_id", "embryo_id", "experiment_time"]], how="left", on=["snip_id"])
print(umap_df.shape)

### Make 3D UMAP and PCA for hotfish experiments

In [None]:
HF_experiments = np.asarray(['20240813_24hpf', '20240813_30hpf', '20240813_36hpf']) #, '20240813_extras'])
hf_morph_df = morph_df.loc[np.isin(morph_df["experiment_date"], HF_experiments), :].reset_index()
hf_umap_df = umap_df.loc[np.isin(umap_df["experiment_date"], HF_experiments), :].reset_index()
hf_outlier_snips = np.asarray(["20240813_24hpf_F06_e00_t0000", "20240813_36hpf_D03_e00_t0000", "20240813_36hpf_C03_e00_t0000"]) 
hf_umap_df = hf_umap_df.loc[~np.isin(hf_umap_df["snip_id"], hf_outlier_snips), :]

In [None]:
# make umap scatter
fig = px.scatter_3d(hf_umap_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})
fig.update_traces(marker=dict(size=6))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_umap.png"))
fig.write_html(os.path.join(fig_path, "hotfish_umap.html"))

In [None]:
fig = px.scatter_3d(hf_umap_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})
fig.update_traces(marker=dict(size=6))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca.html"))

### Problem: 28C is our control group, but we don't have stage-matching due to stage shifting
**Potential solution:** search for reference embryos from timelapse data that closely overlap with 28C, but which also extend out into later timepoints

In [None]:
short_pert_name = "wt_ab" # genotype
target_stage = 44 # alive through at least this point
start_stage = 18

embryo_df = morph_df.loc[:, ["experiment_date", "embryo_id", "predicted_stage_hpf", "short_pert_name"]].groupby(
                        ["experiment_date", "embryo_id", "short_pert_name"])["predicted_stage_hpf"].agg(["min", "max"]).reset_index()

pert_filter = embryo_df["short_pert_name"] == short_pert_name
stage_filter = (embryo_df["min"] <= start_stage) & (embryo_df["max"] >= target_stage)

embryo_df = embryo_df.loc[stage_filter & pert_filter, :]
# embryo_df.shape

ref_umap_df = umap_df.merge(embryo_df.loc[:, ["embryo_id"]], how="inner", on="embryo_id")

In [None]:
fig = px.scatter_3d(hf_umap_df, x="UMAP_00_bio_3", y="UMAP_01_bio_3", z="UMAP_02_bio_3", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date"})

embryo_index = np.unique(ref_umap_df["embryo_id"])
for eid in embryo_index:
    e_filter = ref_umap_df["embryo_id"]==eid
    fig.add_traces(go.Scatter3d(x=ref_umap_df.loc[e_filter, "UMAP_00_bio_3"], 
                                y=ref_umap_df.loc[e_filter, "UMAP_01_bio_3"], 
                                z=ref_umap_df.loc[e_filter, "UMAP_02_bio_3"], mode="lines", 
                                line=dict(color='rgba(0, 0, 0, 0.2)'), showlegend=False ))

fig.update_traces(marker=dict(size=4))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_umap_ref.png"))
fig.write_html(os.path.join(fig_path, "hotfish_umap_ref.html"))

In [None]:
fig = px.scatter_3d(hf_umap_df, x="PCA_00_bio", y="PCA_01_bio", z="PCA_02_bio", 
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date"})

embryo_index = np.unique(ref_umap_df["embryo_id"])
for eid in embryo_index:
    e_filter = ref_umap_df["embryo_id"]==eid
    fig.add_traces(go.Scatter3d(x=ref_umap_df.loc[e_filter, "PCA_00_bio"], 
                                y=ref_umap_df.loc[e_filter, "PCA_01_bio"], 
                                z=ref_umap_df.loc[e_filter, "PCA_02_bio"], mode="lines", 
                                line=dict(color='rgba(0, 0, 0, 0.2)'), showlegend=False ))

fig.update_traces(marker=dict(size=4))
fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_ref.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_ref.html"))

### Experiment with fitting 3D spline to re

In [None]:
from src.functions.spline_fitting_v2 import LocalPrincipalCurve
import time
import re 
from tqdm import tqdm 

pattern = r"PCA_.*_bio"
pca_cols = [col for col in ref_umap_df.columns if re.search(pattern, col)]
# pca_cols = [col for col in ref_umap_df.columns.tolist() if "PCA" in col] #["PCA_00_bio", "PCA_01_bio", "PCA_02_bio"]
bandwidth = .5
max_iter = 2500
tol = 1e-5
angle_penalty_exp = 0.5
n_boots = 50
boot_size = np.min([ref_umap_df.shape[0], 2500])
num_points = 5000

# Extract PCA coordinates
pert_array = ref_umap_df[pca_cols].values

# Compute average early stage point
min_time = ref_umap_df["predicted_stage_hpf"].min()
early_mask = (ref_umap_df["predicted_stage_hpf"] >= min_time) & \
             (ref_umap_df["predicted_stage_hpf"] < min_time + 2)
early_points = ref_umap_df.loc[early_mask, pca_cols].values

early_options = np.arange(early_points.shape[0])

# Compute average late stage point
max_time = ref_umap_df["predicted_stage_hpf"].max()
late_mask = (ref_umap_df["predicted_stage_hpf"] >= (max_time - 2))
late_points = ref_umap_df.loc[late_mask, pca_cols].values
late_options = np.arange(late_points.shape[0])
# generate array to store spline fits
spline_boot_array = np.zeros((num_points, len(pca_cols), n_boots))

# Randomly select a subset of points for fitting
rng = np.random.RandomState(42)

for n in tqdm(range(n_boots)):
    subset_indices = rng.choice(len(pert_array), size=boot_size, replace=True)
    pert_array_subset = pert_array[subset_indices, :]

    start_ind = np.random.choice(early_options,1)[0]
    stop_ind = np.random.choice(late_options,1)[0]
    start_point = early_points[start_ind, :]
    stop_point = late_points[stop_ind, :]
    
    # Fit LocalPrincipalCurve
    lpc = LocalPrincipalCurve(
        bandwidth=bandwidth,
        max_iter=max_iter,
        tol=tol,
        angle_penalty_exp=angle_penalty_exp
    )
    
    # Fit with the optional start_points/end_point to anchor the spline
    lpc.fit(
        pert_array_subset,
        start_points=start_point[None, :],
        end_point=stop_point[None, :],
        num_points=num_points
    )

    spline_boot_array[:, :, n] = lpc.cubic_splines[0]


In [None]:
# calculate mean and se
mean_spline = np.mean(spline_boot_array, axis=2)
se_spline = np.std(spline_boot_array, axis=2)

In [None]:
plot_dims = np.asarray([0, 1, 2])

# get se mesh for spline
# tube_x, tube_y, tube_z = compute_tube_points(mean_spline[:, plot_dims], se_spline[:, plot_dims])

# se_mesh = go.Mesh3d(
#     x=tube_x.flatten(),
#     y=tube_y.flatten(),
#     z=tube_z.flatten(),
#     i=[], j=[], k=[],  # You would need to compute triangle indices based on the grid structure
#     color='lightblue',
#     opacity=0.2,
#     name='Uncertainty'
# )


plot_strings = [pca_cols[p] for p in plot_dims]

fig = px.scatter_3d(hf_umap_df, x=plot_strings[0], y=plot_strings[1], z=plot_strings[2], opacity=1,
                    color="temperature", hover_data={"predicted_stage_hpf", "experiment_date", "snip_id"})

fig.update_traces(marker=dict(size=5, showscale=False))

fig.add_traces(go.Scatter3d(x=mean_spline[:, plot_dims[0]], y=mean_spline[:, plot_dims[1]], 
                            z=mean_spline[:, plot_dims[2]],
                           mode="lines", line=dict(color="darkblue", width=4), name="reference curve"))

# fig.add_traces(go.Scatter3d(x=[P2[0]], y=[P2[1]], z=[P2[2]], mode="markers"))

# fig.add_traces(se_mesh)

fig.show()

fig.write_image(os.path.join(fig_path, "hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "hotfish_pca_with_spline.html"))

### Next, fit a polynomial surface to estimate embryo stages
Let's experiment with fitting derivatives so we can utilize experimental clock time

In [None]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline

# Define a pipeline that first transforms the input and then fits a linear model.
degree = 2  # or any degree you choose
model = Pipeline([
    ('poly', PolynomialFeatures(degree=degree, include_bias=True)),
    ('linear', LinearRegression())
])

frac_to_fit = 0.8
X = ref_umap_df[pca_cols].values
n_train = int(np.floor(frac_to_fit * X.shape[0]))
X_indices = np.arange(X.shape[0])
train_indices = np.random.choice(X_indices, n_train, replace=False)
test_indices = X_indices[~np.isin(X_indices, train_indices)]

X_train = X[train_indices]
X_test = X[test_indices]

y = ref_umap_df["predicted_stage_hpf"].values
y_train = y[train_indices]
y_test = y[test_indices]
# Assume X is your (n_samples x N) input array and y is your (n_samples,) target (time).
model.fit(X_train, y_train)

y_pd = model.predict(X_test)

fig = px.scatter(x=y_test, y=y_pd)
fig.show()
# X_new = hf_umap_df[pca_cols].values
# # You can then use the model to predict or analyze the polynomial surface.
# hf_surf_predictions = model.predict(X_new)

In [None]:
# get predictions for hotfish data

X_hf = hf_umap_df[pca_cols].values
# You can then use the model to predict or analyze the polynomial surface.
hf_stage_predictions = model.predict(X_hf)

hf_umap_df["timepoint"] = np.round(hf_umap_df["predicted_stage_hpf"].to_numpy()).astype(int)
hf_umap_df["mdl_stage_hpf"] = hf_stage_predictions
fig = px.scatter(hf_umap_df, x="timepoint", y=hf_stage_predictions, color="temperature")
fig.show()

In [None]:
# get time values for the WT spline 
spline_stage_pd = model.predict(mean_spline)
spline_stage_df = pd.DataFrame(np.arange(num_points), columns=["knot_index"])
spline_stage_df["pd_stage_hpf"] = spline_stage_pd
spline_stage_df[pca_cols] = mean_spline

### Can we visualize the developmenta "surface"?

In [None]:
# from scipy.interpolate import griddata

# # Create a grid over the domain of your data.
# X0 = ref_umap_df[pca_cols].to_numpy()
# z=-model.predict(X0) 

# # grid_x = np.linspace(x.min(), x.max(), 100)
# # grid_y = np.linspace(y.min(), y.max(), 100)

In [None]:
import umap

umap_model = umap.UMAP(n_components=2)

# Compute the embedding
umap_model.fit(ref_umap_df[pca_cols].values)
embedding = umap_model.transform(ref_umap_df[pca_cols].values)
hf_embedding = umap_model.transform(hf_umap_df[pca_cols].values)

In [None]:
from scipy.interpolate import griddata
from scipy.spatial import distance_matrix

# Create a grid over the domain of your data.
x=embedding[:, 0]
y=embedding[:, 1]
z=-model.predict(X) / 60

# fig = px.scatter_3d(x=x, y=y, z=z, color=z)
# fig.show()
grid_x = np.linspace(0.9*x.min(), 1.1*x.max(), 100)
grid_y = np.linspace(0.9*y.min(), 1.1*y.max(), 100)
grid_x, grid_y = np.meshgrid(grid_x, grid_y)

xy_long = np.c_[grid_x.ravel()[:, None], grid_y.ravel()[:, None]]
dist_vec = np.min(distance_matrix(xy_long, embedding), axis=1)
# Interpolate the scattered data onto the grid.
# grid_z = griddata(points=(x, y), values=z, xi=(grid_x, grid_y), method='cubic')

# grid_x.shape
# Create the surface plot.
# fig = go.Figure(data=[go.Surface(z=grid_z, x=grid_x, y=grid_y)])
# fig.update_layout(title="3D Surface from Scattered Data", scene=dict(
#                     xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
# fig.show()

In [None]:
from scipy.ndimage import gaussian_filter
# px.histogram(dist_vec)
dist_thresh = 1.5
dist_mat = dist_vec.reshape(100, 100)
grid_z = griddata(points=(x, y), values=z, xi=(grid_x, grid_y), method='nearest')
grid_z_smoothed = gaussian_filter(grid_z, sigma=2, mode="nearest")
grid_z_smoothed[dist_mat>dist_thresh] = np.nan

hf_umap_df["mdl_stage_plot"] = -hf_umap_df["mdl_stage_hpf"].copy() / 60
# Create the surface plot.
fig = px.scatter_3d(x=hf_embedding[:, 0], y=hf_embedding[:, 1], z=hf_umap_df["mdl_stage_plot"], color=hf_umap_df["temperature"])

fig.update_traces(marker=dict(size=5))

fig.add_trace(go.Surface(z=grid_z_smoothed, x=grid_x, y=grid_y, opacity=0.5))
fig.update_layout(title="3D Surface from Scattered Data", scene=dict(
                    xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))


fig.show()

fig.write_image(os.path.join(fig_path, "ab_developmental_surface.png"))
fig.write_html(os.path.join(fig_path, "ab_developmental_surface.html"))

### Calculate mean and standard deviation in embryo morphology

In [None]:
hf_umap_df.columns

In [None]:
hf_cohort_df = hf_umap_df.loc[:, ["timepoint", "temperature", "mdl_stage_hpf"] + pca_cols].groupby(
                    ["timepoint", "temperature"]).agg(["mean", "std"]).reset_index()
hf_cohort_df.columns.values
hf_cohort_df.columns = ['_'.join(map(str, col)).strip() for col in hf_cohort_df.columns.values]
hf_cohort_df.head()       

In [None]:
plot_dims = np.asarray([0, 1, 2])
mean_pca_cols = [col +"_mean" for col in pca_cols]
plot_strings = [mean_pca_cols[p] for p in plot_dims]

fig = px.scatter_3d(hf_cohort_df, x=plot_strings[0], y=plot_strings[1], z=plot_strings[2], opacity=1,
                    color="temperature_", hover_data={"timepoint_"})

fig.update_traces(marker=dict(size=5, showscale=False))

fig.add_traces(go.Scatter3d(x=mean_spline[:, plot_dims[0]], y=mean_spline[:, plot_dims[1]], 
                            z=mean_spline[:, plot_dims[2]],
                           mode="lines", line=dict(color="darkblue", width=4), name="reference curve"))

# fig.add_traces(go.Scatter3d(x=[P2[0]], y=[P2[1]], z=[P2[2]], mode="markers"))

# fig.add_traces(se_mesh)

fig.show()

fig.write_image(os.path.join(fig_path, "avg_hotfish_pca_with_spline.png"))
fig.write_html(os.path.join(fig_path, "avg_hotfish_pca_with_spline.html"))

### Use JAX to generate predicted developmental gradients at each point in latent space

In [None]:
import jax
import jax.numpy as jnp

def make_jax_functions(model):

    # Extract the PolynomialFeatures transformer and LinearRegression estimator.
    poly = model.named_steps['poly']
    linear = model.named_steps['linear']
    
    # Extract the exponents for each polynomial term. This is an (m x d) array.
    powers = jnp.array(poly.powers_)
    
    # Extract the coefficients and intercept from the linear model.
    theta = jnp.array(linear.coef_)
    intercept = jnp.array(linear.intercept_)
    
    # def predict(x, theta, intercept):
    #     """
    #     Computes predictions for a batch of inputs.
    #     x: (n_samples x d) input array.
    #     Returns an array of shape (n_samples,) with the model's predictions.
    #     """
    #     # Compute polynomial features: for each sample, raise the input to each power
    #     # and take the product across features. The result is an (n_samples x m) array.
    #     poly_features = jnp.prod(jnp.power(x[:, None, :], powers), axis=2)
    #     return jnp.dot(poly_features, theta) + intercept

    # def loss_fn(params):
    #     """
    #     Computes the mean-squared error loss on the dataset (X, y) given model parameters.
    #     params: tuple (theta, intercept)
    #     """
    #     preds = predict(X, params[0], params[1])
    #     return jnp.mean((preds - y) ** 2)
    
    def predict_single(x, theta, intercept):
        """
        A helper function that computes the prediction for a single input sample.
        x: (d,) array.
        Returns a scalar prediction.
        """
        # For a single sample, x has shape (d,). 
        # The polynomial features are computed by raising x to each power in 'powers' 
        # (which has shape (m, d)) and taking the product over the d features.
        poly_features = jnp.prod(jnp.power(x, powers), axis=1)
        return jnp.dot(poly_features, theta) + intercept

    def predict_and_grad(params, X_new):
        """
        Given parameters (theta, intercept) and a new set of input data X_new,
        returns:
          - preds: the predictions for each input in X_new,
          - grads: the gradient of the scalar prediction function with respect to the input,
                   evaluated at each sample in X_new.
        """
        # Define a function of a single sample.
        f = lambda x: predict_single(x, params[0], params[1])
        # Compute the gradient of f with respect to the input x.
        grad_f = jax.grad(f)
        # Vectorize both the function and its gradient over the batch dimension.
        preds = jax.vmap(f)(X_new)
        grads = jax.vmap(grad_f)(X_new)
        return preds, grads

    return predict_and_grad, (theta, intercept)

### Calculate stage and morphological deltas

In [None]:
# get stage shift
hf_cohort_df["stage_hpf_mean"] = model.predict(hf_cohort_df[mean_pca_cols].values)
hf_cohort_df["stage_shift_hpf"] = hf_cohort_df["stage_hpf_mean"] - hf_cohort_df["timepoint_"]

predict_and_grad, params = make_jax_functions(model)

In [None]:
from scipy.spatial import distance_matrix
sd_pca_cols = [col +"_std" for col in pca_cols]
# Get morphological shift 
# Assume distances are small enough that they can be linearized
# stage_dist_mat = distance_matrix(hf_cohort_df["stage_hpf_mean"].values[:, None], spline_stage_df["pd_stage_hpf"].values[:, None])
stage_dist_mat = distance_matrix(hf_cohort_df[mean_pca_cols], spline_stage_df[pca_cols])
hf_cohort_df["knot_index"] = np.argmin(stage_dist_mat, axis=1)
for row in tqdm(range(hf_cohort_df.shape[0])):
    pca_obs = hf_cohort_df.loc[row, mean_pca_cols].to_numpy() # morph mean
    pca_obs_var = hf_cohort_df.loc[row, sd_pca_cols].to_numpy()**2 # morph std
    knot_i = hf_cohort_df.loc[row, "knot_index"]
    pca_ref = spline_stage_df.loc[spline_stage_df["knot_index"]==knot_i, pca_cols].to_numpy() # stage-matched comparison

    # get phenotypic distance
    hf_cohort_df.loc[row, "morph_shift"] = np.sqrt(np.sum((pca_obs - pca_ref)**2))

    # record total variance
    hf_cohort_df.loc[row, "total_variance"] = np.sum(pca_obs_var)

    # use gradient to decompose variance
    stage_pd, grad_pd = predict_and_grad(params, pca_obs[None, :])
    grad_u = np.asarray(grad_pd / np.sqrt(np.sum(grad_pd**2)))[0]
    var_null = 0
    for n in range(100):
        rand_u = np.random.permutation(grad_u.copy())
        var_null += np.dot(rand_u, pca_obs_var)

    hf_cohort_df.loc[row, "stage_variance"] = np.dot(grad_u, pca_obs_var)
    
    hf_cohort_df.loc[row, "stage_variance_null"] = var_null/100
    
    hf_cohort_df.loc[row, "morph_variance"] = hf_cohort_df.loc[row, "total_variance"] - hf_cohort_df.loc[row, "stage_variance"]

In [None]:
np.random.permutation(grad_u)

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="morph_shift", color="temperature_")
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="stage_shift_hpf", color="temperature_")
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="morph_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="stage_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="timepoint_", y="total_variance", color="temperature_")
fig.update_traces(marker=dict(size=8))
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="morph_variance", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            xaxis=dict(range=[0, 1.7]), 
            yaxis=dict(range=[0, 1.7])
        )
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="stage_variance_null", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            xaxis=dict(range=[0, 0.5]), 
            yaxis=dict(range=[0, 0.5])
        )
fig.show()

In [None]:
fig = px.scatter(hf_cohort_df, x="mdl_stage_hpf_std", y="stage_variance", color="temperature_", symbol="timepoint_")
fig.update_traces(marker=dict(size=8))
fig.update_layout(
            height=800,
            width=800,
            # xaxis=dict(range=[0, 0.5]), 
            # yaxis=dict(range=[0, 0.5])
        )
fig.show()

### Make figure showing images for sanity check purposes

In [None]:
import skimage.io as io

image_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/training_data/20241107_ds/images/0/"
hf_snip_vec = hf_umap_df["snip_id"].to_numpy()
hf_time_vec = hf_umap_df["timepoint"].to_numpy()
hf_temp_vec = hf_umap_df["temperature"].to_numpy()
image_list = []
for snip_id in hf_snip_vec:
    im = io.imread(os.path.join(image_path, snip_id + ".jpg"))
    image_list.append(im)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

image_path = os.path.join(fig_path, "cohort_images", "")
os.makedirs(image_path, exist_ok=True)
im_shape = image_list[0].shape 

for time in np.unique(hf_time_vec):
    for temp in np.unique(hf_temp_vec):
        obs_indices = np.where((hf_time_vec==time) & (hf_temp_vec==temp))[0]
        
        # fig = go.Figure() # make_subplots(rows=2, cols=4)
        
        # Add each image to a subplot
        top_list = []
        bottom_list = []
        for i in range(8):
            if len(obs_indices) > i:
                im = image_list[obs_indices[i]]
            else:
                im = np.zeros(im_shape, dtype=np.uint8)
                
            if i < 4:
                top_list.append(im)
            else:
                bottom_list.append(im)

        tiled_image = np.block([top_list,
                                bottom_list])
        
        fig = px.imshow(tiled_image, color_continuous_scale="gray", title=f"{temp:02}C @{time:02}hpf")

        
        # Update layout for better display
        # fig.update_layout(
        #     height=600,
        #     width=1200,
        #     title_text="Multiple Images in Plotly"
        # )
        
        fig.write_image(image_path + f"embryo_images_tp{time:02}_temp{temp:02}.png", engine="kaleido")

# fig.show()

### Is it possible to fit to the derivatives?

In [None]:
# get point-over-point differences
cols_to_diff = pca_cols + ["experiment_time"]
diff_cols = [col + "_diff" for col in cols_to_diff]
dt_cols = [col + "_dt" for col in cols_to_diff]
ref_umap_df_dt = ref_umap_df.copy()
ref_umap_df_dt[diff_cols] = ref_umap_df_dt.groupby('embryo_id')[cols_to_diff].diff()
ref_umap_df_dt = ref_umap_df_dt.fillna(method='bfill') 

# we want to calculate the rate of time changes wrpt 
ref_umap_df_dt[dt_cols[:-1]] = np.divide(ref_umap_df_dt[diff_cols[-1]].values[:, None], ref_umap_df_dt[diff_cols[:-1]].values)

In [None]:
# Suppose we have K measurement points in an N-dimensional space.
# D_data: (K, N) array of points.
# G_data: (K, N) array of measured gradients at those points.
# d: polynomial degree

def multiindex_list(N, d):
    # Generate list of multi-indices (tuples) for N dimensions up to degree d.
    # This is a helper function; many implementations exist.
    indices = []
    def rec(current, start, remaining):
        if remaining == 0:
            indices.append(tuple(current))
        else:
            for i in range(start, N):
                new_current = current.copy()
                new_current[i] += 1
                rec(new_current, i, remaining-1)
    # Include all degrees from 0 up to d
    for degree in range(d+1):
        # Initialize multi-index with zeros
        base = [0]*N
        # Recursively fill in
        rec(base, 0, degree)
    return indices

def build_A(D_data):
    for k in range(len(D_data)):
        Dk = D_data[k]  # shape (N,)
        for j in range(len(Dk)):
            row = []
            for alpha in multiindices:
                # For the derivative with respect to D_j,
                # the coefficient is: alpha[j] * Dk^(alpha - e_j)
                # If alpha[j] == 0, this term is zero.
                if alpha[j] == 0:
                    row.append(0.0)
                else:
                    # Compute Dk^(alpha - e_j)
                    term = 1.0
                    for i in range(N):
                        exponent = alpha[i] - (1 if i == j else 0)
                        term *= Dk[i]**exponent if exponent > 0 else 1.0
                    row.append(alpha[j] * term)
            A.append(row)
            
    return np.array(A)

def build_b(G_data):
    for k in range(G_data.shape[0]):
        for j in range(G_data.shape[1]):
            b.append(G_data[k, j])
            
    return np.array(b)

def evaluate_polynomial_array(D, multiindices, c):
    """
    Evaluate the polynomial at multiple points.
    
    Parameters:
    - D: numpy array of shape (M, N) where each row is an N-dimensional input.
    - multiindices: list of tuples, each tuple being the exponents for one term.
    - c: numpy array of coefficients corresponding to each multi-index.
    
    Returns:
    - predictions: numpy array of shape (M,) with the computed polynomial values.
    """
    D = np.asarray(D)  # Ensure D is a numpy array
    M, N = D.shape
    predictions = np.zeros(M)
    
    for coeff, alpha in zip(c, multiindices):
        # Compute the term D^alpha for each point.
        # Convert alpha to an array to enable broadcasting.
        alpha_array = np.array(alpha)
        # For each point, compute the product of each dimension raised to the corresponding power.
        term = coeff * np.prod(D ** alpha_array, axis=1)
        predictions += term
    return predictions

In [None]:
N = len(pca_cols)
d = 2  # for example, quadratic polynomial

# Get multi-index list for polynomial basis.
multiindices = multiindex_list(N, d)
num_terms = len(multiindices)

# Build design matrix A and measurement vector b.
# There will be K * N equations (each derivative component).
A = []
b = []
D_data = ref_umap_df_dt[pca_cols].to_numpy()
G_data = ref_umap_df_dt[dt_cols[:-1]].to_numpy()  

A = build_A(D_data)
b = build_b(G_data)

# Solve the least squares problem
c, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)

In [None]:
prediction = evaluate_polynomial_array(D_data, multiindices, c) / 3600

In [None]:


fig = px.scatter(x=ref_umap_df["predicted_stage_hpf"], y=prediction)
fig.show()