In [24]:
import numpy as np
import operator
import pickle
from scipy.spatial import KDTree
import tifffile
from scipy.spatial.distance import euclidean

import plotly.graph_objects as go
import plotly.io as pio

In [3]:
# Load list of outer+core segments
segments = []
timepoints = [0,1,2,3]

for timepoint in timepoints:
    with open(f'output/outer_segments_{timepoint}.pkl', 'rb') as f:
        segments.append(pickle.load(f))
    print(f"Number of outer segments in T{timepoint}: {len(segments[timepoint])}")

Number of outer segments in T0: 74
Number of outer segments in T1: 78
Number of outer segments in T2: 83
Number of outer segments in T3: 76


In [4]:
# Example segments
segments_0 = segments[0]
segments_1 = segments[1]
segments_2 = segments[2]
segments_3 = segments[3]

In [5]:
print(len(segments_0))
print(len(segments_1))
print(len(segments_2))
print(len(segments_3))

74
78
83
76


In [6]:
segments_0_ends = []
segments_1_ends = []
segments_2_ends = []
segments_3_ends = []

for segment in segments_0:
    segment = [segment[0], segment[-1]]
    segments_0_ends.append(segment)

for segment in segments_1:
    segment = [segment[0], segment[-1]]
    segments_1_ends.append(segment)

for segment in segments_2:
    segment = [segment[0], segment[-1]]
    segments_2_ends.append(segment)

for segment in segments_3:
    segment = [segment[0], segment[-1]]
    segments_3_ends.append(segment)

In [13]:
import numpy as np

def line_similarity(line1, line2):
    """
    Calculate similarity between two lines.
    This example uses Euclidean distance between the start and end points.
    """
    z1, x1, y1 = line1[0]
    z2, x2, y2 = line1[1]
    z1_, x1_, y1_ = line2[0]
    z2_, x2_, y2_ = line2[1]
    
    dist_start = np.sqrt((z1 - z1_)**2 + (x1 - x1_)**2 + (y1 - y1_)**2)
    dist_end = np.sqrt((z2 - z2_)**2 + (x2 - x2_)**2 + (y2 - y2_)**2)
    
    return dist_start + dist_end

def find_best_match(line, other_lines):
    """
    Find the best match for a line in a list of lines.
    Returns the index of the best matching line and its similarity score.
    """
    best_match_index = None
    best_score = float('inf')
    
    for idx, other_line in enumerate(other_lines):
        score = line_similarity(line, other_line)
        if score < best_score:
            best_score = score
            best_match_index = idx
    
    return best_match_index, best_score

def assign_confidence(score, threshold):
    """
    Assign a confidence rating based on the similarity score.
    """
    if score < threshold:
        return 'High'
    elif score < 2 * threshold:
        return 'Medium'
    else:
        return 'Low'

def find_matches_and_confidence(list1, list2, list3, list4, threshold=10):
    results = []
    
    # Step 1: Match list1 with list2
    for idx1, line in enumerate(list1):
        match2_index, score2 = find_best_match(line, list2)
        confidence2 = assign_confidence(score2, threshold)
        
        if match2_index is not None:
            match2_line = list2[match2_index]
        else:
            match2_line = None
        
        results.append({
            'line': line,
            'index1': idx1,
            'match2': match2_line,
            'index2': match2_index,
            'confidence2': confidence2,
        })
    
    # Step 2: Match results from list2 with list3
    for result in results:
        if result['match2'] is not None:
            match3_index, score3 = find_best_match(result['match2'], list3)
            confidence3 = assign_confidence(score3, threshold)
            
            if match3_index is not None:
                match3_line = list3[match3_index]
            else:
                match3_line = None
            
            result.update({
                'match3': match3_line,
                'index3': match3_index,
                'confidence3': confidence3,
            })
        else:
            result.update({
                'match3': None,
                'index3': None,
                'confidence3': 'Low',
            })
    
    # Step 3: Match results from list3 with list4
    for result in results:
        if result['match3'] is not None:
            match4_index, score4 = find_best_match(result['match3'], list4)
            confidence4 = assign_confidence(score4, threshold)
            
            if match4_index is not None:
                match4_line = list4[match4_index]
            else:
                match4_line = None
            
            result.update({
                'match4': match4_line,
                'index4': match4_index,
                'confidence4': confidence4,
            })
        else:
            result.update({
                'match4': None,
                'index4': None,
                'confidence4': 'Low',
            })
    
    return results


results = find_matches_and_confidence(segments_0_ends, segments_1_ends, segments_2_ends, segments_3_ends)

for ii, result in enumerate(results):
    print(f"segment {ii}: {result}")


segment 0: {'line': [(12, 1736, 517), (12, 1710, 534)], 'index1': 0, 'match2': [(8, 1727, 524), (8, 1708, 538)], 'index2': 1, 'confidence2': 'Medium', 'match3': [(5, 1714, 536), (6, 1708, 537)], 'index3': 0, 'confidence3': 'Low', 'match4': [(5, 1715, 535), (5, 1706, 538)], 'index4': 1, 'confidence4': 'High'}
segment 1: {'line': [(12, 1820, 525), (27, 1805, 533)], 'index1': 1, 'match2': [(32, 1775, 571), (31, 1781, 560)], 'index2': 64, 'confidence2': 'Low', 'match3': [(37, 1773, 569), (29, 1780, 561)], 'index3': 77, 'confidence3': 'High', 'match4': [(35, 1767, 572), (30, 1781, 556)], 'index4': 72, 'confidence4': 'Medium'}
segment 2: {'line': [(14, 1803, 536), (26, 1800, 536)], 'index1': 2, 'match2': [(32, 1775, 571), (31, 1781, 560)], 'index2': 64, 'confidence2': 'Low', 'match3': [(37, 1773, 569), (29, 1780, 561)], 'index3': 77, 'confidence3': 'High', 'match4': [(35, 1767, 572), (30, 1781, 556)], 'index4': 72, 'confidence4': 'Medium'}
segment 3: {'line': [(14, 1972, 312), (13, 1962, 331

In [14]:
def filter_results(results):
    """
    Filter out results that include a 'Low' confidence rating.
    """
    filtered_results = []
    
    for result in results:
        if result['confidence2'] != 'Low' and result['confidence3'] != 'Low' and result['confidence4'] != 'Low':
            filtered_results.append(result)
    
    return filtered_results

In [16]:
filtered_results = filter_results(results)

for ii, result in enumerate(filtered_results):
    print(f"segment {ii}: {result}")

segment 0: {'line': [(15, 1082, 818), (15, 1044, 832)], 'index1': 5, 'match2': [(10, 1079, 815), (10, 1043, 832)], 'index2': 3, 'confidence2': 'Medium', 'match3': [(8, 1079, 814), (8, 1042, 832)], 'index3': 6, 'confidence3': 'High', 'match4': [(5, 1077, 814), (5, 1041, 832)], 'index4': 0, 'confidence4': 'High'}
segment 1: {'line': [(16, 1441, 698), (14, 1395, 738)], 'index1': 7, 'match2': [(11, 1436, 695), (9, 1392, 739)], 'index2': 6, 'confidence2': 'Medium', 'match3': [(8, 1435, 695), (6, 1392, 738)], 'index3': 7, 'confidence3': 'High', 'match4': [(7, 1433, 694), (5, 1391, 739)], 'index4': 8, 'confidence4': 'High'}
segment 2: {'line': [(17, 937, 847), (17, 906, 847)], 'index1': 10, 'match2': [(12, 936, 847), (11, 905, 847)], 'index2': 10, 'confidence2': 'Medium', 'match3': [(10, 937, 846), (10, 905, 847)], 'index3': 14, 'confidence3': 'High', 'match4': [(9, 933, 847), (8, 904, 847)], 'index4': 15, 'confidence4': 'High'}
segment 3: {'line': [(18, 1039, 985), (17, 1001, 994)], 'index1'

In [18]:
t0_segments = []
t1_segments = []
t2_segments = []
t3_segments = []

for result in filtered_results:
    idx_1 = result['index1']
    idx_2 = result['index2']
    idx_3 = result['index3']
    idx_4 = result['index4']

    t0_segments.append(segments_0[idx_1])
    t1_segments.append(segments_1[idx_2])
    t2_segments.append(segments_2[idx_3])
    t3_segments.append(segments_3[idx_4])

In [21]:
print(len(t0_segments[3]))
print(len(t1_segments[3]))
print(len(t2_segments[3]))
print(len(t3_segments[3]))

39
34
31
31


In [22]:
# Extract segment coordinates and set up segment plotting
def extract_coordinates(list_of_lists):
    coordinates = []
    for sublist in list_of_lists:
        z = [coord[0] for coord in sublist]
        x = [coord[1] for coord in sublist]
        y = [coord[2] for coord in sublist]
        coordinates.append((z, x, y))
    return coordinates

def create_scatter3d_traces(coordinates, color_list):
    traces = []
    colors = color_list
    for i, (z, x, y) in enumerate(coordinates):
        trace = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker=dict(
                size=3,
                color=colors[i % len(colors)],  # Cycle through colors
                opacity=1
            ),
            name=f'Segment {i+1}'
        )
        traces.append(trace)
    return traces

In [35]:
# Extract coordinates
seg_coordinates_0 = extract_coordinates(t0_segments)
seg_coordinates_1 = extract_coordinates(t1_segments)
seg_coordinates_2 = extract_coordinates(t2_segments)
seg_coordinates_3 = extract_coordinates(t3_segments)

# Set colors
lvl_1_colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'cyan', 'magenta']

# Create Scatter3d traces
t0_traces = create_scatter3d_traces(seg_coordinates_0, lvl_1_colors)
t1_traces = create_scatter3d_traces(seg_coordinates_1, lvl_1_colors)
t2_traces = create_scatter3d_traces(seg_coordinates_2, lvl_1_colors)
t3_traces = create_scatter3d_traces(seg_coordinates_3, lvl_1_colors)

# Combine all traces into one list
all_traces = t0_traces + t1_traces + t2_traces + t3_traces

In [27]:
timepoint = 0

# Load 3d data
skeleton = np.load(f'output/pvd_skeleton_{timepoint}.npy')  # Load your 3D neuron data

# Print descriptives
print(f"min: {np.amin(skeleton)} max:{np.amax(skeleton)} shape:{skeleton.shape} type:{type(skeleton)} ")

# Prepare skeleton data for plotting
image_stack = np.transpose(skeleton, (1, 2, 0))
x, y, z = image_stack.shape
Y, X, Z = np.meshgrid(np.arange(y), np.arange(x), np.arange(z))
colors = image_stack.ravel()
# Extract background points
visible_mask = colors != 0

min: 0 max:1 shape:(188, 2044, 2042) type:<class 'numpy.ndarray'> 


In [36]:
# Visualize segments
fig = go.Figure(data=all_traces)

# Original skeleton structure
fig.add_trace(go.Scatter3d(
    x=X.ravel()[visible_mask],
    y=Y.ravel()[visible_mask],
    z=Z.ravel()[visible_mask],
    mode='markers',
    marker=dict(
        size=2,
        color='black',
        colorscale='Viridis',
        opacity=.1
    )
))

fig.update_layout(
    title='C. Elegans PVD Neuron',
    scene=dict(
        xaxis_title='X (pixels)',
        yaxis_title='Y (pixels)',
        zaxis_title='Z (image index)',
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=.27),  # Adjust z-axis scale if desired
        zaxis=dict(range=[0, skeleton.shape[0]]),  # Set z-axis bounds
        xaxis=dict(range=[0, skeleton.shape[1]]),  # Set x-axis bounds
        yaxis=dict(range=[0, skeleton.shape[2]]),   # Set y-axis bounds

    ),
    autosize=True
)

# Save the plot to an HTML file
pio.write_html(fig, file=f'plots/matched_segments_preview.html', auto_open=True)