In [None]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import os

In [None]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
embryo_metadata_df = pd.read_csv(os.path.join(root, "metadata", "combined_metadata_files", "embryo_metadata_df01.csv"))
embryo_metadata_df["experiment_date"] = embryo_metadata_df["experiment_date"].astype(str)
embryo_metadata_df.head()

In [None]:
############
# Clean up chemical perturbation variable and create a master perturbation variable
# Make a master perturbation class
embryo_metadata_df["chem_perturbation"] = embryo_metadata_df["chem_perturbation"].astype(str)
embryo_metadata_df.loc[np.where(embryo_metadata_df["chem_perturbation"] == 'nan')[0], "chem_perturbation"] = "None"

embryo_metadata_df["master_perturbation"] = embryo_metadata_df["chem_perturbation"].copy()
embryo_metadata_df.loc[np.where(embryo_metadata_df["master_perturbation"] == "None")[0], "master_perturbation"] = \
    embryo_metadata_df["genotype"].iloc[
        np.where(embryo_metadata_df["master_perturbation"] == "None")[0]].copy().values

In [None]:
# join on additional perturbation info
pert_name_key = pd.read_csv(os.path.join(root, 'metadata', "perturbation_name_key.csv"))
embryo_metadata_df = embryo_metadata_df.merge(pert_name_key, how="left", on="master_perturbation", indicator=True)
if np.any(embryo_metadata_df["_merge"] != "both"):
    problem_perts = np.unique(embryo_metadata_df.loc[embryo_metadata_df["_merge"] != "both", "master_perturbation"])
    raise Exception("Some perturbations were not found in key: " + ', '.join(problem_perts.tolist()))
embryo_metadata_df.drop(labels=["_merge"], axis=1, inplace=True)

### Look at embryo length vs. predicted stage

In [None]:
ref_date00 = "20240626"
ref_date01 = "20230620"


date_df00 = embryo_metadata_df.loc[embryo_metadata_df["experiment_date"] == ref_date00, ["snip_id", "embryo_id", "time_int", "Time Rel (s)", "short_pert_name",
                        "phenotype", "control_flag", "predicted_stage_hpf", "surface_area_um", "use_embryo_flag"]].reset_index(drop=True)

# calculate length percentiles
ref_bool = (date_df00.loc[:, "phenotype"].to_numpy() == "wt") | (date_df00.loc[:, "control_flag"].to_numpy() == 1)
ref_bool = ref_bool & date_df00["use_embryo_flag"]
date_df_ref00 = date_df00.loc[ref_bool]

# date_df["length_um"] = date_df["length_um"]*1.5
date_df_ref00["stage_group_hpf"] = np.round(date_df_ref00["predicted_stage_hpf"])   # ["predicted_stage_hpf"])
date_key_df00 = date_df_ref00.loc[:, ["stage_group_hpf", "surface_area_um"]].groupby(
                                                ['stage_group_hpf']).quantile(.95).reset_index()


date_df01 = embryo_metadata_df.loc[embryo_metadata_df["experiment_date"] == ref_date01, ["snip_id", "embryo_id", "time_int", "Time Rel (s)", "short_pert_name",
                        "phenotype", "control_flag", "predicted_stage_hpf", "surface_area_um", "use_embryo_flag"]].reset_index(drop=True)
# calculate length percentiles
ref_bool = (date_df01.loc[:, "phenotype"].to_numpy() == "wt") | (date_df01.loc[:, "control_flag"].to_numpy() == 1)
ref_bool = ref_bool & date_df01["use_embryo_flag"]
date_df_ref01 = date_df01.loc[ref_bool]

# date_df["length_um"] = date_df["length_um"]*1.5
date_df_ref01["stage_group_hpf"] = np.round(date_df_ref01["predicted_stage_hpf"])   # ["predicted_stage_hpf"])
date_key_df01 = date_df_ref01.loc[:, ["stage_group_hpf", "surface_area_um"]].groupby(
                                    ['stage_group_hpf']).quantile(.95).reset_index()

date_key_df = pd.concat([date_key_df00, date_key_df01.loc[date_key_df01["stage_group_hpf"] <=14, :]], axis=0, ignore_index=True)

px.scatter(x=date_key_df["stage_group_hpf"], y=date_key_df["surface_area_um"])

### Fit a sigmoidal function to generate an SA-based staging key

In [None]:
import scipy 

t_vec_full = np.linspace(0, 72)
t_vec = date_key_df["stage_group_hpf"]
sa_vec = date_key_df["surface_area_um"]

def sigmoid(params, t_vec=t_vec):
    sa_pd = params[0] + params[1] * np.divide(t_vec**params[2], params[3]**params[2] + t_vec**params[2])
    return sa_pd

def loss_fun(params, sa_vec=sa_vec):
    sa_pd = sigmoid(params)
    return sa_pd-sa_vec


x0 = [4e5, 1e6, 2, 24]
lb = (0, 0, 0, 0)
ub = (np.inf, np.inf, np.inf, np.inf)
params_fit = scipy.optimize.least_squares(loss_fun, x0, bounds=[lb, ub])

sa_pd_full = sigmoid(params_fit.x, t_vec=t_vec_full)

fig = px.scatter(x=date_key_df["stage_group_hpf"], y=date_key_df["surface_area_um"])
fig.add_trace(go.Scatter(x=t_vec_full, y=sa_pd_full, mode="lines"))
fig.show()

### Save

In [None]:
stage_key_df = pd.DataFrame(t_vec_full, columns=["stage_hpf"])
stage_key_df["sa_um"] = sa_pd_full
stage_key_df.to_csv(os.path.join(root, "metadata", "stage_ref_df01.csv"), index=False)

In [None]:
stage_key_prev = pd.read_csv(os.path.join(root, "metadata", "stage_ref_df.csv"), index_col=0)
stage_key_prev.head()

In [None]:
fig = px.scatter(x=stage_key_df["stage_hpf"], y=stage_key_df["sa_um"])
fig.add_trace(go.Scatter(x=stage_key_prev["stage_hpf"], y=stage_key_prev["sa_um"]))
fig.show()