### Notebook to examine how temperature impacts developmental stage at cell, tissue, and embryo-levels

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import glob2 as glob
import os
from src.functions.plot_functions import format_2d_plotly

Load in colData from sequencing experiment. This has been exported to a csv file

In [None]:
# set paths
project_root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/hotfish/"
data_folder = os.path.join(project_root, "built_data", "20240813", "")


fig_folder = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/nn_staging/"
os.makedirs(fig_folder, exist_ok=True)

# specify which regression to use
model_name = "bead_expt_linear" #"t_spline_inter"
latent_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/emb_projections/latent_projections/"
model_path = os.path.join(latent_path, model_name, "")

# load data
col_df = pd.read_csv(data_folder + "col_data.csv")
# col_df["tissue"] = pd.Categorical(col_df["tissue"])
col_df.head()

In [None]:
# load full counts dataset
counts_df = pd.read_csv(model_path + "combined_counts.csv", index_col=0)
metadata_df = pd.read_csv(model_path + "combined_metadata.csv", index_col=0)

hf_meta_df = metadata_df.loc[metadata_df["expt"]=="hotfish2"]
# hf_meta_df.head()

### Step 1: look at whole-embryo use nearest neighbor cell neighbor to estimage stage by bogort

In [None]:
# get average time for each embryo
emb_stage_df = col_df.loc[:, ["embryo_ID", "temp", "timepoint", "mean_nn_time"]].groupby(
                    ["embryo_ID", "temp", "timepoint"]).mean("mean_nn_time").reset_index()

# now group by cohort
cohort_stage_df = emb_stage_df.loc[:, ["temp", "timepoint", "mean_nn_time"]].groupby(
                    ["temp", "timepoint"]).agg(["mean", "std"])

cohort_stage_df.columns = [f"{col[0]}_{col[1]}" for col in cohort_stage_df.columns]
cohort_stage_df = cohort_stage_df.reset_index()

# cohort_stage_df

# get predicted stage using linear formula
emb_stage_df["predicted_stage"] =6 + (emb_stage_df["timepoint"]-6)*(0.055*emb_stage_df["temp"]-0.57)
cohort_stage_df["predicted_stage"] = 6 + (cohort_stage_df["timepoint"]-6)*(0.055*cohort_stage_df["temp"]-0.57)

# # make reference x=y diag
# nal_stage_vs_predicted.html")

In [None]:
# emb_stage_df = emb_stage_df.merge(hf_meta_df.rename(columns={"mean_nn_time":"ccs_time"}), how="left", 
#                                   left_on="embryo_ID", right_index=True)

# fig = px.scatter(emb_stage_df, x="ccs_time", y="mean_nn_time")
# fig.show()

In [None]:
d_temp_vec = np.asarray([28, 32, 34])
d_filter = np.isin(cohort_stage_df["temp"], d_temp_vec)

# just temps in Dorrity et all first
ref_vec = np.linspace(22, 44)
marker_size = 14

colorscale="RdBu_r"
range_color=[19, 35]

fig = px.scatter(cohort_stage_df.loc[d_filter], x="predicted_stage", y="mean_nn_time_mean", error_y="mean_nn_time_std", 
                 color="temp", color_continuous_scale=colorscale, range_color=range_color,
                labels={"temp":"temperature"})

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5), showlegend=False))

axis_labels = ["expected stage (hpf)", "molecular staging <br> (nn-transcriptional age"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20, show_gridlines=False)

# Update axis titles
# fig.update_layout(
#     xaxis_title="expected stage (hpf)",
#     yaxis_title="transcriptional stage (hpf)"
# )

# fig.update_layout(width=800, height=600) 
fig.show()

fig.write_image(fig_folder + "cohort_nn_seq_stage_dorrity.png", scale=2)
fig.write_html(fig_folder + "cohort_nn_seq_stage_dorrity.html")

In [None]:
# just temps in Dorrity et all first
ref_vec = np.linspace(14, 44)
marker_size = 14

colorscale="RdBu_r"
range_color=[19, 35]

fig = px.scatter(cohort_stage_df, x="predicted_stage", y="mean_nn_time_mean", error_y="mean_nn_time_std", 
                 color="temp", color_continuous_scale=colorscale, range_color=range_color,
                labels={"temp":"temperature"})

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5), showlegend=False))

axis_labels = ["expected stage (hpf)", "molecular staging <br> (nn-transcriptional age"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20, show_gridlines=False)

# Update axis titles
# fig.update_layout(
#     xaxis_title="expected stage (hpf)",
#     yaxis_title="transcriptional stage (hpf)"
# )

# fig.update_layout(width=800, height=600) 
fig.show()

fig.write_image(fig_folder + "cohort_nn_seq_stage_all.png", scale=2)
fig.write_html(fig_folder + "cohort_nn_seq_stage_all.html")

In the broadest strokes things generally align with expectation: low temperature embryos (19C and 24C) consistently register as younger than their warmer contemporaries. But there are some strange trends as well. 19C embryos start out FAR older than expected (~25hpd instead of 14hpf) and then proceed to get "younger". The 32C cohort actually outpaces the 34 and 35C cohorts by the final timepoint. As well, we see massive embryo-to-embryo variability in the 34 and 35C. For all of these trends, it behooves us to dig into the underlying cell types and tissues to figure out what is going on.

### Look at pairwise correlations between tissue stages

In [None]:
# load in cell coarse-grained cell types
ct_broad_df = pd.read_csv("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/hotfish/data/unique_ct_full.csv",
                          index_col=0)

col_df = col_df.merge(ct_broad_df.loc[:, ["cell_type", "cell_type_broad"]].drop_duplicates(), how="left", on="cell_type")

col_df.head()

In [None]:
from tqdm import tqdm
from itertools import product

n_emb_thresh = 8
group_var = "cell_type_broad"

# grouped table
emb_tissue_stage_df = col_df.loc[:, ["embryo_ID", group_var, "temp", "timepoint", "mean_nn_time"]].groupby(
                    ["embryo_ID", group_var, "temp", "timepoint"])["mean_nn_time"].agg(["mean", "size"]).reset_index()

# rename
emb_tissue_stage_df = emb_tissue_stage_df.rename(columns={"mean":"mean_nn_time", "size":"n_cells"})

# Get unique values for temp, timepoint, and tissue
tissue_values = emb_tissue_stage_df[group_var].unique()  # Unique tissue values
embryo_values = emb_tissue_stage_df['embryo_ID'].unique()

# Create a Cartesian product of all combinations
complete_index = pd.DataFrame(list(product(tissue_values, embryo_values)),
                               columns=[group_var, "embryo_ID"])

# Merge the original data with the complete index
emb_tissue_stage_df = pd.merge(complete_index, emb_tissue_stage_df, on=[group_var, 'embryo_ID'], how='left')

standard_cell_types = np.unique(emb_tissue_stage_df[group_var])

# calculate correlation coefficients
temp_vec = np.asarray(sorted(emb_tissue_stage_df["temp"].unique())).astype(float)
temp_vec = temp_vec[~np.isnan(temp_vec)]
cc_mat_list = []
cc_mat_list_null = []
cov_mat_list = []
count_mat_list = []

np.random.seed(371)

for temp in tqdm(temp_vec):

    temp_df = emb_tissue_stage_df.loc[emb_tissue_stage_df["temp"] == temp, ["embryo_ID", group_var, "mean_nn_time"]]
    # mu = np.mean(temp_df["mean_nn_time"])
    # temp_df["mean_nn_time"] = temp_df["mean_nn_time"] 
    
    # Step 1: Pivot the data so that each tissue is a column
    pivot_df = temp_df.pivot(index='embryo_ID', columns=group_var, values='mean_nn_time')

    pivot_df = pivot_df.reindex(columns=standard_cell_types)
    
    # Step 2: Compute pairwise Pearson correlations between tissues
    correlation_matrix = pivot_df.corr(method='pearson').reset_index()
    covariance_matrix = pivot_df.cov().reset_index()

    # Step 3: Calculate the count matrix (number of shared `embryo_ID`s for each pair)
    bool_df = pivot_df.notna()
    
    # Explicitly convert Boolean to integers (True -> 1, False -> 0)
    int_df = bool_df.astype(int)
    count_matrix = int_df.T.dot(int_df)
    count_mat_list.append(count_matrix.to_numpy())
    
    # # Step 4: Apply a threshold for counts (e.g., at least 3 shared `embryo_ID`s)
    # correlation_matrix[count_matrix < n_emb_thresh] = np.nan
    cv = covariance_matrix.set_index(group_var)
    cc_mat_list.append(correlation_matrix.set_index(group_var))
    cov_mat_list.append(cv)

    ### Do the same thing for a shuffled "null" matrix
    
    # Flatten the DataFrame values into a 1D array, shuffle them, then reshape back to the original shape
    vals_raw = pivot_df.values.flatten()
    vals = vals_raw[~np.isnan(vals_raw)]
    shuffled_values = np.random.choice(vals, len(vals_raw), replace=True).reshape(pivot_df.shape)
    
    # Create a new DataFrame with the shuffled values and the same index and columns as the original
    null_df = pd.DataFrame(shuffled_values, index=pivot_df.index, columns=pivot_df.columns)

    null_correlation_matrix = null_df.corr(method='pearson').reset_index()
    cc_mat_list_null.append(null_correlation_matrix.set_index(group_var))

In [None]:
# filter out low-obs tissues
stacked_counts = np.stack(count_mat_list, axis=2)
min_count_array = np.squeeze(np.min(stacked_counts, axis=2))
filter_vec = np.max(min_count_array, axis=0) >= n_emb_thresh
colnames = np.asarray(cc_mat_list[0].columns)
filter_vec = filter_vec & (colnames !="other")

for t in range(len(temp_vec)):
    # corr
    cc_temp = cc_mat_list[t].copy()
    cols = np.asarray(cc_temp.columns)
    cc_temp = cc_temp.loc[filter_vec, cols[filter_vec]]
    cc_mat_list[t] = cc_temp

    # null corr
    cc_temp_n = cc_mat_list_null[t].copy()
    cols = np.asarray(cc_temp_n.columns)
    cc_temp_n = cc_temp_n.loc[filter_vec, cols[filter_vec]]
    cc_mat_list_null[t] = cc_temp_n

    # cov
    cv_temp = cov_mat_list[t].copy()
    cols = np.asarray(cv_temp.columns)
    cv_temp = cv_temp.loc[filter_vec, cols[filter_vec]]
    cov_mat_list[t] = cv_temp

In [None]:
# get sort order
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, dendrogram

t_ind = 3

# Suppose 'corr' is your correlation matrix as a Pandas DataFrame.
# For example, let's create a dummy correlation matrix:
np.random.seed(0)
corr = cc_mat_list[t_ind]

# Compute the distance matrix from the correlation matrix.
# One common approach is to define distance as 1 - correlation.
# Ensure the matrix is symmetric with ones on the diagonal.
dist = 1 - corr

# Convert the square distance matrix to condensed form for linkage.
# (squareform flattens the upper triangle of the symmetric matrix)
condensed_dist = squareform(dist, checks=False)

# Compute hierarchical clustering using the linkage function.
Z = linkage(condensed_dist, method='average')

# Get the ordering from the dendrogram without plotting it.
dendro = dendrogram(Z, no_plot=True)
order = dendro['leaves']

# Reorder the correlation matrix according to the dendrogram.
# sorted_corr = corr.iloc[order, order]

In [None]:
cc_mat = cc_mat_list_null[4].iloc[order, order]
fig = px.imshow(cc_mat, color_continuous_scale="RdBu_r", range_color=[-1, 1])
fig.update_layout(width=1000, height=1000) 
fig.show()

In [None]:
for t, temp in enumerate(temp_vec):
    
    cc_mat = cc_mat_list[t].iloc[order, order]
    fig = px.imshow(cc_mat, color_continuous_scale="RdBu_r", range_color=[-1, 1])

    
    fig.update_layout(width=1000, height=1000) 
    # fig.update_layout(title=f"Pairwise tissue stage correlation ({temp} C)")

    # Hide axis titles
    fig.update_layout(
        xaxis_title="", #"tissue type",  
        yaxis_title="", #"tissue type",  
        font=dict(color="white", family="Arial, sans-serif", size=18)
        # coloraxis_colorbar_title="Pearson's cc",
        # coloraxis=dict(cmin=-1, cmax=1) 
    )

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    # Set gaps between cells (in pixels).
    fig.update_traces(xgap=1, ygap=1)
    
    # Set the background colors to black, so the gaps show as black lines.
    fig.update_layout(
        plot_bgcolor="black",
        paper_bgcolor="black"
    )
    fig.update_layout(coloraxis_showscale=False)

    fig.update_layout(
        margin=dict(l=10, r=10, t=10, b=10)
    )
    fig.update_xaxes(automargin=False)
    fig.update_yaxes(automargin=False)    
    fig.write_image(fig_folder + f"tissue_cc_{int(temp)}C.png", scale=2)
    fig.write_html(fig_folder + f"tissue_cc_{int(temp)}C.html")

fig.show()

### look at distribution of CC's across the different temps

In [None]:
cc_list = []
cc_null_val_list = []
for t, temp in enumerate(temp_vec): 
    
    cc = cc_mat_list[t].to_numpy()
    tril_indices = np.tril_indices(cc.shape[0], k=-1)
    cc_vals = cc[tril_indices]

    cc_n = cc_mat_list_null[t].to_numpy()
    cc_n_vals = cc_n[tril_indices]
    cc_null_val_list.append(cc_n_vals[cc_vals<=1])
    
    cc_df = pd.DataFrame(cc_vals[cc_vals<=1], columns=["cc"])
    cc_df["temperature"] = int(temp)
    cc_list.append(cc_df)

null_df = pd.DataFrame(np.concatenate(cc_null_val_list), columns=["cc"])
null_df["temperature"] = "null"


cc_df = pd.concat(cc_list + [null_df], axis=0, ignore_index=True)
# cc_df["temp"] = pd.Categorical(cc_df["temp"])
cc_df.head()

In [None]:
import plotly.colors as pc

cc_df = cc_df.loc[cc_df["temperature"] != "null"]

group_stats = cc_df.groupby(["temperature"])["cc"].agg(
    mean="mean",
    std="std",
    sem="sem"  # sem() computes standard error of the mean
).reset_index()

# Example: group_labels contains numeric temperature values as strings or numbers.
# Ensure they are floats for proper scaling.
groups = cc_df['temperature'].unique()
groups_sorted = sorted(groups.astype(str))
# group_numeric = [float(label) for label in group_labels]

# Define your min and max for the color mapping
min_val, max_val = 17, 38

# Create a list of colors by normalizing the numeric values to [0,1] and sampling the "RdBu_r" colormap.
colors = dict({int(val):
    pc.sample_colorscale("RdBu_r", (val - min_val) / (max_val - min_val))[0]
    for val in groups
})
# colors["null"] = "gray"

cc_df["temperature_group"] = pd.Categorical(cc_df["temperature"].astype(str))
cc_df = cc_df.sort_values(["temperature_group"])

fig = px.box(cc_df, y="cc", x="temperature", color="temperature", #category_orders={"temperature": groups_sorted},
             color_discrete_map=colors) #, orientation="h")

fig.update_traces(width=0.8)

                          
fig = format_2d_plotly(fig, font_size=20, axis_labels=["temperature (C)", "correlation coefficient"])


fig.update_traces(opacity=0.95)

# Adjust the line width for the box outline
fig.update_traces(line=dict(width=5))

fig.show()

fig.write_image(fig_folder + f"tissue_cc_box.png", scale=2)
fig.write_html(fig_folder + f"tissue_cc_box.html")

In [None]:


fig = px.scatter(group_stats, x="temperature", y="mean", error_y="sem", color="temperature", color_continuous_scale="RdBu_r",
                 range_color=[17,38])


fig = format_2d_plotly(fig, axis_labels=["temperature (C)", "correlation coefficient"], font_size=20, marker_size=24)

fig.update_traces(mode="lines+markers", line=dict(color="white"))

fig.update_layout(yaxis=dict(range=[0, 0.3]))
fig.update_layout(coloraxis_showscale=False)
fig.update_traces(
            error_x=dict(color="white", width=0)
        )

fig.show()

fig.write_image(fig_folder + f"tissue_cc_scatter.png", scale=2)
fig.write_html(fig_folder + f"tissue_cc_scatter.html")