In [74]:
import operator
import numpy as np
import matplotlib.pyplot as plt
from skimage import filters, morphology
from skimage.feature import peak_local_max
from scipy import ndimage
from scipy.ndimage import label
from skimage.io import imsave
import plotly.graph_objects as go
import plotly.io as pio


### Functions

In [75]:
def find_indices(array, condition=operator.eq, value=1):
    # Return list of indices for cells meeting a logical condition
    indices = np.argwhere(condition(array, value))
    return [tuple(idx) for idx in indices]

def check_neighbors(array, x, y, z, mode="count", prev=None):
    assert mode in ["count", "retrieve"], "Mode must be either 'count' or 'retrieve'"
    count = 0
    neighbors = []
    # Iterate through all possible neighbors
    for i in [-1, 0, 1]:
        for j in [-1, 0, 1]:
            for k in [-1, 0, 1]:
                # Skip the cell itself
                if i == 0 and j == 0 and k == 0:
                    continue
                # Neighbor cell coords
                nx, ny, nz = x + i, y + j, z + k
                # Check if the neighbor is within bounds
                if 0 <= nx < array.shape[0] and 0 <= ny < array.shape[1] and 0 <= nz < array.shape[2]:
                    if array[nx, ny, nz] != 0 and (nx, ny, nz) != prev:
                        count += 1
                        neighbors.append((nx, ny, nz))
    if mode == "count":
        return count
    if mode == "retrieve":
        return neighbors
    
    return None


def traverse_all_neighbors(array, start_indices, stop_indices, echo=False):
    segments = []
    visited = set()

    def dfs(x, y, z, path):
        if (x, y, z) in visited:
            return False
        visited.add((x, y, z))
        path.append((x, y, z))
        if (x, y, z) in stop_indices:
            segments.append(path.copy())
            return True
        for neighbor in check_neighbors(array, x, y, z, prev=None, mode="retrieve"):
            if dfs(neighbor[0], neighbor[1], neighbor[2], path):
                return True
        path.pop()
        return False

    for start in start_indices:
        x, y, z = start
        path = []
        if dfs(x, y, z, path):
            print(f"Path found from {start}: {path}")
        else:
            print(f"Stopped traversal early from {start}, no path found to a stop index.")

    return segments

def traverse_single_neighbors(array, start_indices, stop_indices):
    segments = []
    skeleton = array.copy()
    # Start at an endpoint and traverse skeleton until a knot is encountered
    for start in start_indices:
        x, y, z = start
        path = [start]
        prev = None
        while (x, y, z) not in stop_indices:
            next_neighbor = check_neighbors(skeleton, x, y, z, mode="retrieve", prev=prev)
            if next_neighbor is None:
                break
            prev = (x, y, z)
            skeleton[x, y, z] = 0 #blank out visited location
            x, y, z = next_neighbor
            path.append((x, y, z))
        if (x, y, z) in stop_indices:
            segments.append(path)
        else:
            print(f"Stopped traversal early at {(x, y, z)}, no single nonzero neighbor found.")

    return segments

def find_nodes(array, indices, condition=operator.eq, neighbor_criterion=1):
    cells_matching_criterion = []

    # Iterate through each cell in the array
    for ii in range(len(indices)):
        z = indices[ii][0]
        x = indices[ii][1]
        y = indices[ii][2]
        if array[z, x, y] != 0 and condition(check_neighbors(array, z, x, y, mode="count"), neighbor_criterion):
            cells_matching_criterion.append((z, x, y))
    return cells_matching_criterion

# Check if coordinate pairs are within a given tolerance
def check_tolerance(coord1, coord2, tolerance):
    return all(abs(a - b) <= tolerance for a, b in zip(coord1, coord2))

# Filter coords in list_a if they are too similar to coords in list_b
def filter_coordinates(list_a, list_b, tolerance=5):
    filtered_list = []
    for coord_a in list_a:
        if not any(check_tolerance(coord_a, coord_b, tolerance) for coord_b in list_b):
            filtered_list.append(coord_a)
    return filtered_list

def remove_close_coordinates(coords, tolerance=5):
    filtered_coords = []
    for coord in coords:
        if all(not check_tolerance(coord, existing, tolerance) for existing in filtered_coords):
            filtered_coords.append(coord)
    return filtered_coords

### Load Skeleton

In [76]:
# Load 3d data
skeleton = np.load('output/pvd_skeleton.npy')  # Load your 3D neuron data

# Print descriptives
print(f"min: {np.amin(skeleton)} max:{np.amax(skeleton)} shape:{skeleton.shape} type:{type(skeleton)} ")

min: 0 max:1 shape:(188, 2044, 2042) type:<class 'numpy.ndarray'> 


In [77]:
# Get indices of relevant cells
skeleton_idx = find_indices(skeleton)

In [78]:
# Find tips (cells with a single neighbor)
tips = find_nodes(skeleton, skeleton_idx, condition=operator.eq, neighbor_criterion=1)
print(f"number of endpoints: {len(tips)}")

number of endpoints: 174


In [79]:
# Find knots (cells with >= 3 neighbors)
knots = find_nodes(skeleton, skeleton_idx, condition=operator.ge, neighbor_criterion=3)
print(f"number of knots: {len(knots)}")

number of knots: 296


In [80]:
# Remove knots that are too close together
knots = remove_close_coordinates(knots, tolerance=1)
print(f"number of knots after filtering: {len(knots)}")

number of knots after filtering: 171


In [81]:
# Filter tips to make sure they aren't too near knots
tips = filter_coordinates(tips, knots, tolerance=5)
print(f"number of tips after filtering: {len(tips)}")

number of tips after filtering: 132


### Extract Segments

In [82]:
outer_segments = traverse_all_neighbors(skeleton, tips, knots)

Path found from (15, 1736, 518): [(15, 1736, 518), (14, 1735, 519), (14, 1734, 519), (14, 1733, 520), (14, 1732, 521), (14, 1731, 521), (15, 1730, 522), (15, 1729, 523), (14, 1728, 524), (14, 1727, 525), (14, 1726, 526), (14, 1725, 527), (14, 1724, 528), (14, 1723, 529), (14, 1722, 529), (13, 1721, 530), (13, 1720, 531), (13, 1719, 531), (13, 1718, 532), (12, 1717, 533), (12, 1716, 533), (12, 1715, 533), (12, 1714, 534), (12, 1713, 534), (12, 1712, 534), (13, 1711, 535)]
Path found from (15, 1797, 541): [(15, 1797, 541), (16, 1797, 540), (17, 1798, 539), (18, 1798, 539), (19, 1799, 538), (20, 1800, 537), (21, 1801, 536)]
Path found from (15, 1812, 465): [(15, 1812, 465), (16, 1812, 466), (16, 1812, 467), (16, 1813, 468), (17, 1813, 469), (18, 1813, 470), (17, 1814, 471), (18, 1815, 472), (18, 1816, 473), (18, 1816, 474), (18, 1817, 475), (19, 1817, 476), (19, 1818, 477), (19, 1819, 478), (19, 1820, 479), (19, 1821, 480), (20, 1822, 481), (21, 1823, 482), (21, 1824, 483), (21, 1825, 484

### Visualize

In [83]:
# Prepare skeleton data for plotting
image_stack = np.transpose(skeleton, (1, 2, 0))
x, y, z = image_stack.shape
Y, X, Z = np.meshgrid(np.arange(y), np.arange(x), np.arange(z))
colors = image_stack.ravel()

# Extract background points
visible_mask = colors != 0

# Extract coordinates for skeleton tips
tips_z = [point[0] for point in tips]
tips_x = [point[1] for point in tips]
tips_y = [point[2] for point in tips]

# Extract coordinates for skeleton knots
knot_z = [point[0] for point in knots]
knot_x = [point[1] for point in knots]
knot_y = [point[2] for point in knots]


In [84]:
# Extract segment coordinates and set up segment plotting
def extract_coordinates(list_of_lists):
    coordinates = []
    for sublist in list_of_lists:
        z = [coord[0] for coord in sublist]
        x = [coord[1] for coord in sublist]
        y = [coord[2] for coord in sublist]
        coordinates.append((z, x, y))
    return coordinates

def create_scatter3d_traces(coordinates):
    traces = []
    colors = [
        'red', 'green', 'orange', 'purple', 'brown', 'pink', 'cyan', 'magenta'
    ]  # Define more colors as needed
    for i, (z, x, y) in enumerate(coordinates):
        trace = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker=dict(
                size=3,
                color=colors[i % len(colors)],  # Cycle through colors
                opacity=1
            ),
            name=f'Segment {i+1}'
        )
        traces.append(trace)
    return traces

In [85]:
# Extract coordinates
seg_coordinates = extract_coordinates(outer_segments)

# Create Scatter3d traces
seg_traces = create_scatter3d_traces(seg_coordinates)

In [86]:
# Visualize segments
fig = go.Figure(data=seg_traces)

# Skeleton tips
fig.add_trace(go.Scatter3d(
    x=tips_x,
    y=tips_y,
    z=tips_z,
    mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
    marker=dict(
        size=5,
        color='black',  # You can customize the color
        opacity=1
    )
))

# Skeleton knots
fig.add_trace(go.Scatter3d(
    x=knot_x,
    y=knot_y,
    z=knot_z,
    mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
    marker=dict(
        size=5,
        color='blue',
        opacity=1
    )
))

# Original skeleton structure
fig.add_trace(go.Scatter3d(
    x=X.ravel()[visible_mask],
    y=Y.ravel()[visible_mask],
    z=Z.ravel()[visible_mask],
    mode='markers',
    marker=dict(
        size=2,
        color='black',
        colorscale='Viridis',
        opacity=.1
    )
))

fig.update_layout(
    title='C. Elegans PVD Neuron',
    scene=dict(
        xaxis_title='X (pixels)',
        yaxis_title='Y (pixels)',
        zaxis_title='Z (image index)',
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=.27),  # Adjust z-axis scale if desired
        zaxis=dict(range=[0, skeleton.shape[0]]),  # Set z-axis bounds
        xaxis=dict(range=[0, skeleton.shape[1]]),  # Set x-axis bounds
        yaxis=dict(range=[0, skeleton.shape[2]]),   # Set y-axis bounds

    ),
    autosize=True
)

# Save the plot to an HTML file
pio.write_html(fig, file='plots/skeleton_segmentation.html', auto_open=True)