## Notebook to make plots of gene expression in the pec fin atlas

In [None]:
import plotly.express as px
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import json
import glob2 as glob
import os

### Preliminaries: set paths and plot options. Define functions

#### Set read and write paths

In [3]:
# set path to data
dataRoot = "../dat/"

# designate directory to save figures (can be outside of the repository)
figPath = "../fig/"
if not os.path.isdir(figPath):
    os.makedirs(figPath)

#### Designate plotting options

In [4]:
figure_resolution = 5 # 1 us default res. Higher is higher res.

plot_file = "reference_df_cell_hr.csv" # reference_df_cell.csv

#### Define helper variables and functions

In [10]:
# Load list of continuous colormaps to choose from
colormaps = px.colors.named_colorscales()  # ["ice", "inferno", "viridis"]
colormaps[2] = "inferno"
colormaps[3] = "matter"

# designate list of gene names 
global gene_names
gene_names = ['Col11a2', 'Emilin3a', 'Fgf10a', 'Hand2', 'Myod1', 'Prdm1a', 'Robo3', 'Sox9a', 'Tbx5a']

# helper function to load dataset and normalize gene expression levels
def load_nucleus_dataset(dataRoot, filename):

    propPath = os.path.join(dataRoot, filename)

    # if os.path.isfile(propPath):
    df_raw = pd.read_csv(propPath, index_col=0)

    if "_hr" in filename: # removes nuisance cluster
        keep_indices = np.where((df_raw["Z"] <= 0.46))
        df = df_raw.iloc[keep_indices].copy()
    else:
        df = df_raw
    
    for g in range(len(gene_names)): # normalize expression levels of each gene

        colname = gene_names[g]
        c_max = np.max(df[colname])
        df[colname] = df.loc[:, colname] / c_max

    return df

#### Define main plotting function

In [17]:
def create_figure(df, plot_gene='Emilin3a', cmap="inferno", high_low_thresh=0.3, title_text='Normalized Expression'):

    # set plot variables
    xs = "X"
    ys = "Y"
    zs = "Z"
    
    df = df.dropna(subset=["Z"]) # remove any observations with missing spatial coordinates
    alph = 4

    xyz_array = np.asarray(df[[xs, ys, zs]])
    
    # identify high- and low- expressing nuclei. We use different plot formats for each
    high_flags = df[plot_gene] >= high_low_thresh
    low_flags = df[plot_gene] < high_low_thresh
    
    # plot high-expressing nuclei
    fig = px.scatter_3d(df.iloc[np.where(high_flags)], x=xs, y=ys, z=zs, opacity=0.8, color=plot_gene,
                        color_continuous_scale=cmap, range_color=(0, 1), template="plotly")

    fig.update_traces(marker=dict(size=10)) # this sets markers for "high" nuclei to be size 10

    # add low-expressing nuclei
    fig.add_trace(go.Scatter3d(x=df[xs].loc[low_flags],
                               y=df[ys].loc[low_flags],
                               z=df[zs].loc[low_flags],
                               mode='markers',
                               marker=dict(
                                   size=6, # note the smaller marker size
                                   color=df[plot_gene].loc[low_flags],
                                   # set color to an array/list of desired values
                                   colorscale=cmap,  
                                   opacity=0.3, # note the lower opacity
                                   cmin=0,
                                   cmax=1)
                               ))
    
    fig.update_layout(coloraxis_colorbar_title_text=title_text)
    fig.update_layout(showlegend=False)
    
    return fig

### Make and save plots

In [20]:
# load dataset and make a plot
df = load_nucleus_dataset(dataRoot, plot_file)
fig = create_figure(df)
fig.show()
fig.write_image(figPath + "test.png", scale=figure_resolution)

In [22]:
fig.write_image(figPath + "test.png", scale=figure_resolution)