## Attempt automatic fin segmentation using plane fitting

### Import packages and load image data

In [5]:
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
import napari
import numpy as np
from napari_animation import Animation
import random
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

import plotly.express as px

### Now, load an image dataset along with nucleus masks inferred using cellpose.

In [6]:
import warnings
warnings.filterwarnings("ignore")

# set parameters
filename = "2022_12_22 HCR Sox9a Tbx5a Emilin3a_1.zarr"
readPath = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files_small/" + filename
readPathLabels = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files_small/" + filename + "labels"
level = 1

#############
# Main image
#############

# read the image data
store = parse_url(readPath, mode="r").store
reader = Reader(parse_url(readPath))

# nodes may include images, labels etc
nodes = list(reader())

# first node will be the image pixel data
image_node = nodes[0]
image_data = image_node.data

#############
# Labels
#############

# read the image data
store_lb = parse_url(readPathLabels, mode="r").store
reader_lb = Reader(parse_url(readPathLabels))

# nodes may include images, labels etc
nodes_lb = list(reader_lb())

# first node will be the image pixel data
label_node = nodes_lb[1]
label_data = label_node.data

# extract key image attributes
omero_attrs = image_node.root.zarr.root_attrs['omero']
channel_metadata = omero_attrs['channels']  # list of channels and relevant info
multiscale_attrs = image_node.root.zarr.root_attrs['multiscales']
axis_names = multiscale_attrs[0]['axes']
dataset_info = multiscale_attrs[0]['datasets']  # list containing scale factors for each axis

Failed to parse metadata
Traceback (most recent call last):
  File "/Users/nick/miniforge3/envs/napari-env/lib/python3.10/site-packages/ome_zarr/reader.py", line 366, in __init__
    rgb = [(int(color[i : i + 2], 16) / 255) for i in range(0, 6, 2)]
  File "/Users/nick/miniforge3/envs/napari-env/lib/python3.10/site-packages/ome_zarr/reader.py", line 366, in <listcomp>
    rgb = [(int(color[i : i + 2], 16) / 255) for i in range(0, 6, 2)]
ValueError: invalid literal for int() with base 16: 're'
no parent found for <ome_zarr.reader.Label object at 0x16a0418d0>: None


### Extract 3D nucleus locations from labels

In [7]:
# extract useful info
from skimage.measure import label, regionprops

scale_vec = multiscale_attrs[0]["datasets"][level]["coordinateTransformations"][0]["scale"]
channel_names = [channel_metadata[i]["label"] for i in range(len(channel_metadata))]
colormaps = [channel_metadata[i]["color"] for i in range(len(channel_metadata))]

label_array = np.asarray(label_data[level].compute())
regions = regionprops(label_array,image_data[level][3, :, :, :])

centroid_array = np.empty((len(regions), 3))
for rgi, rg in enumerate(regions):
    centroid_array[rgi, :] = np.multiply(rg.centroid, scale_vec)

### Use nearest neighbore distances to estimate a "nearness" threshold
See https://towardsdatascience.com/3d-model-fitting-for-point-clouds-with-ransac-and-python-2ab87d5fd363

In [84]:
import sklearn
from sklearn.neighbors import KDTree
tree = KDTree(centroid_array, leaf_size=2)  
nearest_dist, nearest_ind = tree.query(centroid_array, k=8)
mean_nn_dist_vec = np.mean(nearest_dist[:,1:],axis=0)
threshold = np.mean(mean_nn_dist_vec)
print(mean_nn_dist_vec)
print(threshold)

[4.79803079 5.70459094 6.4030969  7.02434099 7.51600651 7.96008691
 8.42642902]
6.833226007999531


### Fit random plane to data and visualize

In [85]:
# draw 3 random points and use to define a plane 
xyz = np.roll(centroid_array, 2, axis=1) 
idx_samples = random.sample(range(len(xyz)), 3)
pts = xyz[idx_samples]
# use these to define two vectors and find normal vector
vecA = pts[1] - pts[0]
vecB = pts[2] - pts[0]
normal = np.cross(vecA, vecB)
# calculate plane coefficients
a,b,c = normal #/ np.linalg.norm(normal)
d=-np.sum(normal*pts[1])
# calculate distance for all points
distance = abs(a * xyz[:,0] + b * xyz[:,1] + c * xyz[:,2] + d
            ) / np.sqrt(a ** 2 + b ** 2 + c ** 2)
# apply threshold
idx_candidates = np.where(np.abs(distance) <= threshold)[0]
inliers = xyz[idx_candidates]
mask = np.ones(len(xyz), dtype=bool)
mask[idx_candidates] = False
outliers=xyz[mask]
print(np.min(abs(distance)))
print(normal)

0.0
[-2328.20288984   598.81322542   775.04798314]


### Plot segmentation results

In [87]:
import pandas as pd
import plotly.graph_objects as go

xx, yy = np.meshgrid(range(round(np.max(xyz[:, 0]))), range(round(np.max(xyz[:, 1]))))

# calculate corresponding z
z = (-normal[0] * xx - normal[1] * yy - d) * 1. /normal[2]
df_in = pd.DataFrame(xyz,columns=["X","Y","Z"])
df_out = pd.DataFrame(outliers,columns=["X","Y","Z"])

fig = px.scatter_3d(df_in, x="X", y="Y", z="Z",opacity=0.5)
fig.update_traces(marker=dict(color='blue'))
#fig.add_trace(go.Scatter3d(x=df_out.iloc[:,0], y=df_out.iloc[:,1], z=df_out.iloc[:,2], mode='markers', opacity=0.5))
#fig = px.scatter_3d(x=outliers[:, 0], y=outliers[:, 1], z=outliers[:, 2], color='red', opacity=0.75)
#fig.add_trace(go.Surface(z=z, x=xx, y=yy,colorscale='algae'))

fig.show()