In [4]:
import operator
import numpy as np
import matplotlib.pyplot as plt
from skimage import filters, morphology
from skimage.feature import peak_local_max
from scipy import ndimage
from scipy.ndimage import label
from skimage.io import imsave
import plotly.graph_objects as go
import plotly.io as pio


### Functions

In [5]:
def find_indices(array, condition=operator.eq, value=1):
    # Return list of indices for cells meeting a logical condition
    indices = np.argwhere(condition(array, value))
    return [tuple(idx) for idx in indices]

def check_neighbors(array, x, y, z, mode="count", prev=None):
    assert mode in ["count", "retrieve"], "Mode must be either 'count' or 'retrieve'"
    count = 0
    neighbors = []
    # Iterate through all possible neighbors
    for i in [-1, 0, 1]:
        for j in [-1, 0, 1]:
            for k in [-1, 0, 1]:
                # Skip the cell itself
                if i == 0 and j == 0 and k == 0:
                    continue
                # Neighbor cell coords
                nx, ny, nz = x + i, y + j, z + k
                # Check if the neighbor is within bounds
                if 0 <= nx < array.shape[0] and 0 <= ny < array.shape[1] and 0 <= nz < array.shape[2]:
                    if array[nx, ny, nz] != 0 and (nx, ny, nz) != prev:
                        count += 1
                        neighbors.append((nx, ny, nz))
    if mode == "count":
        return count
    if mode == "retrieve":
        return neighbors
    
    return None

def find_nodes(array, indices, condition=operator.eq, neighbor_criterion=1):
    cells_matching_criterion = []

    # Iterate through each cell in the array
    for ii in range(len(indices)):
        z = indices[ii][0]
        x = indices[ii][1]
        y = indices[ii][2]
        if array[z, x, y] != 0 and condition(check_neighbors(array, z, x, y, mode="count"), neighbor_criterion):
            cells_matching_criterion.append((z, x, y))
    return cells_matching_criterion

# Check if coordinate pairs are within a given tolerance
def check_tolerance(coord1, coord2, tolerance):
    return all(abs(a - b) <= tolerance for a, b in zip(coord1, coord2))

# Filter coords in list_a if they are too similar to coords in list_b
def filter_coordinates(list_a, list_b, tolerance=5):
    filtered_list = []
    for coord_a in list_a:
        if not any(check_tolerance(coord_a, coord_b, tolerance) for coord_b in list_b):
            filtered_list.append(coord_a)
    return filtered_list


def remove_floating_segments(array):
    # Define a 3x3x3 structuring element for 26-connectivity
    structuring_element = np.ones((3, 3, 3), dtype=int) 
    # Label all 26-connected regions
    labeled_array, num_features = label(array, structure=structuring_element)
    # Find the size of each connected component
    sizes = np.bincount(labeled_array.ravel())
    # Ignore the background component (label 0)
    sizes[0] = 0
    # Find the label of the largest connected component
    largest_label = sizes.argmax()
    # Create a new array that retains only the largest component
    largest_component = (labeled_array == largest_label).astype(int)
    
    return largest_component

### Load Processed Image Stack

In [6]:
# Load 3d data
data_3d = np.load('output/pvd_test.npy')  # Load your 3D neuron data

# Print descriptives
print(f"min: {np.amin(data_3d)} max:{np.amax(data_3d)} shape:{data_3d.shape} type:{type(data_3d)} ")

min: 0 max:1 shape:(188, 2044, 2042) type:<class 'numpy.ndarray'> 


### Preprocess

In [7]:
# Dilate by 1px
data_3d = morphology.binary_dilation(data_3d, morphology.ball(radius=1))

# Apply Gaussian smoothing
smoothed = filters.gaussian(data_3d, sigma=3)

# Threshold back to binary
thresh = filters.threshold_otsu(smoothed)
binary = smoothed > thresh

### Skeletonize

In [8]:
skeleton = morphology.skeletonize(binary, method='lee')

In [10]:
skeleton_main = remove_floating_segments(skeleton)

In [18]:
# Write numpy array to disk
np.save('output/pvd_skeleton.npy', skeleton_main)

KeyboardInterrupt: 

In [11]:
skeleton_idx = find_indices(skeleton_main)

In [12]:
# Find tips (cells with a single neighbor)
tips = find_nodes(skeleton_main, skeleton_idx, condition=operator.eq, neighbor_criterion=1)
print(f"number of endpoints: {len(tips)}")

number of endpoints: 174


In [13]:
# Find knots (cells with >= 3 neighbors)
knots = find_nodes(skeleton_main, skeleton_idx, condition=operator.ge, neighbor_criterion=3)
print(f"number of knots: {len(knots)}")

number of knots: 296


In [14]:
# Filter tips to make sure they aren't too near knots
tips = filter_coordinates(tips, knots, tolerance=5)
print(f"number of tips after filtering: {len(tips)}")

number of tips after filtering: 131


### Visualize

In [15]:
# Prepare skeleton data for plotting
image_stack = np.transpose(skeleton_main, (1, 2, 0))
x, y, z = image_stack.shape
Y, X, Z = np.meshgrid(np.arange(y), np.arange(x), np.arange(z))
colors = image_stack.ravel()

# Extract background points
visible_mask = colors != 0

# Extract coordinates for skeleton tips
tips_z = [point[0] for point in tips]
tips_x = [point[1] for point in tips]
tips_y = [point[2] for point in tips]

# Extract coordinates for skeleton knots
knot_z = [point[0] for point in knots]
knot_x = [point[1] for point in knots]
knot_y = [point[2] for point in knots]

In [16]:
# Plot
fig = go.Figure()

# Skeleton tips
fig.add_trace(go.Scatter3d(
    x=tips_x,
    y=tips_y,
    z=tips_z,
    mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
    marker=dict(
        size=5,
        color='red',  # You can customize the color
        opacity=.9
    )
))

# Skeleton knots
fig.add_trace(go.Scatter3d(
    x=knot_x,
    y=knot_y,
    z=knot_z,
    mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
    marker=dict(
        size=4,
        color='blue',  # You can customize the color
        opacity=.9
    )
))

# Skeleton structure
fig.add_trace(go.Scatter3d(
    x=X.ravel()[visible_mask],
    y=Y.ravel()[visible_mask],
    z=Z.ravel()[visible_mask],
    mode='markers',
    marker=dict(
        size=2,
        color='black',
        colorscale='Viridis',
        opacity=.1
    )
))

# Update layout
fig.update_layout(
    title='C. Elegans PVD Neuron',
    scene=dict(
        xaxis_title='X (pixels)',
        yaxis_title='Y (pixels)',
        zaxis_title='Z (image index)',
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=.1),  # Adjust z-axis scale if desired
        zaxis=dict(range=[0, skeleton.shape[0]]),  # Set z-axis bounds
        xaxis=dict(range=[0, skeleton.shape[1]]),  # Set x-axis bounds
        yaxis=dict(range=[0, skeleton.shape[2]]),   # Set y-axis bounds

    ),
    autosize=True

)

# Show the figure
fig.show()

In [17]:
# Save the plot to an HTML file
pio.write_html(fig, file='plots/skeletonization.html', auto_open=True)