### Define key loading and plotting functions

In [1]:
import dash
import plotly.express as px
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import json
import glob2 as glob
from skimage.measure import regionprops
import itertools
import os
from plotly.offline import init_notebook_mode, iplot

init_notebook_mode(connected=True)  

# set path to fin atlases
data_root = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/fin_atlas/"
atlas_name = "reference_df_cell_hr.csv"

# set figure path 
fig_path = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_figures/atlas_projections/"
if os.path.isdir(fig_path)==False:
    os.makedirs(fig_path)

# define list of gene names
gene_name_list = ['Col11a2', 'Emilin3a', 'Fgf10a', 'Hand2', 'Myod1', 'Prdm1a', 'Robo3', 'Sox9a', 'Tbx5a']

# load the atlas
atlas_df_raw = pd.read_csv(data_root+atlas_name, index_col=0)

## Make PD, AP, and DV projections

**First, we need to m,ake regular grids for each projection**

In [57]:
from scipy.interpolate import LinearNDInterpolator

keep_indices = np.where((atlas_df_raw["Z"]<=0.46))
atlas_df = atlas_df_raw.iloc[keep_indices].copy()

scale_PD = np.max(atlas_df["Z"])-np.min(atlas_df["Z"])
scale_AP = np.max(atlas_df["X"])-np.min(atlas_df["X"])
scale_DV = np.max(atlas_df["Y"])-np.min(atlas_df["Y"])

grid_res_pd = 100
grid_res_ap = grid_res_pd#np.round(scale_AP/scale_PD*grid_res_pd).astype(int)
grid_res_dv = grid_res_pd#np.round(scale_DV/scale_PD*grid_res_pd / 1.5).astype(int) # empirically I find that this scale is too large

# DV projections
dv_axis1 = np.linspace(np.min(atlas_df["X"].copy()), np.max(atlas_df["X"].copy()), grid_res_ap)
dv_axis2 = np.linspace(np.min(atlas_df["Z"].copy()), np.max(atlas_df["Z"].copy()), grid_res_pd)
dv_grid1, dv_grid2 = np.meshgrid(dv_axis1, dv_axis2)

# AP projections
ap_axis1 = np.linspace(np.min(atlas_df["Y"].copy()), np.max(atlas_df["Y"].copy()), grid_res_dv)
ap_axis2 = dv_axis2.copy()
ap_grid1, ap_grid2 = np.meshgrid(ap_axis1, ap_axis2)

# PD projections
pd_axis1 = dv_axis1.copy()
pd_axis2 = ap_axis1.copy()
pd_grid1, pd_grid2 = np.meshgrid(pd_axis1, pd_axis2)

In [77]:
from scipy.interpolate import CloughTocher2DInterpolator
from scipy.ndimage import gaussian_filter
from astropy.convolution import convolve
from astropy.convolution import Gaussian2DKernel

sm_sig = 1
kernel_dv = Gaussian2DKernel(x_stddev=2)
kernel_ap = Gaussian2DKernel(x_stddev=2.5)

# set gene ID
for gene_ind in range(len(gene_name_list)):
    
    gene_name = gene_name_list[gene_ind]

    # remove nan values
    gene_vec_raw = atlas_df[gene_name].to_numpy()
    gene_vec_raw = gene_vec_raw 
    nn_indices = np.where(~np.isnan(gene_vec_raw))


    ##########
    # DV 

    # Use linear interpolation to calculate DV projection
    interp_dv = LinearNDInterpolator(list(zip(atlas_df["X"].iloc[nn_indices].to_numpy(), 
                                              atlas_df["Z"].iloc[nn_indices].to_numpy())),
                                          gene_vec_raw[nn_indices])

    gene_dv_raw = interp_dv(dv_grid1, dv_grid2)
    gene_dv = convolve(gene_dv_raw, kernel_dv)
    gene_dv = gene_dv / np.nanmax(gene_dv)
    gene_dv[np.where(np.isnan(gene_dv_raw))] = np.nan

    dv_axis1_plot = dv_axis1/np.max(dv_axis1) - 0.5
    dv_axis2_plot = (dv_axis2 - np.min(dv_axis2))
    dv_axis2_plot = dv_axis2_plot/np.max(dv_axis2_plot) 

    # define colorbar dict
    cb_dict = dict(
                title= gene_name + " Expression", # title here
                titleside='right',
                titlefont=dict(
                    size=14,
                    family='Arial, sans-serif')
              )

    # layout = go.Layout(yaxis=dict(scaleanchor="x", scaleratio=1))
    # Make DV projection
    dv_fig = go.Figure()
    dv_fig.add_trace(go.Contour(x=dv_axis1_plot,
                                y=dv_axis2_plot,
                                z=gene_dv,
                                colorbar=cb_dict,
                               contours=dict(
                                        start=0,
                                        end=1,
                                        size=.1,
                                    ),))

    dv_fig.update_layout(template="plotly_white")
    dv_fig.update_coloraxes(cmin=0, cmax=1)
    dv_fig.update_xaxes(range=[-0.5, 0.5])
    dv_fig.update_yaxes(range=[0, 1])
    dv_fig.update_layout(width=512+128,
                         height=512,)
    dv_fig.update_yaxes(title="PD Position")
    dv_fig.update_xaxes(title="AP Position")

    dv_fig.write_image(fig_path + gene_name + "_DV_projection.png")

    ##########
    # AP Projection

    # Use linear interpolation to calculate DV projection
    interp_ap = LinearNDInterpolator(list(zip(atlas_df["Y"].iloc[nn_indices].to_numpy(), 
                                              atlas_df["Z"].iloc[nn_indices].to_numpy())),
                                          gene_vec_raw[nn_indices])

    gene_ap_raw = interp_ap(ap_grid1, ap_grid2)
    gene_ap = convolve(gene_ap_raw, kernel_ap)
    gene_ap = gene_ap / np.nanmax(gene_ap)
    gene_ap[np.where(np.isnan(gene_ap_raw))] = np.nan

    ap_axis1_plot = ap_axis1/np.max(ap_axis1) - 0.5
    ap_axis2_plot = (ap_axis2 - np.min(ap_axis2))
    ap_axis2_plot = ap_axis2_plot/np.max(ap_axis2_plot)


    layout = go.Layout(yaxis=dict(scaleanchor="x", scaleratio=scale_PD/scale_DV))
    # Make AP projection
    dv_fig = go.Figure()
    dv_fig.add_trace(go.Contour(x=ap_axis1_plot,
                                y=ap_axis2_plot,
                                z=gene_ap,
                                colorbar=cb_dict,
                                line_smoothing=1.3,
                               contours=dict(
                                        start=0,
                                        end=1,
                                        size=.1,
                                    ),
                               ))

    dv_fig.update_layout(template="plotly_white")

    dv_fig.update_yaxes(title="PD Position")
    dv_fig.update_xaxes(title="DV Position")
    dv_fig.update_layout(width=300,
                         height=512,)
    dv_fig.update_xaxes(range=[-0.5, 0.5])
    dv_fig.update_yaxes(range=[0, 1])

    dv_fig.write_image(fig_path + gene_name + "_AP_projection.png")

    ##########
    # PD

    # Use linear interpolation to calculate DV projection
    interp_pd = LinearNDInterpolator(list(zip(atlas_df["X"].iloc[nn_indices].to_numpy(), 
                                              atlas_df["Y"].iloc[nn_indices].to_numpy())),
                                          gene_vec_raw[nn_indices])

    gene_pd_raw = interp_pd(pd_grid1, pd_grid2)
    gene_pd = convolve(gene_pd_raw, kernel_ap)
    gene_pd = gene_pd / np.nanmax(gene_pd)
    gene_pd[np.where(np.isnan(gene_pd_raw))] = np.nan

    pd_axis1_plot = pd_axis1/np.max(pd_axis1) - 0.5
    pd_axis2_plot = pd_axis2/np.max(pd_axis2) - 0.5


    # Make DV projection
    dv_fig = go.Figure()
    dv_fig.add_trace(go.Contour(x=pd_axis1_plot,
                                y=pd_axis2_plot,
                                z=gene_pd,
                                colorbar=cb_dict,
                               contours=dict(
                                        start=0,
                                        end=1,
                                        size=.1,
                                    ),))

    dv_fig.update_layout(template="plotly_white")

    dv_fig.update_yaxes(title="DV Position")
    dv_fig.update_xaxes(title="AP Position")

    dv_fig.update_layout(width=512,
                         height=300,)
    dv_fig.update_xaxes(range=[-0.5, 0.5])
    dv_fig.update_yaxes(range=[-0.5, 0.5])

    dv_fig.write_image(fig_path + gene_name + "_PD_projection.png")



# fig.update_layout(coloraxis_showscale=False)

# #     fig.show()
# fig.write_image(frame_dir + imageName + "_" + "{:03d}".format(iter_i) + ".png")
#     time.sleep(0.25)

In [22]:
print(ap_grid1.shape)

(50, 18)


In [None]:
angle_vec = np.linspace(1.25*np.pi, 3.25*np.pi, 25)

# make save directory
gene_frame_dir = figureRoot + imageName + "_gene_expression/"
if os.path.isdir(gene_frame_dir)==False:
#     shutil.rmtree(gene_frame_dir)
    os.makedirs(gene_frame_dir)

gene_name_list = gene_name_dict[imageName]
mRNA_suffix = "_cell_mean_nn"
color_scale_list = ["ice", "inferno", "viridis"]
top = 0.8
bottom = 0.2

total_counter = 0
for g in range(len(gene_name_list)):
    gene_name = gene_name_list[g]

    col_name = gene_name + mRNA_suffix
    mRNA_col = df_fin[col_name].copy()
    
    for iter_i, a in enumerate(angle_vec):
        angle = a
        za = -0.2
        vec = np.asarray([math.cos(angle), math.sin(angle), za])
        vec = vec*2
        camera = dict(
            eye=dict(x=vec[0], y=vec[1], z=vec[2]))

        fig = px.scatter_3d(df_fin, x="X", y="Y", z="Z", opacity=1, color=mRNA_col, range_color=[bottom, top],
                            color_continuous_scale=color_scale_list[g], template="plotly_white")

        fig.update_layout(scene_camera=camera, scene_dragmode='orbit')
        
        fig.update_layout(title_text=gene_name, title_x=0.5)
        
        fig.update_layout(scene = dict(
                        xaxis_title='',
                        yaxis_title='',
                        zaxis_title='',
                        xaxis = dict(showticklabels=False),
                        yaxis = dict(showticklabels=False),
                        zaxis = dict(showticklabels=False)))

        fig.update_layout(coloraxis_showscale=False)

#         fig.show()
        fig.write_image(gene_frame_dir + imageName + "_mRNA_" + "{:03d}".format(total_counter) + ".png")
        total_counter += 1