In [38]:
import numpy as np
from skimage import morphology
from skimage.measure import label
import plotly.graph_objects as go

# Step 1: Create a sample 3D branched structure
def create_sample_structure():
    structure = np.zeros((10, 10, 10), dtype=int)
    # Create a branched structure manually
    structure[2:8, 4, 4] = 1  # Main branch
    structure[2:5, 4, 5] = 1  # Branch 1
    structure[5:8, 4, 5] = 1  # Branch 2
    structure[2:5, 5, 4] = 1  # Branch 3
    structure[5:8, 5, 4] = 1  # Branch 4
    structure[4, 4, 2:5] = 1  # Branch 5
    structure[4, 5, 2:5] = 1  # Branch 6
    structure[4, 4, 5:8] = 1  # Branch 7
    structure[4, 5, 5:8] = 1  # Branch 8
    return structure

structure = create_sample_structure()

# Step 2: Skeletonize the structure
skeleton = morphology.skeletonize_3d(structure)

# Step 3: Identify junctions and endpoints
def get_junctions_and_endpoints(skeleton):
    neighbors_count = np.zeros_like(skeleton)
    for offset in [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1),
                   (-1, -1, 0), (-1, 1, 0), (1, -1, 0), (1, 1, 0),
                   (0, -1, -1), (0, 1, -1), (0, -1, 1), (0, 1, 1),
                   (-1, 0, -1), (-1, 0, 1), (1, 0, -1), (1, 0, 1),
                   (-1, -1, -1), (-1, -1, 1), (-1, 1, -1), (-1, 1, 1),
                   (1, -1, -1), (1, -1, 1), (1, 1, -1), (1, 1, 1)]:
        neighbors_count += np.roll(np.roll(np.roll(skeleton, offset[0], axis=0), offset[1], axis=1), offset[2], axis=2)
    
    endpoints = (skeleton & (neighbors_count == 1)).astype(np.uint8)
    junctions = (skeleton & (neighbors_count > 2)).astype(np.uint8)
    return junctions, endpoints

junctions, endpoints = get_junctions_and_endpoints(skeleton)

# Step 4: Segment the skeleton into branches
def segment_skeleton(skeleton, junctions):
    labeled_skeleton, num_features = label(skeleton - junctions, return_num=True)
    return labeled_skeleton, num_features

labeled_skeleton, num_features = segment_skeleton(skeleton, junctions)

# Print the number of features
print(f'Number of features: {num_features}')

# Step 5: Plot the result using Plotly
def plot_3d_structure(labeled_skeleton):
    x, y, z = np.where(labeled_skeleton > 0)
    labels = labeled_skeleton[x, y, z]

    fig = go.Figure()

    for label in np.unique(labels):
        mask = (labels == label)
        fig.add_trace(go.Scatter3d(
            x=x[mask],
            y=y[mask],
            z=z[mask],
            mode='markers',
            marker=dict(
                size=4,
                color=label,  # Color by label
                colorscale='Viridis',  # Use a colorscale to distinguish branches
                opacity=0.8
            ),
            name=f'Branch {label}'
        ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        title='3D Branched Structure with Labeled Branches'
    )

    fig.show()

plot_3d_structure(labeled_skeleton)


Number of features: 4



Downcasting int32 to uint8 without scaling because max value 1 fits in uint8



In [33]:
num_features

1