### Experiment with using Hooke latent space for embryo projections

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import glob2 as glob
import patsy
from src.functions.plot_functions import format_3d_plotly, format_2d_plotly, rotate_figure

# set paths
fig_root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/PLN/"

fig_folder = os.path.join(fig_root, "cov_analyses" , "")
os.makedirs(fig_folder, exist_ok=True)

# set path to data
model_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/results/20250312/HF_hooke_regressions/"

### Load covariance matrices from each experiment

In [None]:
# get a list of temp model folders
folder_list = sorted(glob.glob(model_path + "*C"))

cov_list = []
cov_d_list = []
temp_vec = []
det_list = []
det_d_list = []
k_list = []
col_list = []
for fname in folder_list:
    temp_vec.append(int(os.path.basename(fname)[:-1]))
    cov = pd.read_csv(os.path.join(fname, "COV.csv"), index_col=0)
    col_list.append(cov.columns)
    cov_list.append(cov)
    
common_cols = np.asarray(list(set.intersection(*map(set, col_list))))

cov_list_clean = []
for c in range(len(cov_list)):
    cov_raw = cov_list[c]
    cov = cov_raw.copy()
    cov = cov.reindex(common_cols)
    cov = cov.loc[:, common_cols]
    cov_list_clean.append(cov)
    
    cov_d = cov.copy()
    cov_d[np.eye(cov_d.shape[0])!=1] = 0
    sign, logdet = np.linalg.slogdet(cov)
    if sign <= 0:
        print("Warning: Covariance matrix is not positive definite!")
    det_list.append(logdet)
    k_list.append(cov.shape[0])
    cov_list.append(cov)

    _, logdet_d = np.linalg.slogdet(cov_d)
    cov_d_list.append(cov_d)
    det_d_list.append(logdet_d)

In [None]:
# for t, temp in enumerate(temp_vec):

t=3
temp = temp_vec[t]

if True:
    cc_mat = cov_list[t]#.iloc[order, order]
    fig = px.imshow(cc_mat, color_continuous_scale="RdBu_r", range_color=[-1, 1])

    
    fig.update_layout(width=400, height=400) 
    # fig.update_layout(title=f"Pairwise tissue stage correlation ({temp} C)")

    # Hide axis titles
    fig.update_layout(
        xaxis_title="", #"tissue type",  
        yaxis_title="", #"tissue type",  
        font=dict(color="white", family="Arial, sans-serif", size=18)
        # coloraxis_colorbar_title="Pearson's cc",
        # coloraxis=dict(cmin=-1, cmax=1) 
    )

    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    # Set gaps between cells (in pixels).
    # fig.update_traces(xgap=1, ygap=1)
    
    # Set the background colors to black, so the gaps show as black lines.
    fig.update_layout(
        plot_bgcolor="black",
        paper_bgcolor="black"
    )
    fig.update_layout(coloraxis_showscale=False)

    fig.update_layout(
        margin=dict(l=5, r=5, t=5, b=5)
    )
    fig.update_xaxes(automargin=False)
    fig.update_yaxes(automargin=False)    
    # fig.write_image(fig_folder + f"tissue_cov_{int(temp)}C.png", scale=2)
    # fig.write_html(fig_folder + f"tissue_cov_{int(temp)}C.html")

fig.show()

In [None]:
det_list

In [None]:
k_vec = np.asarray(k_list)
e_vec = (k_vec / 2 * np.log(2*np.pi*np.exp(1)) + 0.5 * np.asarray(det_list)) / np.log(2)
e_vec_d = (k_vec / 2 * np.log(2*np.pi*np.exp(1)) + 0.5 * np.asarray(det_d_list)) / np.log(2)



colorscale="RdBu_r"
range_color=[19, 35]
marker_size = 14

fig = px.scatter(x=temp_vec, y=e_vec_d - e_vec, color=temp_vec,
                 color_continuous_scale=colorscale, range_color=range_color)

fig.update_traces(marker=dict(line=dict(color="white", width=0.5)))

axis_labels = ["temperature", "regulatory information (bits)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20, show_gridlines=True)

# Update axis titles
# fig.update_layout(
#     xaxis_title="expected stage (hpf)",
#     yaxis_title="transcriptional stage (hpf)"
# )

# fig.update_layout(width=800, height=600) 
fig.show()

fig.write_image(fig_folder + "reg_info.png", scale=2)
fig.write_html(fig_folder + "reg_info.html")


In [None]:
k_list

In [None]:
from src.seq.hooke_latent_projections.project_ccs_data import construct_X
from tqdm import tqdm 

# dis = 2.0
expt = "NA" #"expthotfish2"
cov_col_list = beta_array.columns.tolist()

# gene3_dict = dict({"expt":"exptGENE3"})
null_dict = dict({"expt":"NA"})
# generate covariate matrix

nt = 100
time_ref_vals = np.linspace(np.min(metadata_df["timepoint"]), np.max(metadata_df["timepoint"]), nt)
# time_pd_vals = time_df["pseudostage"].to_numpy()

# construct covariates matrix
x_list = []
for t in tqdm(time_ref_vals):
    xt = construct_X(timepoint=t, cov_dict=null_dict, cov_col_list=cov_col_list, spline_lookup_df=spline_lookup_df)
    x_list.append(xt)

# x_list_pd = []
# for t in tqdm(time_pd_vals):
#     xt = construct_X(timepoint=t, cov_dict=null_dict, cov_col_list=cov_col_list, spline_lookup_df=spline_lookup_df)
#     x_list_pd.append(xt)


# # get covariate array
X = pd.concat(x_list)
# X_pd = pd.concat(x_list_pd)
# # X3 = pd.concat(x_list3)

# # get prediction matrix
ref_latent_df = (X @ beta_array.T).reset_index(drop=True)
# pd_latent_df = (X_pd @ beta_array.T)

# residual_df = pd.DataFrame(latents_df.to_numpy() - pd_latent_array.to_numpy(), columns=pd_latent_array.columns)
# # ref_latent_array3 = X3 @ beta_array.T
ref_latent_df.head()

In [None]:
from sklearn.decomposition import PCA
n_components = 25

pca_cols = [f"PCA_{p:02}" for p in range(n_components)]

pca = PCA(n_components=n_components)
pca.fit(latents_df)

# get loadings
ccs_pca = pd.DataFrame(pca.transform(latents_df), columns=pca_cols, index=latents_df.index)
ref_pca = pd.DataFrame(pca.transform(ref_latent_df), columns=pca_cols, index=ref_latent_df.index)

# plot % explained
fig = px.line(x=np.arange(n_components) + 1, y=np.cumsum(pca.explained_variance_ratio_), markers=True)

fig.update_layout(
    title = "PC loadings",
    xaxis_title="PC",
    yaxis_title="% variance explained")

fig.show()

# fig.write_image(fig_path + "pc_var.png", scale=2)
# fig.write_html(fig_path + "pc_var.html")

#### Plot wildtype curve

In [None]:
# ccs_pca = ccs_pca.rename(columns={"temp":"temperature"})
ccs_pca["inference_flag"] = time_df["inference_flag"].copy()
ccs_pca["timepoint"] = time_df["timepoint"].copy()
ccs_pca["temperature"] = time_df["temp"].copy()
ccs_pca["expt"] = time_df["expt"].copy()
ccs_pca["pseudostage"] = time_df["pseudostage"].copy()

# general perspective params
zoom_factor = 0.021
z_rotation = 25
elevation = -10
marker_size = 6

xrange = [-25, 25]
yrange = [-12, 30]
zrange = [-11, 15]

# plot params
ref_filter = ccs_pca.loc[:, "inference_flag"]==1
colormap = "RdBu_r"
color_range = [17, 39]

axis_labels = ["seq PC 1", "seq PC 2", "seq PC 3"]

fig = px.scatter_3d(ccs_pca.loc[ref_filter], x=pca_cols[0], y=pca_cols[1], z=pca_cols[2],
                    color="timepoint", opacity=1)

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref_pca.loc[:, pca_cols[0]], y=ref_pca.loc[:, pca_cols[1]], z=ref_pca.loc[:, pca_cols[2]], 
                             mode="lines", line=dict(color="white", width=4), showlegend=False)) # marker=dict(size=3),

fig = format_3d_plotly(fig, axis_labels=axis_labels, marker_size=marker_size, font_size=18)

fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)


# fig.add_traces(go.Scatter3d(x=ref3_pca[:, 0], y=ref3_pca[:, 1], z=ref3_pca[:, 2]))#, 
                            # marker=dict(color=time_vals, size=3), line=dict(color="black")))

fig.show()

fig.write_image(fig_path + "hf2_ref_pca_trajectory.png", scale=2)
fig.write_html(fig_path + "hf2_ref_pca_trajectory.html")

In [None]:

tmin = np.min(time_ref_vals)
tmax = np.max(time_ref_vals)
time_vec = np.linspace(tmin, tmax, 50)

spline_frame_path = os.path.join(fig_path, "hf_seq_spline_pca_frames", "")
os.makedirs(spline_frame_path, exist_ok=True)


# rotate
for t, time in enumerate(tqdm(time_vec)):

# t = 25
# time = time_vec[t]

# if True:
    # initialize figure
    fig = px.scatter_3d(ccs_pca.loc[ref_filter], x=pca_cols[0], y=pca_cols[1], z=pca_cols[2],
                        color="timepoint", opacity=1)
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange


    time_filter = time_ref_vals <= time
    last_true_only = np.arange(len(time_filter)) == np.nonzero(time_filter)[0][-1]
    
    fig.add_traces(go.Scatter3d(x=ref_pca.loc[time_filter, pca_cols[0]], y=ref_pca.loc[time_filter, pca_cols[1]], 
                                z=ref_pca.loc[time_filter, pca_cols[2]], 
                                 mode="lines", line=dict(color="white", width=4), showlegend=False)) 

    # fig.add_traces(go.Scatter3d(x=ref_pca.loc[time_filter, pca_cols[0]], y=ref_pca.loc[time_filter, pca_cols[1]], 
    #                             z=ref_pca.loc[time_filter, pca_cols[2]], 
    #                              mode="markers", marker=dict(color="white", size=20), showlegend=False)) 
    
    fig = format_3d_plotly(fig, axis_labels=axis_labels, marker_size=marker_size, font_size=18)

    if t > 0:
        fig.add_traces(go.Scatter3d(x=ref_pca.loc[last_true_only, pca_cols[0]].values, 
                                    y=ref_pca.loc[last_true_only, pca_cols[1]].values, 
                                    z=ref_pca.loc[last_true_only, pca_cols[2]].values, 
                                    mode="markers", marker=dict(color="white", size=9, line=dict(color="black", width=1)), 
                                    showlegend=False)) 

    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    

    fig.write_image(os.path.join(spline_frame_path, f"hotfish_seq_pca_time{t:02}.png"), scale=2)
    
fig.show()

In [None]:
angle_vec = np.linspace(0, 360, 50)

frame_path = os.path.join(fig_path, "hf_ref_seq_pca_rot_frames", "")
os.makedirs(frame_path, exist_ok=True)

# initialize figure
fig = px.scatter_3d(ccs_pca.loc[ref_filter], x=pca_cols[0], y=pca_cols[1], z=pca_cols[2],
                    color="timepoint", opacity=1)

fig.layout.scene.xaxis.range = xrange
fig.layout.scene.yaxis.range = yrange
fig.layout.scene.zaxis.range = zrange

fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref_pca.loc[:, pca_cols[0]], y=ref_pca.loc[:, pca_cols[1]], z=ref_pca.loc[:, pca_cols[2]], 
                             mode="lines", line=dict(color="white", width=4), showlegend=False)) # marker=dict(size=3),

fig = format_3d_plotly(fig, axis_labels=axis_labels, marker_size=marker_size, font_size=18)


# rotate
for a, angle in enumerate(tqdm(angle_vec)):
    
    
    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation + angle, elev_rotation=elevation)
    

    fig.write_image(os.path.join(frame_path, f"hotfish_seq_pca_rot{a:02}.png"), scale=2)
    # fig.write_html(os.path.join(fig_path, f"hotfish_pca_rot{a:02}.html"))
    
fig.show()

In [None]:
hot_filter = ccs_pca.loc[:, "expt"]=='hotfish2'

hf_frame_path = os.path.join(fig_path, "hf_seq_pca_scatter_frames", "")
os.makedirs(hf_frame_path, exist_ok=True)

ccs_stage_vec = ccs_pca.loc[:, "pseudostage"].to_numpy()
smin = np.min(ccs_pca.loc[hot_filter, "pseudostage"]) - 1
smax = np.max(ccs_pca.loc[hot_filter, "pseudostage"])

stage_vec = np.linspace(smin, smax, 25)

for s, stage in enumerate(tqdm(stage_vec)):

    if s > 0:
        opacity = 1
        stage_filter = ccs_stage_vec <= stage
    else:
        opacity = 0
        stage_filter = ccs_stage_vec <= np.inf

    plot_filter = stage_filter & hot_filter
    

    
    # initialize figure
    fig = px.scatter_3d(ccs_pca.loc[plot_filter], x=pca_cols[0], y=pca_cols[1], z=pca_cols[2], opacity=opacity,
                        color="temperature", color_continuous_scale=colormap, range_color=color_range)
    
    fig.layout.scene.xaxis.range = xrange
    fig.layout.scene.yaxis.range = yrange
    fig.layout.scene.zaxis.range = zrange

    
    fig.add_traces(go.Scatter3d(x=ref_pca.loc[:, pca_cols[0]], y=ref_pca.loc[:, pca_cols[1]], 
                                z=ref_pca.loc[:, pca_cols[2]], 
                                 mode="lines", line=dict(color="white", width=4), showlegend=False)) 

    # fig.add_traces(go.Scatter3d(x=ref_pca.loc[time_filter, pca_cols[0]], y=ref_pca.loc[time_filter, pca_cols[1]], 
    #                             z=ref_pca.loc[time_filter, pca_cols[2]], 
    #                              mode="markers", marker=dict(color="white", size=20), showlegend=False)) 
    
    fig = format_3d_plotly(fig, axis_labels=axis_labels, marker_size=marker_size, font_size=18)

    fig = rotate_figure(fig, zoom_factor=zoom_factor, z_rotation=z_rotation, elev_rotation=elevation)
    

    fig.write_image(os.path.join(hf_frame_path, f"hotfish_seq_pca_stage{s:02}.png"), scale=2)
    
fig.show()

### Look at inferred pseudotime vs experimental timepoint

In [None]:
time_df["temperature"] = time_df["temp"].copy()

fig = px.scatter(time_df.loc[hot_filter], x="mean_nn_time", y="pseudostage", color="temperature", symbol="timepoint",
                color_continuous_scale=colormap, range_color=color_range)

# Compute a linear trendline on all data
x = time_df.loc[hot_filter, "mean_nn_time"] 
y = time_df.loc[hot_filter, "pseudostage"] 
coeffs = np.polyfit(x, y, 1)  # Linear fit (degree 1)
poly = np.poly1d(coeffs)

# Create a smooth line for the trendline
x_line = np.linspace(x.min()-2, x.max()+2, 100)
y_line = poly(x_line)

# Add the overall trendline as a trace
fig.add_trace(
    go.Scatter(
        x=x_line,
        y=y_line,
        mode="lines", showlegend=False,
        line=dict(color="white", width=2.5, dash="dash")
    )
)

fig.layout.xaxis.range = [18, 45]
fig.layout.yaxis.range = [18, 45]


axis_labels = ["nn transcritpional age (hpf)", "pseudostage (hpf)"]
fig = format_2d_plotly(fig, axis_labels=axis_labels, font_size=20, marker_size=10)

# fig.update_coloraxes(colorbar_title="temperature")
fig.show()

fig.write_image(fig_path + "hf2_pseudotime_vs_nn_time.png", scale=2)
fig.write_html(fig_path + "hf2_pseudotime_vs_nn_time.html")

### Plot vs expected

In [None]:
# now group by cohort
time_df["stage_shift"] = time_df["pseudostage"] - time_df["timepoint"] 
cohort_stage_df = time_df.loc[:, ["temperature", "timepoint", "expt", "mean_nn_time", "pseudostage", "stage_shift"]].groupby(
                    ["temperature", "timepoint", "expt"]).agg(["mean", "std"])

cohort_stage_df.columns = [f"{col[0]}_{col[1]}" for col in cohort_stage_df.columns]
cohort_stage_df = cohort_stage_df.reset_index()

hot_filter = cohort_stage_df["expt"]=="hotfish2"
# get predicted stage using linear formula
cohort_stage_df["predicted_stage"] = 6 + (cohort_stage_df["timepoint"]-6)*(0.055*cohort_stage_df["temperature"]-0.57)

ref_vec = np.linspace(14, 44)
marker_size = 14


fig = px.scatter(cohort_stage_df.loc[hot_filter], x="predicted_stage", y="pseudostage_mean", error_y="pseudostage_std", 
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["expected stage (hpf)", "molecular staging <br> (nn-transcriptional age)"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "cohort_pseudotime_seq_stage_all.png", scale=2)
fig.write_html(fig_path + "cohort_pseudotime_seq_stage_all.html")

In [None]:
fig = px.scatter(cohort_stage_df.loc[hot_filter], x="mean_nn_time_mean", y="pseudostage_mean", 
                 error_y="pseudostage_std", error_x="mean_nn_time_std",
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
fig.update_traces(error_x=dict(width=0))

# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["nn-transcriptional age", "pseudotime"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "mean_pseudotime_vs_nn.png", scale=2)
fig.write_html(fig_path + "mean_pseudotime_vs_nn.html")

In [None]:
cohort_stage_df["pseudostage_cv"] = np.divide(cohort_stage_df["pseudostage_std"], cohort_stage_df["pseudostage_mean"])

fig = px.scatter(cohort_stage_df.loc[hot_filter], x="stage_shift_mean", y="pseudostage_cv", 
                 color="temperature", symbol="timepoint",color_continuous_scale=colormap, range_color=[19, 35])

fig.update_traces(error_y=dict(width=0))
# fig.update_traces(mode="lines+markers", line=dict(color="white", width=0.5))

# fig.add_trace(go.Scatter(x=ref_vec, y=ref_vec, mode="lines", line=dict(color="white", width=2.5, dash="dash"), showlegend=False))

axis_labels = ["stage shift", "stage dispersion"]

fig = format_2d_plotly(fig, marker_size=marker_size, axis_labels=axis_labels, font_size=20)#, show_gridlines=False)


fig.show()

fig.write_image(fig_path + "cohort_shift_noise_seq.png", scale=2)
fig.write_html(fig_path + "cohort_shift_noise_seq.html")

### Look to see if we see temp-shift in hotfish experiment

In [None]:
hot_filter = time_df["expt"]=='hotfish2'
# time_df["expt"].unique()

fig = px.scatter(time_df.loc[hot_filter], x="timepoint", y="pseudostage", color="temp")

# fig.update_layout(xaxis=dict(range=[0, 120]), 
#                   yaxis=dict(range=[0, 120]))
fig.update_layout(width=800, height=600) 

fig.update_layout(
    title = "predicted stage vs experiment time (hotfish)",
    xaxis_title="experimental timepoint (hpf)",
    yaxis_title="transcriptional pseudo-stage (hpf)")

fig.update_coloraxes(colorbar_title="temperature (C)")
fig.show()

fig.write_image(fig_path + "hf2_pseudotime_vs_timepoint.png", scale=2)
fig.write_html(fig_path + "hf2_pseudotime_vs_timepoint.html")

# fig.show()

In [None]:
fig = px.scatter(time_df.loc[hot_filter], x="mean_nn_time", y="pseudostage", color="temp")

# fig.update_layout(xaxis=dict(range=[0, 120]), 
#                   yaxis=dict(range=[0, 120]))
fig.update_layout(width=800, height=600) 

fig.update_layout(
    title = "predicted stage vs experiment time (hotfish)",
    xaxis_title="nearest-neighbor stage (hpf)",
    yaxis_title="transcriptional pseudo-stage (hpf)")

fig.update_coloraxes(colorbar_title="temperature (C)")
fig.show()

fig.write_image(fig_path + "hf2_pseudotime_vs_nn_time.png", scale=2)
fig.write_html(fig_path + "hf2_pseudotime_vs_nn_time.html")

### Let's use PCA to visualize latent space

In [None]:
import itertools

# get mean model predictions
hooke_data_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/seq_data/emb_projections/hooke_model_files/"
model_path = os.path.join(hooke_data_path, model_name, "")

# load spline lookup
spline_lookup_df = pd.read_csv(model_path + "time_splines.csv")

# load hooke model files
cov_array = pd.read_csv(model_path + "COV.csv", index_col=0)
beta_array = pd.read_csv(model_path + "B.csv", index_col=0).T

beta_array = beta_array.rename(columns={"(Intercept)":"Intercept"})

cols_from = beta_array.columns
cols_from_clean = [col.replace(" = c", "=") for col in cols_from]
beta_array.columns = cols_from_clean
beta_array.head()

In [None]:
from sklearn.decomposition import PCA
n_components = 25
# n_cell_cutoff = 250
# n_filter = metadata_df["count_per_embryo"]>=n_cell_cutoff
# latents_df_filtered = latents_df.loc[n_filter, :]

pca = PCA(n_components=n_components)
pca.fit(residual_df)

# get loadings
ccs_pca_array = pca.transform(latents_df)
ref_pca = pca.transform(ref_latent_array)

# plot % explained
fig = px.line(x=np.arange(n_components), y=pca.explained_variance_ratio_, markers=True)

fig.update_layout(
    title = "PC loadings",
    xaxis_title="PC",
    yaxis_title="% variance explained")

fig.show()

fig.write_image(fig_path + "pc_var.png", scale=2)
fig.write_html(fig_path + "pc_var.html")

In [None]:
from src.seq.hooke_latent_projections.project_ccs_data import construct_X
from tqdm import tqdm 

# dis = 2.0
expt = "NA" #"expthotfish2"
cov_col_list = beta_array.columns.tolist()

# gene3_dict = dict({"expt":"exptGENE3"})
null_dict = dict({"expt":"NA"})
# generate covariate matrix
nt = 100
time_ref_vals = np.linspace(np.min(metadata_df["timepoint"]), np.max(metadata_df["timepoint"]), nt)
time_pd_vals = time_df["pseudostage"].to_numpy()

# construct covariates matrix
x_list = []
for t in tqdm(time_ref_vals):
    xt = construct_X(timepoint=t, cov_dict=null_dict, cov_col_list=cov_col_list, spline_lookup_df=spline_lookup_df)
    x_list.append(xt)

x_list_pd = []
for t in tqdm(time_pd_vals):
    xt = construct_X(timepoint=t, cov_dict=null_dict, cov_col_list=cov_col_list, spline_lookup_df=spline_lookup_df)
    x_list_pd.append(xt)


# # get covariate array
X = pd.concat(x_list)
X_pd = pd.concat(x_list_pd)
# # X3 = pd.concat(x_list3)

# # get prediction matrix
ref_latent_array = X @ beta_array.T
pd_latent_array = X_pd @ beta_array.T

residual_df = pd.DataFrame(latents_df.to_numpy() - pd_latent_array.to_numpy(), columns=pd_latent_array.columns)
# # ref_latent_array3 = X3 @ beta_array.T
ref_latent_array.head()

In [None]:
pd_latent_array.shape
latents_df.shape

#### Seems like we have a significant QC issue with Gene3...there are ~5/10x fewer cells?!?

In [None]:
metadata_df["log_counts"] = np.log10(metadata_df["count_per_embryo"])
fig = px.histogram(metadata_df, x="log_counts", color="expt", opacity=0.75)
# fig.update_xaxes(type="log")
fig.update_layout(barmode="overlay")
fig.show()

In [None]:
time_df_ft = time_df.loc[n_filter]
hot_filter =time_df_ft.loc[:, "expt"]=='hotfish2'

fig = px.scatter_3d(x=ccs_pca_array[hot_filter, 0], y=ccs_pca_array[hot_filter, 1], z=ccs_pca_array[hot_filter, 2],
                    color=time_df_ft.loc[hot_filter, "temp"], opacity=0.7,
                    size=time_df_ft.loc[hot_filter, "timepoint"])
fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref_pca[:, 0], y=ref_pca[:, 1], z=ref_pca[:, 2], 
                            marker=dict(size=3), line=dict(color="black")))
# fig.add_traces(go.Scatter3d(x=ref3_pca[:, 0], y=ref3_pca[:, 1], z=ref3_pca[:, 2]))#, 
                            # marker=dict(color=time_vals, size=3), line=dict(color="black")))

fig.show()

fig.write_image(fig_path + "hf2_hotfish_pca_trajectory.png", scale=2)
fig.write_html(fig_path + "hf2_hotfish_pca_trajectory.html")

In [None]:
chem_filter = np.asarray([1 if "chem" in exp.lower() else 0 for exp in metadata_df["expt"].tolist()])==1
chem_i_vec = metadata_df.loc[chem_filter, "expt"]

In [None]:
# fig = px.scatter_3d(x=ccs_pca_array[chem_filter, 0], y=ccs_pca_array[chem_filter, 1], z=ccs_pca_array[chem_filter, 2],
#                     color=chem_i_vec)
# fig.update_traces(marker=dict(size=5))
# fig.add_traces(go.Scatter3d(x=ref_pca[:, 0], y=ref_pca[:, 1], z=ref_pca[:, 2], 
#                             marker=dict(color=time_vals, size=3), line=dict(color="black")))

# fig.show()

In [None]:
crisp_filter = np.asarray([1 if "chem" in exp.lower() else 0 for exp in metadata_df["expt"].tolist()])==1
chem_i_vec = metadata_df.loc[chem_filter, "target"]

fig = px.scatter_3d(x=ccs_pca_array[chem_filter, 0], y=ccs_pca_array[chem_filter, 1], z=ccs_pca_array[chem_filter, 2],
                    color=chem_i_vec)
fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref_pca[:, 0], y=ref_pca[:, 1], z=ref_pca[:, 2], 
                            marker=dict(color=time_vals, size=3), line=dict(color="black")))

fig.show()

In [None]:
# metadata_df["perturbation"].unique()
# ctrl_labels = np.asarray(["EtOH", "DMSO","ctrl-inj", "reference", "ctrl-uninj", "novehicle"])
# ctrl_filter = np.isin(metadata_df["perturbation"], ctrl_labels)
# bead_filter = (metadata_df["dis_protocol"]==2).to_numpy()
# np.sum(ctrl_filter & bead_filter)
# metadata_df.loc[metadata_df["target"]=="Control", "perturbation"].unique()

In [None]:
metadata_df_ft = metadata_df.loc[n_filter, :]
# inf_filter = metadata_df_ft["target"]=="Control"

fig = px.scatter_3d(x=ccs_pca_array[:, 0], y=ccs_pca_array[:, 1], z=ccs_pca_array[:, 2],
                    color=metadata_df_ft.loc[:, "target"]
                   )
fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref_pca[:, 0], y=ref_pca[:, 1], z=ref_pca[:, 2], 
                            marker=dict(color=time_vals, size=3), line=dict(color="black")))

fig.show()

In [None]:
import umap

umap_model = umap.UMAP(n_components=3, n_neighbors=7, min_dist=1, metric='euclidean')

# Compute the embedding
embedding = umap_model.fit_transform(residual_df)
ref = umap_model.transform(ref_latent_array)

In [None]:
fig = px.scatter_3d(x=embedding[hot_filter, 0], y=embedding[hot_filter, 1], z=embedding[hot_filter, 2],
                    color=time_df.loc[hot_filter, "pseudostage"], hover_data=[time_df.loc[hot_filter,"pseudostage"]])

fig.update_traces(marker=dict(size=5))
fig.add_traces(go.Scatter3d(x=ref[:, 0], y=ref[:, 1], z=ref[:, 2], 
                            marker=dict(color=time_vals, size=3), line=dict(color="black")))

fig.show()

fig.write_image(fig_path + "hf2_hotfish_umap_trajectory.png", scale=2)
fig.write_html(fig_path + "hf2_hotfish_umap_trajectory.html")