In [None]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import os

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
metadata_df = pd.read_csv(os.path.join(root, "metadata", "combined_metadata_files", "embryo_metadata_df01.csv"))
metadata_df = metadata_df.loc[metadata_df["use_embryo_flag"]==1, :]
np.mean(metadata_df["use_embryo_flag"])

In [None]:
np.unique(metadata_df["experiment_date"].astype(str))

### Look at embryo length vs. predicted stage

In [None]:
sample_date = "20240306"
metadata_df["experiment_date"] = metadata_df["experiment_date"].astype(str)
ft = (metadata_df["experiment_date"]==sample_date)# & (metadata_df["embryo_id"]=="20230531_D03_e00")
fig = px.scatter(metadata_df.loc[ft, :], 
                 x="predicted_stage_hpf", y="surface_area_um", color="use_embryo_flag", hover_data={"snip_id"})
fig.show()

### Write outline for stage calibration function that uses embryo length

#### Build reference key

In [None]:
ref_date01 = "20230620"
stage_df01 = metadata_df.loc[metadata_df["experiment_date"]==ref_date01, ["snip_id", "embryo_id", "predicted_stage_hpf", "surface_area_um", "length_um", "use_embryo_flag"]].reset_index(drop=True)
# ref_bool = (stage_df.loc[:, "phenotype"].to_numpy() == "wt") | (stage_df.loc[:, "control_flag"].to_numpy() == 1)
# ref_bool = ref_bool | (stage_df.loc[:, "phenotype"].to_numpy() == "uncertain")
# ref_bool = ref_bool & stage_df["use_embryo_flag"]
# stage_df = stage_df.loc[ref_bool]
stage_df01["stage_group_hpf"] = np.round(stage_df01["predicted_stage_hpf"])
stage_key_df01 = stage_df01.groupby('stage_group_hpf').quantile(.90).reset_index()

ref_date02 = "20240626"
stage_df02 = metadata_df.loc[metadata_df["experiment_date"]==ref_date02, ["snip_id", "embryo_id", "predicted_stage_hpf", "surface_area_um", "length_um", "use_embryo_flag"]].reset_index(drop=True)
stage_df02["stage_group_hpf"] = np.round(stage_df02["predicted_stage_hpf"])
stage_key_df02 = stage_df02.groupby('stage_group_hpf').quantile(.90).reset_index()

stage_key_df = pd.concat([stage_key_df01.loc[stage_key_df01["predicted_stage_hpf"] <= 12, :], stage_key_df02], axis=0, ignore_index=True)

fig = px.scatter(stage_key_df, x="stage_group_hpf", y="surface_area_um")
fig.show()

In [None]:
import scipy

# A vs time
ref_time_vec = stage_key_df["stage_group_hpf"]
ref_sa_vec = stage_key_df["surface_area_um"]


def sigmoid(params, t_vec=ref_time_vec):
    sa_pd = params[0] + params[1]* np.divide(t_vec**params[2], params[3]**params[2] + t_vec**params[2])
    return sa_pd

# define loss
def loss_fun(params, sa=ref_sa_vec):
    loss = sigmoid(params) - ref_sa_vec
    return loss

# def height_fun(params, repo_areas=repo_lengths[boot_indices]):
#         h = np.divide(params[0] * repo_areas**params[2], (params[1]**params[2] + repo_areas**params[2]))
#         return h

# def loss_fun(params, repo_heights=repo_heights[boot_indices]):
#     h_hat = height_fun(params)
#     return repo_heights - h_hat

x0 = [3e5, 1.6e6, 2, 24] 
# sigmoid(x0)
params_fit = scipy.optimize.least_squares(loss_fun, x0, bounds=[(0, 0, 0, 0), (np.inf, np.inf, np.inf, np.inf)])

print(params_fit)
full_time = np.linspace(0, 96)
sa_pd = sigmoid(params_fit.x, t_vec=full_time)

fig = px.scatter(stage_key_df, x="stage_group_hpf", y="surface_area_um")
fig.add_trace(go.Scatter(x=full_time, y=sa_pd, mode="lines"))
fig.show()


In [None]:
# save 
stage_ref_df =

### Test on a sample dataset

In [None]:
register_date = "20240813_30hpf"
register_df = metadata_df.loc[metadata_df["experiment_date"].astype(str)==register_date, 
                              ["snip_id", "embryo_id", "time_int","short_pert_name", 
                        "phenotype", "control_flag", "predicted_stage_hpf", "length_um", "use_embryo_flag"]].reset_index(drop=True)

# check for multiple age cohorts
min_t = np.min(register_df["time_int"])
cohort_key = register_df.loc[register_df["time_int"]==min_t, ["embryo_id", "predicted_stage_hpf"]]
age_u, age_cohort = np.unique(np.round(cohort_key["predicted_stage_hpf"]/ 2.5) * 2.5, return_inverse=True)
cohort_key["cohort_id"] = age_cohort

# join onto main df
register_df = register_df.merge(cohort_key.loc[:, ["embryo_id", "cohort_id"]], how="left", on="embryo_id")

# calculate length percentiles
ref_bool = (register_df.loc[:, "phenotype"].to_numpy() == "wt") | (register_df.loc[:, "control_flag"].to_numpy() == 1)
ref_bool = ref_bool #& register_df["use_embryo_flag"]

register_df = register_df.loc[ref_bool]
register_df["stage_group_hpf"] = np.round(register_df["predicted_stage_hpf"])
register_key_df = register_df.groupby(['stage_group_hpf', "cohort_id"]).quantile(.95).reset_index()

fig = px.scatter(register_key_df, x="stage_group_hpf", y="length_um", color="cohort_id")
fig.show()

In [None]:
import scipy 

interp = scipy.interpolate.interp1d(stage_key_df["length_um"], stage_key_df["stage_group_hpf"], 
                                    kind="linear", fill_value=np.nan, bounds_error=False)

ref_hpf_interp = interp(register_key_df["length_um"])
register_key_df["stage_hpf_interp"] = ref_hpf_interp

fig = px.scatter(register_key_df, x="stage_group_hpf", y="stage_hpf_interp", color="cohort_id")
fig.show()

#### Fit simple linear regression model

In [None]:
import statsmodels.api as sm


Y = register_key_df['stage_hpf_interp']

nan_ft = ~np.isnan(Y)


X = register_key_df[['stage_group_hpf', 'cohort_id']] #, columns=['cohort_id'], drop_first=True)
X = X.rename(columns={'stage_group_hpf':'stage'})
X["stage2"] = X["stage"]**2
X["interaction"] = np.prod(X[['stage', 'cohort_id']].to_numpy(), axis=1)
X["interaction2"] = np.prod(X[['stage2', 'cohort_id']].to_numpy(), axis=1)

# Add a constant (intercept term) to the predictor matrix
# X = sm.add_constant(X)

X_ft = X[nan_ft]
Y_ft = Y[nan_ft]

# Fit the OLS regression model
model = sm.OLS(Y_ft, X_ft).fit()

# Print the regression results
print(model.summary())

In [None]:
fig = px.scatter(register_key_df, x="stage_group_hpf", y="stage_hpf_interp", color="cohort_id")
fig.update_layout(coloraxis_showscale=False)
predictions = model.predict(X)
g0 = X["cohort_id"]==0
g1 = X["cohort_id"]==1
fig.add_trace(go.Scatter(x=X.loc[g0, "stage"], y=predictions[g0], mode="lines"))
fig.add_trace(go.Scatter(x=X.loc[g1, "stage"], y=predictions[g1], mode="lines"))
fig.show()


In [None]:
from sklearn.linear_model import Ridge

Y = register_key_df['stage_hpf_interp']

# Fit a Ridge model (which includes regularization)
ridge_model = Ridge(alpha=100.0, fit_intercept=True)  # Regularization strength controlled by alpha
ridge_model.fit(X, Y)

# Coefficients (including the intercept)
print(f"Intercept: {ridge_model.intercept_}")
print(f"Other Coefficients: {ridge_model.coef_}")

### Use length deciles to estimate stage

In [None]:
metadata_path = os.path.join(root, "metadata", "")
length_key = pd.read_csv(metadata_path + "stage_length_key.csv")
px.scatter(length_key, x="length_mm_lin", y="stage_hpf")

In [None]:
import scipy 

interp = scipy.interpolate.interp1d(length_key["length_mm_lin"], length_key["stage_hpf"], 
                                    kind="linear", fill_value=np.nan, bounds_error=False)

stage_hpf_interp = interp(length_df["length_mm"])
length_df["stage_hpf_interp"] = stage_hpf_interp

fig = px.scatter(length_df, x="stage_group", y="stage_hpf_interp")
fig.show()

In [None]:
length_df["length_mm"]