In [310]:
import numpy as np
import pickle
import operator
from collections import deque
from itertools import chain
import plotly.graph_objects as go
import plotly.io as pio

### Functions

In [203]:
def find_indices(array, condition=operator.eq, value=1):
    # Return list of indices for cells meeting a logical condition
    indices = np.argwhere(condition(array, value))
    return [tuple(idx) for idx in indices]

def check_neighbors(array, x, y, z, mode="count", prev=None):
    assert mode in ["count", "retrieve"], "Mode must be either 'count' or 'retrieve'"
    count = 0
    neighbors = []
    # Iterate through all possible neighbors
    for i in [-1, 0, 1]:
        for j in [-1, 0, 1]:
            for k in [-1, 0, 1]:
                # Skip the cell itself
                if i == 0 and j == 0 and k == 0:
                    continue
                # Neighbor cell coords
                nx, ny, nz = x + i, y + j, z + k
                # Check if the neighbor is within bounds
                if 0 <= nx < array.shape[0] and 0 <= ny < array.shape[1] and 0 <= nz < array.shape[2]:
                    if array[nx, ny, nz] != 0 and (nx, ny, nz) != prev:
                        count += 1
                        neighbors.append((nx, ny, nz))
    if mode == "count":
        return count
    if mode == "retrieve":
        return neighbors
    
    return None


def traverse_all_neighbors_bfs(array, start_indices, stop_indices):
    segments = []
    next_start = []

    def bfs(start):
        queue = deque([([start], start)])
        visited = set([start])
        
        while queue:
            path, current = queue.popleft()
            x, y, z = current
            
            if current in stop_indices:
                segments.append(path)
                next_start.append(current)
                return
            
            for neighbor in check_neighbors(array, x, y, z, mode="retrieve"):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((path + [neighbor], neighbor))

    for start in start_indices:
        bfs(start)
    
    return segments, next_start


def find_nodes(array, indices, condition=operator.eq, neighbor_criterion=1):
    cells_matching_criterion = []

    # Iterate through each cell in the array
    for ii in range(len(indices)):
        z = indices[ii][0]
        x = indices[ii][1]
        y = indices[ii][2]
        if array[z, x, y] != 0 and condition(check_neighbors(array, z, x, y, mode="count"), neighbor_criterion):
            cells_matching_criterion.append((z, x, y))
    return cells_matching_criterion

# Check if coordinate pairs are within a given tolerance
def check_tolerance(coord1, coord2, tolerance):
    return all(abs(a - b) <= tolerance for a, b in zip(coord1, coord2))

# Filter coords in list_a if they are too similar to coords in list_b
def filter_coordinates(list_a, list_b, tolerance=5):
    filtered_list = []
    for coord_a in list_a:
        if not any(check_tolerance(coord_a, coord_b, tolerance) for coord_b in list_b):
            filtered_list.append(coord_a)
    return filtered_list

# If multiple nodes are too close, remove all but one
def remove_close_coordinates(coords, tolerance=5):
    filtered_coords = []
    for coord in coords:
        if all(not check_tolerance(coord, existing, tolerance) for existing in filtered_coords):
            filtered_coords.append(coord)
    return filtered_coords

# Subtract segments from skeleton based on list of segments from traverse_all_neighbors_bfs()
def subtract_segments(skeleton, segments):
    for segment in segments:
        z = [point[0] for point in segment]
        x = [point[1] for point in segment]
        y = [point[2] for point in segment]
        for ii in range(len(segment)):
            skeleton[z,x,y] = 0
            
    return skeleton


### Load Skeleton

In [191]:
# Load 3d data
skeleton = np.load('output/pvd_skeleton.npy')  # Load your 3D neuron data

# Print descriptives
print(f"min: {np.amin(skeleton)} max:{np.amax(skeleton)} shape:{skeleton.shape} type:{type(skeleton)} ")

min: 0 max:1 shape:(188, 2044, 2042) type:<class 'numpy.ndarray'> 


### Find Nodes

In [192]:
# Get indices of relevant cells
skeleton_idx = find_indices(skeleton)

In [193]:
# Find tips (cells with a single neighbor)
tips = find_nodes(skeleton, skeleton_idx, condition=operator.eq, neighbor_criterion=1)
print(f"number of tips: {len(tips)}")

number of endpoints: 174


In [194]:
# Find knots (cells with >= 3 neighbors)
knots = find_nodes(skeleton, skeleton_idx, condition=operator.ge, neighbor_criterion=3)
print(f"number of knots: {len(knots)}")

number of knots: 296


In [195]:
# Remove knots that are too close together
knots = remove_close_coordinates(knots, tolerance=1)
print(f"number of knots after filtering: {len(knots)}")

number of knots after filtering: 171


In [196]:
# Filter tips to make sure they aren't too near knots
tips = filter_coordinates(tips, knots, tolerance=5)
print(f"number of tips after filtering: {len(tips)}")

number of tips after filtering: 132


### Extract Segments

In [248]:
# Start at tips and stop at knots
outer_segments, lvl_2_start = traverse_all_neighbors_bfs(skeleton, tips, knots)

In [288]:
# Isolate core segment, then add to segments master list
outer_segments_flat = list(chain(*outer_segments))
core_segment = filter_coordinates(skeleton_idx, outer_segments_flat) # This is slow
segments = outer_segments + [core_segment]
print(f"number of segments: {len(segments)}")

### Save Segments

In [311]:
# Write core segment to disk
np.save('output/pvd_core_segment.npy', core_segment)

# Write segments list to disk
with open('output/outer_segments.pkl', 'wb') as f:
    pickle.dump(segments, f)

### Visualize

In [268]:
# Prepare skeleton data for plotting
image_stack = np.transpose(core_segment, (1, 2, 0))
x, y, z = image_stack.shape
Y, X, Z = np.meshgrid(np.arange(y), np.arange(x), np.arange(z))
colors = image_stack.ravel()

# Extract background points
visible_mask = colors != 0

# Extract coordinates for skeleton tips
tips_z = [point[0] for point in tips]
tips_x = [point[1] for point in tips]
tips_y = [point[2] for point in tips]

# Extract coordinates for skeleton knots
knot_z = [point[0] for point in knots]
knot_x = [point[1] for point in knots]
knot_y = [point[2] for point in knots]

# # Extract coordinates for skeleton knots
# lvl_2_z = [point[0] for point in lvl_2_start]
# lvl_2_x = [point[1] for point in lvl_2_start]
# lvl_2_y = [point[2] for point in lvl_2_start]


In [242]:
# Extract segment coordinates and set up segment plotting
def extract_coordinates(list_of_lists):
    coordinates = []
    for sublist in list_of_lists:
        z = [coord[0] for coord in sublist]
        x = [coord[1] for coord in sublist]
        y = [coord[2] for coord in sublist]
        coordinates.append((z, x, y))
    return coordinates

def create_scatter3d_traces(coordinates, color_list):
    traces = []
    colors = color_list
    for i, (z, x, y) in enumerate(coordinates):
        trace = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker=dict(
                size=3,
                color=colors[i % len(colors)],  # Cycle through colors
                opacity=1
            ),
            name=f'Segment {i+1}'
        )
        traces.append(trace)
    return traces

In [308]:
# Extract coordinates
seg_coordinates = extract_coordinates(segments)
# lvl_2_coords = extract_coordinates(second_segments)

# Set colors
lvl_1_colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'cyan', 'magenta']
lvl_2_colors = [color for color in reversed(lvl_1_colors)]

# Create Scatter3d traces
lvl_1_traces = create_scatter3d_traces(seg_coordinates, lvl_1_colors)
# lvl_2_traces = create_scatter3d_traces(lvl_2_coords, lvl_2_colors)

In [309]:
# Visualize segments
fig = go.Figure(data=lvl_1_traces)

# for trace in lvl_2_traces:
#     fig.add_trace(trace)

# # 2nd level tips
# fig.add_trace(go.Scatter3d(
#     x=lvl_2_x,
#     y=lvl_2_y,
#     z=lvl_2_z,
#     mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
#     marker=dict(
#         size=20,
#         color='red',  # You can customize the color
#         opacity=1
#     )
# ))

# Skeleton tips
fig.add_trace(go.Scatter3d(
    x=tips_x,
    y=tips_y,
    z=tips_z,
    mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
    marker=dict(
        size=6,
        color='black',  # You can customize the color
        opacity=1
    )
))

# # Skeleton knots
# fig.add_trace(go.Scatter3d(
#     x=knot_x,
#     y=knot_y,
#     z=knot_z,
#     mode='markers',  # Use 'lines' for a line plot or 'markers+lines' for both
#     marker=dict(
#         size=8,
#         color='blue',
#         opacity=1
#     )
# ))

# Original skeleton structure
fig.add_trace(go.Scatter3d(
    x=X.ravel()[visible_mask],
    y=Y.ravel()[visible_mask],
    z=Z.ravel()[visible_mask],
    mode='markers',
    marker=dict(
        size=2,
        color='black',
        colorscale='Viridis',
        opacity=.1
    )
))

fig.update_layout(
    title='C. Elegans PVD Neuron',
    scene=dict(
        xaxis_title='X (pixels)',
        yaxis_title='Y (pixels)',
        zaxis_title='Z (image index)',
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=.27),  # Adjust z-axis scale if desired
        zaxis=dict(range=[0, skeleton.shape[0]]),  # Set z-axis bounds
        xaxis=dict(range=[0, skeleton.shape[1]]),  # Set x-axis bounds
        yaxis=dict(range=[0, skeleton.shape[2]]),   # Set y-axis bounds

    ),
    autosize=True
)

# Save the plot to an HTML file
pio.write_html(fig, file='plots/skeleton_segmentation.html', auto_open=True)