In [None]:
import os
import pandas as pd
import glob as glob

# set root path to directory
root = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/morphseq/"
# root = "E:\\Nick\\Dropbox (Cole Trapnell's Lab)\\Nick\\morphseq\\"

# read in metadata
metadata_path = os.path.join(root, 'metadata', '')
embryo_metadata_df = pd.read_csv(os.path.join(metadata_path, "embryo_metadata_df.csv"), index_col=0)

# path to image and snip files
im_snip_dir = os.path.join(root, 'training_data', 'bf_embryo_snips', '')
mask_snip_dir = os.path.join(root, 'training_data', 'bf_embryo_masks', '')

# get lists of snips and masks
im_snip_files = sorted(glob.glob(im_snip_dir + "*.tif"))
emb_mask_files = sorted(glob.glob(mask_snip_dir + "emb*.tif"))

## Plot basic trends in embryo growth over time

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import plotly.offline as pyo
pyo.init_notebook_mode()

# filter our low-quality frames
embryo_metadata_use = embryo_metadata_df.iloc[np.where(embryo_metadata_df["use_embryo_flag"]==1)]
embryo_metadata_use.reset_index(inplace=True)
embryo_metadata_use["snip_id"] = embryo_metadata_use["embryo_id"] + "_" + \
            embryo_metadata_use["time_int"].astype(str)

# plot surface area over time
fig = px.scatter(embryo_metadata_use, x='predicted_stage_hpf', y="surface_area_um", color="embryo_id", 
                 opacity=0.5, title="surface area vs. age", template="plotly")
fig.show()

## Use wck-AB to calculate curve for flagging SA outliers

In [None]:
# Make a master perturbation class
embryo_metadata_use["master_perturbation"] = embryo_metadata_use["chem_perturbation"].astype(str)
embryo_metadata_use["master_perturbation"].iloc[np.where(embryo_metadata_use["master_perturbation"]=="nan")[0]] = \
    embryo_metadata_use["genotype"].iloc[np.where(embryo_metadata_use["master_perturbation"]=="nan")[0]].copy().values

print(np.unique(embryo_metadata_use["master_perturbation"].values))

In [None]:
import scipy

perturbation_class_vec = np.unique(embryo_metadata_use["master_perturbation"])


min_embryos = 10

use_indices = np.where(embryo_metadata_use["master_perturbation"]=='wck-AB')[0]

sa_vec_ref = embryo_metadata_use["surface_area_um"].iloc[use_indices].values
time_vec_ref = embryo_metadata_use['predicted_stage_hpf'].iloc[use_indices].values

sa_vec_all = embryo_metadata_use["surface_area_um"].values
time_vec_all = embryo_metadata_use['predicted_stage_hpf'].values

embryo_metadata_use['sa_outlier_flag'] = True

hpf_window = 0.75
frac_cushion = 1.15
offset_cushion = 100000
prct = 95
ul = np.max(embryo_metadata_use['predicted_stage_hpf'])
ll = np.min(embryo_metadata_use['predicted_stage_hpf'])
time_index = np.linspace(0, 72, 145)
percentile_array = np.empty((len(time_index),))
percentile_array[:] = np.nan

# iterate through time points
first_i = np.nan
last_i = np.nan
for t, time in enumerate(time_index):
    t_indices_ref = np.where((time_vec_ref>=time-hpf_window) & (time_vec_ref<=time+hpf_window))[0]
#     t_indices_all = np.where((time_vec_all>=time-hpf_window) & (time_vec_all<=time+hpf_window))[0]
    
    if len(t_indices_ref) >= min_embryos:
        sa_vec_t_ref = sa_vec_ref[t_indices_ref].copy()
#         sa_vec_t_all = sa_vec_all[t_indices_all].copy()
        
        percentile_array[t] = np.percentile(sa_vec_t_ref, prct)
        
#         out_flags = sa_vec_t_all >= (offset_cushion + percentile_array[t, 0])
#         out_prev = embryo_metadata_use['sa_outlier_flag'].iloc[t_indices_all].values
        
        if np.isnan(first_i):
            first_i = t
    elif ~np.isnan(first_i):
        last_i = t
            
#         out_prev[np.where(np.isnan(out_prev))] = 1
        
#         out_indices = np.where(np.multiply(out_flags, out_prev))[0]
        
#         embryo_metadata_use['sa_outlier_flag'].iloc[t_indices_all] = (out_flags & out_prev)#np.multiply(out_flags, out_prev)
        
#         if len(out_indices) > 0:
#             embryo_metadata_use.loc[t_indices_all[out_indices], "sa_outlier_flag"] = 1
        
# fill in blanks
percentile_array[:first_i] = percentile_array[first_i]
percentile_array[last_i+1:] = percentile_array[last_i]

# smooth
sa_bound_sm = offset_cushion + scipy.signal.savgol_filter(percentile_array, window_length=5, polyorder=2)

# flag outliers
t_ids = np.digitize(time_vec_all,bins=time_index)        

for t in range(len(time_index)-1):
    t_indices = np.where(t_ids == t)
    sa_vec_t_all = sa_vec_all[t_indices].copy()
    embryo_metadata_use['sa_outlier_flag'].iloc[t_indices] = sa_vec_t_all > sa_bound_sm[t]

# plot surface area over time
fig = px.scatter(embryo_metadata_use, x='predicted_stage_hpf', y="surface_area_um", color="sa_outlier_flag", 
                 opacity=0.5, title="surface area vs. age", template="plotly")

fig.add_trace(go.Scatter(x=time_index, y=sa_bound_sm, mode="lines", name="decision boundary"))

fig.show()

## How far prior to "official" death should I call embryos?

In [None]:
# filter on QC metrics, excluding live/dead status
embryo_metadata_df2 = embryo_metadata_df.copy()
embryo_metadata_df2["qc_flag"] = embryo_metadata_df["use_embryo_flag"] = ~(
                embryo_metadata_df["bubble_flag"].values | embryo_metadata_df["focus_flag"].values |
                embryo_metadata_df["frame_flag"].values |
                embryo_metadata_df["no_yolk_flag"].values)

embryo_metadata_df2 = embryo_metadata_df2.iloc[np.where(embryo_metadata_df2["qc_flag"]==True)]

wck_indices = np.where(embryo_metadata_use["master_perturbation"]=='wck-AB')[0]
embryo_metadata_df2 = embryo_metadata_df2.iloc[wck_indices]

# group embryos by age
time_bins_death = np.linspace(8, 48, 5)


embryo_metadata_df2["ever_dead_flag"] = False
embryo_metadata_df2["hours_from_death"] = np.inf
embryo_metadata_df2["age_at_death"] = np.nan

# calculate time relative to death
embryo_id_index = np.unique(embryo_metadata_df2["embryo_id"])

for e, eid in enumerate(embryo_id_index):
    e_indices = np.where(embryo_metadata_df2["embryo_id"]==eid)[0]
    ever_dead_flag = np.any(embryo_metadata_df2["dead_flag"].iloc[e_indices]==True)
    embryo_metadata_df2["ever_dead_flag"].iloc[e_indices] = ever_dead_flag
    if ever_dead_flag:
        d_ind = np.where(embryo_metadata_df2["dead_flag"].iloc[e_indices]==True)[0][0]
        d_time = embryo_metadata_df2["predicted_stage_hpf"].iloc[e_indices[d_ind]]
        embryo_metadata_df2["age_at_death"].iloc[e_indices] = d_time
        embryo_metadata_df2["hours_from_death"] = embryo_metadata_df2["predicted_stage_hpf"].values - d_time
        
        
embryo_metadata_df2["death_age_group"] = np.digitize(embryo_metadata_df2["age_at_death"], bins=time_bins_death)        
embryo_metadata_df2["death_age_group"].iloc[np.where(np.isnan(embryo_metadata_df2["age_at_death"].values))] = np.inf

In [None]:
# calculate a standard curve that features "definitely healthy" embryos at least 3 hours from death

min_embryos = 5
rel_time_ref = np.linspace(-6, 2, 16)

dead_len_array = np.empty((len(rel_time_ref), len(time_bins_death)))
dead_len_array[:] = np.nan
healthy_len_array = np.empty((len(rel_time_ref), len(time_bins_death)))
healthy_len_array[:] = np.nan

for t, dt in enumerate(time_bins_death):
    
    died_indices = np.where(embryo_metadata_df2["death_age_group"].values == t)[0]
    died_eids = np.unique(embryo_metadata_df2["embryo_id"].iloc[died_indices])
    survived_indices = np.where(embryo_metadata_df2["death_age_group"].values > t)[0]
    survived_eids = np.unique(embryo_metadata_df2["embryo_id"].iloc[survived_indices])

    # initialize arrays
    dt_array_dead = np.empty((len(rel_time_ref), len(died_eids)))
    dt_array_dead[:] = np.nan
    dt_array_live = np.empty((len(rel_time_ref), len(survived_eids)))
    dt_array_live[:] = np.nan
    
    if len(died_eids) >= min_embryos:
        d_times = []
        for e, eid in enumerate(died_eids):
            e_indices = np.where(embryo_metadata_df2["embryo_id"] == eid)[0]
            d_time = np.unique(embryo_metadata_df2["age_at_death"].iloc[e_indices].values)
            d_times.append(d_time[0])
            len_vec = embryo_metadata_df2["length_um"].iloc[e_indices]
            rel_time_vec = embryo_metadata_df2["predicted_stage_hpf"].iloc[e_indices] - d_time
            
            # interpolate for simplicity of comparison 
            len_vec_interp = np.interp(rel_time_ref, rel_time_vec, len_vec)
            
            dt_array_dead[:, e] = len_vec_interp
            
        # randomly sample death times
        death_times_rd = np.random.choice(d_times, len(survived_eids), replace=True)
        
        for e, eid in enumerate(survived_eids):
            e_indices = np.where(embryo_metadata_df2["embryo_id"]==eid)[0]
            d_time = death_times_rd[e]
            len_vec = embryo_metadata_df2["length_um"].iloc[e_indices]
            rel_time_vec = embryo_metadata_df2["predicted_stage_hpf"].iloc[e_indices] - d_time
            
            # interpolate for simplicity of comparison 
            len_vec_interp = np.interp(rel_time_ref, rel_time_vec, len_vec)
            
            dt_array_live[:, e] = len_vec_interp
            
            
    # take averages
    dead_len_array[:, t] = np.nanmean(dt_array_dead, axis=1)
    healthy_len_array[:, t] = np.nanmean(dt_array_live, axis=1)
        
            
            


In [None]:
color_list = px.colors.qualitative.Alphabet
fig = go.Figure()
for t in range(len(time_bins_death)):
    fig.add_trace(go.Scatter(x=rel_time_ref, y=dead_len_array[:, t]/dead_len_array[0, t], 
                             marker=dict(color=color_list[t]), line=dict(dash='dash')))
    fig.add_trace(go.Scatter(x=rel_time_ref, y=healthy_len_array[:, t]/healthy_len_array[0, t], 
                              marker=dict(color=color_list[t]), line=dict(color=color_list[t])))
    
fig.update_layout(title='Change in length prior to death',
                   xaxis_title='hours from death',
                   yaxis_title='length (normalized)')
    
fig.show()

Likely there is no perfect point at which to set the cuttoff. 4 hours would be the conservative choice, but I worry that this will rule out many perfectly healthy embryos. I might adopt 2 hours as a compromise position in the interest of putting as much data as possible into the model. This will likely be worth circling back to, though.