In [5]:
import numpy as np
import operator
import pickle
from itertools import chain
from scipy.spatial import KDTree
from scipy.ndimage import label
import tifffile

In [6]:
def find_indices(array, condition=operator.eq, value=1):
    # Return list of indices for cells meeting a logical condition
    indices = np.argwhere(condition(array, value))
    return [tuple(idx) for idx in indices]

def assign_segments_to_indices(indices, segments):
    # Flatten the segments list and create a mapping of coordinates to segment indices
    flattened_segments = []
    coord_to_segment = {}
    
    for segment_index, segment in enumerate(segments):
        for coord in segment:
            flattened_segments.append(coord)
            coord_to_segment[tuple(coord)] = segment_index
    
    # Create a KDTree from the flattened segment coordinates
    tree = KDTree(flattened_segments)
    
    # List to store the segment assignments
    segment_assignments = []

    # Enumerate through the list of indices
    for idx in indices:
        # Query the KDTree for the nearest segment to the current index
        _, nearest_flattened_index = tree.query(idx)
        
        # Retrieve the segment index from the mapping
        nearest_segment_index = coord_to_segment[tuple(flattened_segments[nearest_flattened_index])]
        
        # Append the segment index to the assignments list
        segment_assignments.append(nearest_segment_index)

    return segment_assignments

def label_data_for_imagej_color(indices, segment_assignments, shape, num_segments):
    # Create an empty array to hold the labeled data (3 channels for RGB)
    labeled_array = np.zeros((*shape, 3), dtype=np.uint8)
    
    # Generate unique colors for each segment
    np.random.seed(0)  # For reproducibility
    colors = np.random.randint(0, 255, size=(num_segments, 3), dtype=np.uint8)
    
    # Assign colors to the corresponding coordinates
    for coord, segment in zip(indices, segment_assignments):
        labeled_array[coord[0], coord[1], coord[2]] = colors[segment]
    
    return labeled_array, colors

In [7]:
segments = []
timepoints = 4

# Load list of outer+core segments
for ii in range(timepoints):
    with open(f'output/all_segments_{ii}.pkl', 'rb') as f:
        segments.append(pickle.load(f))

In [8]:
for ii in range(timepoints):
    # Load preprocessed 3d data
    data_3d = np.load(f'output/pvd_binary_{ii}.npy')

    data_idx = find_indices(data_3d)
    print(f"timepoint {ii} relevant voxels: {len(data_idx)}")

    segment_assignments = assign_segments_to_indices(data_idx, segments[ii])

    shape = data_3d.shape
    num_segments = len(set(segment_assignments))+2
    labeled_array_color, label_colors = label_data_for_imagej_color(data_idx, segment_assignments, shape, num_segments)

    # Save the labeled array as a TIFF file
    tifffile.imwrite(f'output/labeled_data_color_{ii}.tif', labeled_array_color)
    print(f"timepoint {ii} saved.")

timepoint 0 relevant voxels: 1554900
timepoint 0 saved.
timepoint 1 relevant voxels: 1363523
timepoint 1 saved.
timepoint 2 relevant voxels: 1326163
timepoint 2 saved.
timepoint 3 relevant voxels: 1215627
timepoint 3 saved.


In [9]:
#COmbine this with the previos cell

matched_voxel_segments = []

for ii in range(timepoints):
    # Load preprocessed 3d data
    data_3d = np.load(f'output/pvd_binary_{ii}.npy')

    data_idx = find_indices(data_3d)
    print(f"timepoint {ii} relevant voxels: {len(data_idx)}")

    segment_assignments = assign_segments_to_indices(data_idx, segments[ii])

    matched_voxel_segments.append(segment_assignments)

    print(f"timepoint {ii} saved.")

timepoint 0 relevant voxels: 1554900
timepoint 0 saved.
timepoint 1 relevant voxels: 1363523
timepoint 1 saved.
timepoint 2 relevant voxels: 1326163
timepoint 2 saved.
timepoint 3 relevant voxels: 1215627
timepoint 3 saved.


In [10]:
data_3d = np.load(f'output/pvd_binary_{ii}.npy')
data_idx = np.array(find_indices(data_3d))
indices = data_idx[np.where(np.array(matched_voxel_segments[0]) == 25)]

In [11]:
segment_26 = []
timepoints = 4

for ii in range(timepoints):
    segment_26.append(data_idx[np.where(np.array(matched_voxel_segments[ii]) == 26)])

In [12]:
for ii in range(timepoints):
    print(len(segment_26[ii]))

2031
2106
1875
2594
