## Notebook to experiment with strategies for integrating mRNA values

In [None]:
# import packages
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader

# Load test dataset of nucleus centroids
nc_path = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files/2022_12_15 HCR Hand2 Tbx5a Fgf10a_1_nucleus_props.csv"
nucleus_df = pd.read_csv(nc_path)

# load image and label datasets
level = 0
filename = "2022_12_15 HCR Hand2 Tbx5a Fgf10a_1.zarr"
readPath = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files/" + filename
readPathLabels = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files/" + filename + "labels"

#############
# Main image
#############

# read the image data
store = parse_url(readPath, mode="r").store
reader = Reader(parse_url(readPath))

# nodes may include images, labels etc
nodes = list(reader())

# first node will be the image pixel data
image_node = nodes[0]
image_data = image_node.data

#############
# Labels
#############

# read the image data
store_lb = parse_url(readPathLabels, mode="r").store
reader_lb = Reader(parse_url(readPathLabels))

# nodes may include images, labels etc
nodes_lb = list(reader_lb())

# first node will be the image pixel data
label_node = nodes_lb[1]
label_data = label_node.data

# extract key image attributes
omero_attrs = image_node.root.zarr.root_attrs['omero']
channel_metadata = omero_attrs['channels']  # list of channels and relevant info
multiscale_attrs = image_node.root.zarr.root_attrs['multiscales']
axis_names = multiscale_attrs[0]['axes']
dataset_info = multiscale_attrs[0]['datasets']  # list containing scale factors for each axis

### Iterate through the list of centroids

In [None]:
from skimage.measure import label, regionprops, regionprops_table

# initialize empty columns
channel_names = [channel_metadata[i]["label"] for i in range(len(channel_metadata))]
for ch in channel_names:
    nucleus_df[ch] = np.nan
    nucleus_df[ch + "_mean"] = np.nan

scale_vec = multiscale_attrs[0]["datasets"][level]["coordinateTransformations"][0]["scale"]

label_array = label_data[level].compute()
image_array = image_data[level]

regions = regionprops(label_array)
    

### For each region, extract centroid info and integrate mRNA levels

In [None]:
# list of colors to use when plotting mRNA levels
colormaps = [channel_metadata[i]["color"] for i in range(len(channel_metadata))]
# clean up data frame
colnames = nucleus_df.columns
clean_indices = [i for i,item in enumerate(colnames) if "Unnamed" not in item]
nucleus_df_clean = nucleus_df.iloc[:, clean_indices]
# compute each channel of image array separately to avoid dask error
im_array_list = []
for ch in range(len(channel_names)-1):
    im_array_list.append(np.asarray(image_array[ch, :, :, :].compute()))
    
# iterate through regions to extract key info
for rgi, rg in enumerate(regions):
    # get coordinates
    nucleus_df[["Z", "Y", "X"]].iloc[rgi] = np.multiply(rg.centroid, scale_vec)
    
    # iterate through channels
    nc_coords = rg.coords.astype(int)
    n_pix = nc_coords.shape[0]
    for ch in range(3):
        #nc_ch_coords = np.concatenate((np.ones((n_pix,1))*ch, nc_coords), axis=1).astype(int)
        im_ch = im_array_list[ch]
        mRNA_integral = np.sum(im_ch[tuple(nc_coords.T)])
        
        nucleus_df_clean[channel_names[ch]].iloc[rgi] = mRNA_integral
        nucleus_df_clean[channel_names[ch] + "_mean"].iloc[rgi] = mRNA_integral / n_pix
        
    

In [None]:
print(nucleus_df_clean.head(3))

### Make scatter plots showing mRNA levels for each gene

In [None]:
import plotly.graph_objects as go

fin_nuclei = np.where(nucleus_df_clean["pec_fin_flag"]==1)
colormaps[1] = "greens"
colormaps[2] = "blues"
channel_ind = 0
mRNA_channel = channel_names[channel_ind] + "_mean"

# estimate background
nucleus_df_fin = nucleus_df_clean.iloc[fin_nuclei]

# fig = px.histogram(nucleus_df_fin, x=mRNA_channel)
# fig.show()
fig = px.scatter_3d(nucleus_df_fin, x="X", y="Y", z="Z", 
              opacity=0.75, color=mRNA_channel, color_continuous_scale=colormaps[channel_ind])

fig.update_traces(marker=dict(size=5))
                
fig.show()


### Plot just top 25% expressing

In [None]:
mRNA_75 = nucleus_df_fin[mRNA_channel].quantile(0.75)

mRNA_fin_nuclei = np.where(nucleus_df_fin[mRNA_channel]>=mRNA_75)

fig = px.scatter_3d(nucleus_df_fin, x="X", y="Y", z="Z", 
              opacity=0.5)

fig.update_traces(marker=dict(size=7, color="gray"))

fig.add_trace(go.Scatter3d(x=nucleus_df_fin["X"].iloc[mRNA_fin_nuclei], 
                           y=nucleus_df_fin["Y"].iloc[mRNA_fin_nuclei],
                           z=nucleus_df_fin["Z"].iloc[mRNA_fin_nuclei], 
                           mode='markers', 
                           opacity=0.75,
                           marker=dict(color=nucleus_df_fin[mRNA_channel].iloc[mRNA_fin_nuclei],
                             colorscale=colormaps[channel_ind],
                             size=7)
                            ))
fig.show()

### Now, try "cell" based integration

In [None]:
# initialize arrays to track labels and distances
cell_label_array = np.zeros(label_array.shape)
dist_array = np.ones(label_array.shape)*np.inf

unique_labels = np.unique(label_array.flatten())
#print(unique_labels)


In [None]:
from scipy.ndimage import distance_transform_edt
import dask

unique_labels = unique_labels[1:]
for lb in range(0,1):
    mask_array = dask.array.from_array(label_array==lb, chunks='auto')
    d_array = distance_transform_edt(mask_array[:,:,:])