### Define key loading and plotting functions

In [16]:
import dash
import plotly.express as px
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import json
import glob2 as glob
from skimage.measure import regionprops
import itertools
import os
from scipy.interpolate import LinearNDInterpolator
import alphashape





def load_nucleus_dataset(filename):
    # global fin_points_prev, not_fin_points_prev, class_predictions_curr, df, curationPath, propPath

    propPath = dataRoot + filename + '_nucleus_props.csv'

    if os.path.isfile(propPath):
        df = pd.read_csv(propPath, index_col=0)
    else:
        raise Exception(
            f"Selected dataset( {filename} ) dataset has no nucleus data. Have you run extract_nucleus_stats?")

    fin_nuclei = np.where(df["pec_fin_flag"] == 1)
    df = df.iloc[fin_nuclei]

    # normalize gene expression levels
    colnames = df.columns
    list_raw = [item for item in colnames if "_cell_mean_nn" in item]
    gene_names = [item.replace("_cell_mean_nn", "") for item in list_raw]

    for g in gene_names:
        ind_list = [i for i in range(len(colnames)) if g in colnames[i]]
        for ind in ind_list:
            colname = colnames[ind]
            c_max = np.max(df[colname])
            df[colname] = df.loc[:, colname] / c_max

    return df

global df, colnames, gene_names

df = load_nucleus_dataset(imNameList[0])

# get list of genes that we can look at
colnames = df.columns
list_raw = [item for item in colnames if "_cell_mean_nn" in item]
gene_names = [item.replace("_cell_mean_nn", "") for item in list_raw]

plot_list = ["3D Scatter", "Volume Plot", "Multiplot"]
########################
# App
app = dash.Dash(__name__)  # , external_stylesheets=external_stylesheets)
# global init_toggle
# init_toggle = True

def create_figure(df, gene_name=None, plot_type=None):
    colormaps = ["ice", "inferno", "viridis"]

    if gene_name == None:
        plot_gene = gene_names[0] + "_mean_nn"
        cmap = colormaps[0]
    else:
        plot_gene = gene_name + "_mean_nn"
        g_index = gene_names.index(gene_name)
        cmap = colormaps[g_index]

    xyz_array = np.asarray(df[["X", "Y", "Z"]])


    if (plot_type == None) | (plot_type == "3D Scatter"):
        high_flags = df[plot_gene] >= 0.3
        low_flags = df[plot_gene] < 0.3
        fig = px.scatter_3d(df.iloc[np.where(high_flags)], x="X", y="Y", z="Z", opacity=0.5, color=plot_gene,
                            color_continuous_scale=cmap, range_color=(0, 1))

        fig.update_traces(marker=dict(size=8
                                      ))

        fig.add_trace(go.Scatter3d(x=df["X"].iloc[np.where(low_flags)],
                                   y=df["Y"].iloc[np.where(low_flags)],
                                   z=df["Z"].iloc[np.where(low_flags)],
                                   mode='markers',
                                   marker=dict(
                                       size=5,
                                       color=df[plot_gene].iloc[np.where(low_flags)],  # set color to an array/list of desired values
                                       colorscale=cmap,  # choose a colorscale
                                       opacity=0.1,
                                       cmin=0,
                                       cmax=1)
                                   ))
        fig.update_layout(coloraxis_colorbar_title_text='Normalized Expression')
        fig.update_layout(showlegend=False)
        # fig = px.scatter_3d(df.iloc[high_flags], x="X", y="Y", z="Z", opacity=0.6, size=plot_gene, color=plot_gene, color_continuous_scale=cmap)
        # fig.update_traces(marker=dict(size=8))

        fig.add_trace(go.Mesh3d(x=xyz_array[:, 0], y=xyz_array[:, 1], z=xyz_array[:, 2],
                                alphahull=9,
                                opacity=0.1,
                                color='gray'))

    elif plot_type == "Volume Plot":

        # generate points to interpolate
        xx = np.linspace(min(df["X"]), max(df["X"]), num=30)
        yy = np.linspace(min(df["Y"]), max(df["Y"]), num=30)
        zz = np.linspace(min(df["Z"]), max(df["Z"]), num=30)

        X, Y, Z = np.meshgrid(xx, yy, zz)  # 3D grid for interpolation

        # generate normalized arrays
        X_norm = X / np.max(X)
        Y_norm = Y / np.max(Y)
        Z_norm = Z / np.max(Z)

        # generate interpolator
        gene_values = np.asarray(df[plot_gene])
        interp = LinearNDInterpolator(xyz_array, gene_values)

        # get interpolated estimate of gene expression
        G = interp(X, Y, Z)

        # generate alpha shape
        xyz_array_norm = np.divide(xyz_array, np.asarray([np.max(X), np.max(Y), np.max(Z)]))
        alpha_fin = alphashape.alphashape(xyz_array_norm, 9)
        xyz_long = np.concatenate((np.reshape(X_norm, (X.size, 1)),
                                   np.reshape(Y_norm, (Y.size, 1)),
                                   np.reshape(Z_norm, (Z.size, 1))),
                                  axis=1)
        inside_flags = alpha_fin.contains(xyz_long)
        G_long = G.flatten()
        G_long[~inside_flags] = np.nan

        fig = go.Figure(
            data=go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=G_long,
            opacity=.1,
            isomin=0.2,
            isomax=0.8,  # needs to be small to see through all surfaces
            surface_count=15,  # needs to be a large number for good volume rendering
            colorscale=cmap
        ))

        fig.add_trace(go.Mesh3d(x=xyz_array[:, 0], y=xyz_array[:, 1], z=xyz_array[:, 2],
                                alphahull=9,
                                opacity=0.1,
                                color='gray'))

        fig.update_layout(coloraxis_colorbar_title_text='Normalized Expression')
        
    
    return fig



## Get list of files

In [17]:
dataRoot = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files/"

# get list of filepaths
fileList = sorted(glob.glob(dataRoot + '*_nucleus_props.csv'))
imNameList = []
for fn in range(len(fileList)):
    labelName = fileList[fn].replace(dataRoot, '', 1)
    labelName = labelName.replace('_nucleus_props.csv', '')
    imNameList.append(labelName)

# compile dictionary of gene names that correspond to each dataset
gene_name_dict = {}
for f in range(len(fileList)):
    filename_temp = imNameList[f]
    propPath = dataRoot + filename_temp + '_nucleus_props.csv'

    if os.path.isfile(propPath):
        df_temp = pd.read_csv(propPath, index_col=0)
    else:
        raise Exception(
            f"Selected dataset( {filename_temp} ) dataset has no nucleus data. Have you run extract_nucleus_stats?")

    # get list of genes that we can look at
    colnames_temp = df_temp.columns
    list_raw = [item for item in colnames_temp if "_cell_mean_nn" in item]
    gene_names_temp = [item.replace("_cell_mean_nn", "") for item in list_raw]

    gene_name_dict[filename_temp] = gene_names_temp

## Test that plotting is working

In [19]:
fileName = fileList[0]
imageName = imNameList[0]
plot_type = "Volume Plot"
gene_name = gene_name_dict[imageName][0]

df = load_nucleus_dataset(imageName)

# get list of genes that we can look at
colnames = df.columns
list_raw = [item for item in colnames if "_cell_mean_nn" in item]
gene_names = [item.replace("_cell_mean_nn", "") for item in list_raw]

f = create_figure(df, gene_name=gene_name, plot_type=plot_type)

f.update_layout(template="plotly_white")

## Record spinning fin plots

In [None]:
# experiment with for loop that rotates fin about z axis
import plotly.graph_objects as go
import  moviepy.editor as mpy
import io 
from PIL import Image

if 'f' in globals():
    del f
    
fig = create_figure(df, gene_name=gene_name, plot_type=plot_type)

fig.update_layout(template="plotly_white")

def plotly_fig2array(fig):
    #convert Plotly fig to  an array
    fig_bytes = fig.to_image(format="png")
    buf = io.BytesIO(fig_bytes)
    img = Image.open(buf)
    return np.asarray(img)

# n = 20 # number of radii
# h = 2/(n-1)
# r = np.linspace(h, 2,  n)
# theta = np.linspace(0, 2*np.pi, 60)
# r, theta = np.meshgrid(r,theta)
# r = r.flatten()
# theta = theta.flatten()

t = np.linspace(0, 10, 50)
x, y, z = np.cos(t), np.sin(t), t

x_eye = -1.25
y_eye = 2
z_eye = 0.5

fig.update_layout(title_text=gene_name,
                  scene_camera_eye=dict(x=x_eye, y=y_eye, z=z_eye),
                  width=500, height=500, 
                  scene_xaxis_visible=False, 
                  scene_yaxis_visible=False, 
                  scene_zaxis_visible=False)


def rotate_z(x, y, z, theta):
    w = x+1j*y
    return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

def make_frame(t):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
    fig.update_layout(scene_camera_eye=dict(x=xe, y=ye, z=ze))  #These are the updates that usually are performed within Plotly go.Frame definition
    return plotly_fig2array(fig)

animation = mpy.VideoClip(make_frame, duration=2)
animation.write_gif("test.gif", fps=20)

MoviePy - Building file test.gif with imageio.


t:  12%|█▎        | 5/40 [00:16<02:09,  3.70s/it, now=None]