In [1]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import os

In [54]:
root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"
metadata_df = pd.read_csv(os.path.join(root, "metadata", "combined_metadata_files", "embryo_metadata_df02.csv"))
# metadata_df = metadata_df.loc[metadata_df["use_embryo_flag"]==1, :]
np.sum(metadata_df["use_embryo_flag"])


Columns (3,8,12,15,16,17,57,58,59) have mixed types. Specify dtype option on import or set low_memory=False.



75690

In [56]:
np.unique(metadata_df["experiment_date"].astype(str))

array(['20230525', '20230531', '20230602', '20230608', '20230613',
       '20230615', '20230620', '20230622', '20230627', '20230629',
       '20230830', '20230831', '20231110', '20231206', '20231207',
       '20231208', '20231218', '20240306', '20240307', '20240314',
       '20240404', '20240411', '20240418', '20240507', '20240509',
       '20240509_18ss', '20240509_24hpf', '20240510', '20240522',
       '20240530', '20240626'], dtype=object)

### Look at embryo length vs. predicted stage

In [49]:
sample_date = "20240626"
metadata_df["experiment_date"] = metadata_df["experiment_date"].astype(str)
ft = (metadata_df["experiment_date"]==sample_date)# & (metadata_df["embryo_id"]=="20230531_D03_e00")
fig = px.scatter(metadata_df.loc[ft, :], 
                 x="predicted_stage_hpf", y="length_um", hover_data={"snip_id"})
fig.show()

### Write outline for stage calibration function that uses embryo length

#### Build reference key

In [5]:
ref_date = "20240626"
stage_df = metadata_df.loc[metadata_df["experiment_date"]==ref_date, ["snip_id", "embryo_id", "short_pert_name", "phenotype", "control_flag", "predicted_stage_hpf", "length_um", "use_embryo_flag"]].reset_index(drop=True)
ref_bool = (stage_df.loc[:, "phenotype"].to_numpy() == "wt") | (stage_df.loc[:, "control_flag"].to_numpy() == 1)
ref_bool = ref_bool | (stage_df.loc[:, "phenotype"].to_numpy() == "uncertain")
ref_bool = ref_bool & stage_df["use_embryo_flag"]
stage_df = stage_df.loc[ref_bool]
stage_df["stage_group_hpf"] = np.round(stage_df["predicted_stage_hpf"])
stage_key_df = stage_df.groupby('stage_group_hpf').quantile(.95).reset_index()

fig = px.scatter(stage_key_df, x="stage_group_hpf", y="length_um")
fig.show()


The default value of numeric_only in DataFrameGroupBy.quantile is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.



### Test on a sample dataset

In [52]:
register_date = "20240813_30hpf"
register_df = metadata_df.loc[metadata_df["experiment_date"].astype(str)==register_date, 
                              ["snip_id", "embryo_id", "time_int","short_pert_name", 
                        "phenotype", "control_flag", "predicted_stage_hpf", "length_um", "use_embryo_flag"]].reset_index(drop=True)

# check for multiple age cohorts
min_t = np.min(register_df["time_int"])
cohort_key = register_df.loc[register_df["time_int"]==min_t, ["embryo_id", "predicted_stage_hpf"]]
age_u, age_cohort = np.unique(np.round(cohort_key["predicted_stage_hpf"]/ 2.5) * 2.5, return_inverse=True)
cohort_key["cohort_id"] = age_cohort

# join onto main df
register_df = register_df.merge(cohort_key.loc[:, ["embryo_id", "cohort_id"]], how="left", on="embryo_id")

# calculate length percentiles
ref_bool = (register_df.loc[:, "phenotype"].to_numpy() == "wt") | (register_df.loc[:, "control_flag"].to_numpy() == 1)
ref_bool = ref_bool #& register_df["use_embryo_flag"]

register_df = register_df.loc[ref_bool]
register_df["stage_group_hpf"] = np.round(register_df["predicted_stage_hpf"])
register_key_df = register_df.groupby(['stage_group_hpf', "cohort_id"]).quantile(.95).reset_index()

fig = px.scatter(register_key_df, x="stage_group_hpf", y="length_um", color="cohort_id")
fig.show()


The default value of numeric_only in DataFrameGroupBy.quantile is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.



In [51]:
import scipy 

interp = scipy.interpolate.interp1d(stage_key_df["length_um"], stage_key_df["stage_group_hpf"], 
                                    kind="linear", fill_value=np.nan, bounds_error=False)

ref_hpf_interp = interp(register_key_df["length_um"])
register_key_df["stage_hpf_interp"] = ref_hpf_interp

fig = px.scatter(register_key_df, x="stage_group_hpf", y="stage_hpf_interp", color="cohort_id")
fig.show()

#### Fit simple linear regression model

In [18]:
import statsmodels.api as sm


Y = register_key_df['stage_hpf_interp']

nan_ft = ~np.isnan(Y)


X = register_key_df[['stage_group_hpf', 'cohort_id']] #, columns=['cohort_id'], drop_first=True)
X = X.rename(columns={'stage_group_hpf':'stage'})
X["stage2"] = X["stage"]**2
X["interaction"] = np.prod(X[['stage', 'cohort_id']].to_numpy(), axis=1)
X["interaction2"] = np.prod(X[['stage2', 'cohort_id']].to_numpy(), axis=1)

# Add a constant (intercept term) to the predictor matrix
# X = sm.add_constant(X)

X_ft = X[nan_ft]
Y_ft = Y[nan_ft]

# Fit the OLS regression model
model = sm.OLS(Y_ft, X_ft).fit()

# Print the regression results
print(model.summary())

                                 OLS Regression Results                                
Dep. Variable:       stage_hpf_interp   R-squared (uncentered):                   0.999
Model:                            OLS   Adj. R-squared (uncentered):              0.999
Method:                 Least Squares   F-statistic:                          2.739e+04
Date:                Wed, 16 Oct 2024   Prob (F-statistic):                    5.16e-58
Time:                        08:59:19   Log-Likelihood:                         -58.039
No. Observations:                  38   AIC:                                      120.1
Df Residuals:                      36   BIC:                                      123.4
Df Model:                           2                                                  
Covariance Type:            nonrobust                                                  
                   coef    std err          t      P>|t|      [0.025      0.975]
---------------------------------------


divide by zero encountered in double_scalars



In [19]:
fig = px.scatter(register_key_df, x="stage_group_hpf", y="stage_hpf_interp", color="cohort_id")
fig.update_layout(coloraxis_showscale=False)
predictions = model.predict(X)
g0 = X["cohort_id"]==0
g1 = X["cohort_id"]==1
fig.add_trace(go.Scatter(x=X.loc[g0, "stage"], y=predictions[g0], mode="lines"))
fig.add_trace(go.Scatter(x=X.loc[g1, "stage"], y=predictions[g1], mode="lines"))
fig.show()


In [112]:
from sklearn.linear_model import Ridge

Y = register_key_df['stage_hpf_interp']

# Fit a Ridge model (which includes regularization)
ridge_model = Ridge(alpha=100.0, fit_intercept=True)  # Regularization strength controlled by alpha
ridge_model.fit(X, Y)

# Coefficients (including the intercept)
print(f"Intercept: {ridge_model.intercept_}")
print(f"Other Coefficients: {ridge_model.coef_}")

Intercept: 7.064790214497016
Other Coefficients: [0.         0.7435761  0.00414592]


### Use length deciles to estimate stage

In [15]:
metadata_path = os.path.join(root, "metadata", "")
length_key = pd.read_csv(metadata_path + "stage_length_key.csv")
px.scatter(length_key, x="length_mm_lin", y="stage_hpf")

In [13]:
import scipy 

interp = scipy.interpolate.interp1d(length_key["length_mm_lin"], length_key["stage_hpf"], 
                                    kind="linear", fill_value=np.nan, bounds_error=False)

stage_hpf_interp = interp(length_df["length_mm"])
length_df["stage_hpf_interp"] = stage_hpf_interp

fig = px.scatter(length_df, x="stage_group", y="stage_hpf_interp")
fig.show()

In [26]:
length_df["length_mm"]

0     0.925418
1     0.960535
2     1.013924
3     1.197659
4     1.479715
5     1.768502
6     1.956264
7     2.086025
8     2.275568
9     2.477180
10    2.669709
11    2.794956
12    2.914311
13    3.034803
14    3.121914
15    3.163997
16    3.266087
17    3.357655
18    3.423700
19    3.506898
20    3.564949
21    3.606951
22    3.637629
23    3.704097
24    3.723944
25    3.802948
26    3.871443
27    3.880929
28    3.883825
29    3.941939
30    3.992889
31    3.944519
Name: length_mm, dtype: float64