In [3]:
import numpy as np
import operator
import pickle
from scipy.spatial import KDTree
from itertools import chain
import tifffile

In [4]:
def find_indices(array, condition=operator.eq, value=1):
    # Return list of indices for cells meeting a logical condition
    indices = np.argwhere(condition(array, value))
    return [tuple(idx) for idx in indices]

In [5]:
outer_segments = [[],[],[],[]]

# Load list of outer+core segments
for ii in range(len(outer_segments)):
    with open(f'output/matched_segments_{ii}.pkl', 'rb') as f:
        outer_segments[ii] = pickle.load(f)
    print(f"number of outer outer segments: {len(outer_segments[ii])}")

number of outer outer_segments: 29
number of outer outer_segments: 29
number of outer outer_segments: 29
number of outer outer_segments: 29


In [3]:
timepoint = 0

# Load preprocessed 3d data
data_3d = np.load('output/pvd_test.npy')

# Print descriptives
print(f"min: {np.amin(data_3d)} max:{np.amax(data_3d)} shape:{data_3d.shape} type:{type(data_3d)} ")

min: 0 max:1 shape:(188, 2044, 2042) type:<class 'numpy.ndarray'> 


In [40]:
data_idx = find_indices(data_3d)
print(f"number of voxels: {len(data_idx)}")

number of voxels: 1256635


In [None]:
# Isolate core segment, then add to segments master list
outer_segments_flat = list(chain(*outer_segments))
# core_segment = filter_coordinates(skeleton_idx, outer_segments_flat) # This is slow
# segments = outer_segments + [core_segment]
# print(f"number of segments: {len(segments)}")

In [74]:
def assign_segments_to_indices(indices, segments):
    # Flatten the segments list and create a mapping of coordinates to segment indices
    flattened_segments = []
    coord_to_segment = {}
    
    for segment_index, segment in enumerate(segments):
        for coord in segment:
            flattened_segments.append(coord)
            coord_to_segment[tuple(coord)] = segment_index
    
    # Create a KDTree from the flattened segment coordinates
    tree = KDTree(flattened_segments)
    
    # List to store the segment assignments
    segment_assignments = []

    # Enumerate through the list of indices
    for idx in indices:
        # Query the KDTree for the nearest segment to the current index
        _, nearest_flattened_index = tree.query(idx)
        
        # Retrieve the segment index from the mapping
        nearest_segment_index = coord_to_segment[tuple(flattened_segments[nearest_flattened_index])]
        
        # Append the segment index to the assignments list
        segment_assignments.append(nearest_segment_index)

    return segment_assignments

In [75]:
segment_assignments = assign_segments_to_indices(data_idx, segments)

In [92]:
def label_data_for_imagej_color(indices, segment_assignments, shape, num_segments):
    # Create an empty array to hold the labeled data (3 channels for RGB)
    labeled_array = np.zeros((*shape, 3), dtype=np.uint8)
    
    # Generate unique colors for each segment
    np.random.seed(0)  # For reproducibility
    colors = np.random.randint(0, 255, size=(num_segments, 3), dtype=np.uint8)
    
    # Assign colors to the corresponding coordinates
    for coord, segment in zip(indices, segment_assignments):
        labeled_array[coord[0], coord[1], coord[2]] = colors[segment]
    
    return labeled_array, colors


shape = data_3d.shape
num_segments = len(set(segment_assignments))+2
labeled_array_color, label_colors = label_data_for_imagej_color(data_idx, segment_assignments, shape, num_segments)

# Save the labeled array as a TIFF file
tifffile.imwrite('labeled_data_color.tif', labeled_array_color)

In [91]:
labeled_array_color.shape

(188, 2044, 2042, 3)

In [96]:
from scipy.ndimage import label

def process_color_clusters(color_array, colors, threshold=100):

    def find_color_indices(array, color):
        return np.all(array == color, axis=-1)

    def check_bulkiness(array, indices, threshold):
        structure = np.ones((3, 3, 3), dtype=int)  # Define connectivity for labeling
        labeled_array, num_features = label(indices, structure=structure)
        for feature in range(1, num_features + 1):
            cluster = (labeled_array == feature)
            if np.sum(cluster) > threshold:
                array[cluster] = [0, 0, 0]

    # Iterate over each color and process
    for color in colors:
        indices = find_color_indices(color_array, color)
        check_bulkiness(color_array, indices, threshold)

    return color_array

In [97]:
trimmed_array = process_color_clusters(labeled_array_color, label_colors, threshold=100)

KeyboardInterrupt: 