# Skeletonization

This notebook contains code to skeletonize segmentations and calculate ERL, number of merge errors and number of split errors

## Imports

In [2]:
import numpy as np 
import kimimaro
import h5py
from cloudvolume import Skeleton
from glob import glob
from zmesh import Mesher
import trimesh
import skeletor as sk

# Works due to changes in .py files directly (from em_util.io import seg_relabel -> from em_util.em_util.io import seg_relabel)
# and because em_util in same folder
from em_erl.erl import skel_to_erlgraph
from em_erl.eval import compute_erl_score

## Functions

In [3]:
# Stolen from PytorchConnectomics/em_erl
# https://github.com/PytorchConnectomics/em_erl/blob/main/em_erl/eval.py
def compute_segment_lut(
    segment, node_position, mask=None, chunk_num=1, data_type=np.uint32
):
    """
    The function `compute_segment_lut` is a low memory version of a lookup table
    computation for node segments in a 3D volume.

    :param node_position: A numpy array containing the coordinates of each node. The shape of the array
    is (N, 3), where N is the number of nodes and each row represents the (z, y, x) coordinates of a
    node
    :param segment: either a 3D volume or a string representing the
    name of a file containing segment data.
    :param chunk_num: The parameter `chunk_num` is the number of chunks into which the volume is divided
    for reading. It is used in the `read_vol` function to specify which chunk to read, defaults to 1
    (optional)
    :param data_type: The parameter `data_type` is the data type of the array used to store the node segment
    lookup table. In this case, it is set to `np.uint32`, which means the array will store unsigned
    32-bit integers
    :return: a list of numpy arrays, where each array represents the node segment lookup table for a
    specific segment.
    """
    if not isinstance(segment, str):
        # load the whole segment
        node_lut = segment[
            node_position[:, 0], node_position[:, 1], node_position[:, 2]
        ]
        mask_id = []
        if mask is not None:
            if isinstance(mask, str):
                mask = read_vol(mask)
            mask_id = segment[mask > 0]
    else:
        # read segment by chunk (when memory is limited)
        assert ".h5" in segment
        node_lut = np.zeros(node_position.shape[0], data_type)
        mask_id = [[]] * chunk_num
        start_z = 0
        for chunk_id in range(chunk_num):
            seg = read_vol(segment, None, chunk_id, chunk_num)
            last_z = start_z + seg.shape[0]
            ind = (node_position[:, 0] >= start_z) * (node_position[:, 0] < last_z)
            pts = node_position[ind]
            node_lut[ind] = seg[pts[:, 0] - start_z, pts[:, 1], pts[:, 2]]
            if mask is not None:
                if isinstance(mask, str):
                    mask_z = read_vol(mask, None, chunk_id, chunk_num)
                else:
                    mask_z = mask[start_z:last_z]
                mask_id[chunk_id] = seg[mask_z > 0]
            start_z = last_z
        if mask is not None:
            # remove irrelevant seg ids (not used by nodes)
            node_lut_unique = np.unique(node_lut)
            for chunk_id in range(chunk_num):
                mask_id[chunk_id] = mask_id[chunk_id][
                    np.in1d(mask_id[chunk_id], node_lut_unique)
                ]
        mask_id = np.concatenate(mask_id)
    return node_lut, mask_id

# Analogous function for meshes
def compute_mesh_lut(
    meshes, node_position, anisotropy=(4, 4, 40), data_type=np.uint32
):
    """
    Compute lookup table for node segments in a 3D volume of meshes.

    :param meshes: A dictionary of trimesh mesh objects, eah with a unique mesh id.
    :param node_position: A numpy array containing the coordinates of each node. The shape of the array
    is (N, 3), where N is the number of nodes and each row represents the (z, y, x) coordinates of a
    node
    :param anisotropy: Voxel anisotropy of original volume. Used to scale mesh positions.
    :param data_type: The parameter `data_type` is the data type of the array used to store the node segment
    lookup table. In this case, it is set to `np.uint32`, which means the array will store unsigned
    32-bit integers.

    :return: a list of numpy arrays, where each array represents the node segment lookup table for a
    specific segment.
    """

    scaling_factor = (1/anisotropy[0], 1/anisotropy[1], 1/anisotropy[2])

    node_lut = np.zeros(node_position.shape[0], dtype=data_type)
    for mesh_id in meshes:
        mesh = meshes[mesh_id]
        mesh.apply_scale(scaling_factor)

        contains = mesh.contains(node_position)
        node_lut[contains] = mesh_id

    return node_lut


# For Microns data
def load_data(path):
    """
    Helper function to load all npy and npz files in a directory into a single numpy array.
    Requires all files to have the same shape.

    Parameters
    ----------
    path : str
        The path to the directory containing the npy and npz files.

    Returns
    -------
    np.ndarray
        A numpy array containing the loaded data.
    """
    files = sorted(glob(path + "*.npy"))

    data_shape = np.load(files[0]).shape

    out = np.zeros((data_shape[0], data_shape[1], len(files)))

    for i, f in enumerate(files):
        out[:, :, i] = np.load(f)

    return out.astype(int)

## For FFN Data

### Load results from ffn segmentation

In [4]:
with np.load('data/fib25_results/fib25/training2/0/0/seg-0_0_0.npz') as data:
    segmentation = data['segmentation']

### Skeletonize

In [5]:
skels = kimimaro.skeletonize(
  segmentation, 
  teasar_params={
    "scale": 1.5, 
    "const": 300, # physical units
    "pdrf_scale": 100000,
    "pdrf_exponent": 4,
    "soma_acceptance_threshold": 3500, # physical units
    "soma_detection_threshold": 750, # physical units
    "soma_invalidation_const": 300, # physical units
    "soma_invalidation_scale": 2,
    "max_paths": 300, # default None
  },
  # object_ids=[ ... ], # process only the specified labels
  # extra_targets_before=[ (27,33,100), (44,45,46) ], # target points in voxels
  # extra_targets_after=[ (27,33,100), (44,45,46) ], # target points in voxels
  dust_threshold=1000, # skip connected components with fewer than this many voxels
  anisotropy=(10,10,10), # default True
  fix_branching=True, # default True
  fix_borders=True, # default True
  fill_holes=True, # default False
  fix_avocados=True, # default False
  progress=True, # default False, show progress bar
  parallel=1, # <= 0 all cpu, 1 single process, 2+ multiprocess
  parallel_chunk_size=100, # how many skeletons to process before updating progress bar
)


Filling Holes: 100%|██████████| 3982/3982 [00:02<00:00, 1766.52it/s]
Fixing Avocados: 100%|██████████| 3/3 [00:00<00:00,  6.53it/s]
Avocado Pass:   0%|          | 0/20 [00:00<?, ?it/s]
Skeletonizing Labels: 100%|██████████| 251/251 [00:52<00:00,  4.77it/s]


### Load ground truth segmentation

In [6]:
with h5py.File('data/fib25_results/flyEM/groundtruth.h5', 'r') as f:
    gt_seg = np.array(f['stack'])
    transforms = np.array(f['transforms'])

### Skeletonize

In [7]:
gt_skels = kimimaro.skeletonize(
  gt_seg, 
  teasar_params={
    "scale": 1.5, 
    "const": 300, # physical units
    "pdrf_scale": 100000,
    "pdrf_exponent": 4,
    "soma_acceptance_threshold": 3500, # physical units
    "soma_detection_threshold": 750, # physical units
    "soma_invalidation_const": 300, # physical units
    "soma_invalidation_scale": 2,
    "max_paths": 300, # default None
  },
  # object_ids=[ ... ], # process only the specified labels
  # extra_targets_before=[ (27,33,100), (44,45,46) ], # target points in voxels
  # extra_targets_after=[ (27,33,100), (44,45,46) ], # target points in voxels
  dust_threshold=1000, # skip connected components with fewer than this many voxels
  anisotropy=(10,10,10), # default True
  fix_branching=True, # default True
  fix_borders=True, # default True
  fill_holes=True, # default False
  fix_avocados=True, # default False
  progress=True, # default False, show progress bar
  parallel=1, # <= 0 all cpu, 1 single process, 2+ multiprocess
  parallel_chunk_size=100, # how many skeletons to process before updating progress bar
)

Filling Holes: 100%|██████████| 36726/36726 [00:00<00:00, 48573.01it/s]
Avocado Pass:   0%|          | 0/20 [00:00<?, ?it/s]
Skeletonizing Labels: 100%|██████████| 2745/2745 [00:28<00:00, 97.95it/s] 


## Expected Run Length calculation

### Get erl score from em_erl

In [8]:
# get gt graph from gt skeleton
gt_graph = skel_to_erlgraph(gt_skels)

# get predicted graph from predicted skeletons
pred_graph = skel_to_erlgraph(skels)

In [9]:
# get positions of nodes in coordinates with voxel size (10,10,10)
node_position = pred_graph.get_nodes_position((10,10,10))

gt_node_position = gt_graph.get_nodes_position((10,10,10))

In [10]:
# compute segment lookup-table. For each skeleton node, which segment of ground truth would it belong to
node_segment_lut, mask_segment_id = compute_segment_lut(gt_seg, node_position, None)

gt_node_segment_lut, gt_mask_segment_id = compute_segment_lut(gt_seg, gt_node_position, None)

In [11]:
# Calculate score
score = compute_erl_score(erl_graph=pred_graph,
    node_segment_lut=node_segment_lut,
    mask_segment_id=mask_segment_id,
    merge_threshold=0,
    verbose=True)

score.compute_erl(None)

score.print_erl()

all skel
ERL	: 1.63
gt ERL	: 7539.24
#skel	: 249
-----------------


In [12]:
# Compare with score for ground truth compared to itself for checking
gt_score = compute_erl_score(erl_graph=gt_graph,
    node_segment_lut=gt_node_segment_lut,
    mask_segment_id=gt_mask_segment_id,
    merge_threshold=0,
    verbose=True)

gt_score.compute_erl(None)
gt_score.print_erl()

all skel
ERL	: 5709.76
gt ERL	: 5709.76
#skel	: 2327
-----------------


as we can see, erl is computed but shows that segmentation is trash. This doesn't fit expectations as visually, segmentation seemed ok. My guess is that there is an issue in the data format for ground truth segmentation: When examining data in playground.ipynb, slicing across z gives good results for both prediction and gt, but slicing through y or x gives only good results for prediction. In some way, ground truth segmentation is not aligned with raw image.

## For Microns Data

### Ground Truth

In [13]:
gt_seg = load_data("data/microns/seg_v1078/")

# Some procesing to reduce memory usage
gt_seg = gt_seg - (np.min(gt_seg[np.nonzero(gt_seg)]) - 1)
gt_seg = gt_seg.astype(np.uint32)

In [14]:
gt_skels = kimimaro.skeletonize(
        gt_seg, 
        teasar_params={
            "scale": 1.5, 
            "const": 300, # physical units
            "pdrf_scale": 100000,
            "pdrf_exponent": 4,
            "soma_acceptance_threshold": 3500, # physical units
            "soma_detection_threshold": 750, # physical units
            "soma_invalidation_const": 300, # physical units
            "soma_invalidation_scale": 2,
            "max_paths": 300, # default None
        },
        dust_threshold=1000, # skip connected components with fewer than this many voxels
        anisotropy=(4,4,40), # default True
        fix_branching=True, # default True
        fix_borders=True, # default True
        fill_holes=True, # default False
        fix_avocados=True, # default False
        progress=True, # default False, show progress bar
        parallel=1, # <= 0 all cpu, 1 single process, 2+ multiprocess
        parallel_chunk_size=100, # how many skeletons to process before updating progress bar
        )

Filling Holes: 100%|██████████| 1061/1061 [00:02<00:00, 439.11it/s]
Fixing Avocados: 100%|██████████| 1/1 [00:00<00:00,  3.71it/s]
Avocado Pass:   0%|          | 0/20 [00:00<?, ?it/s]
Skeletonizing Labels: 100%|██████████| 602/602 [01:38<00:00,  6.11it/s]


### Earlier unproofread Data

In [15]:
old_seg = load_data('data/microns/seg_v117/')

# Some procesing to reduce memory usage
old_seg = old_seg - (np.min(old_seg[np.nonzero(old_seg)]) - 1)
old_seg = old_seg.astype(np.uint32)

In [16]:
old_skels = kimimaro.skeletonize(
        old_seg, 
        teasar_params={
            "scale": 1.5, 
            "const": 300, # physical units
            "pdrf_scale": 100000,
            "pdrf_exponent": 4,
            "soma_acceptance_threshold": 3500, # physical units
            "soma_detection_threshold": 750, # physical units
            "soma_invalidation_const": 300, # physical units
            "soma_invalidation_scale": 2,
            "max_paths": 300, # default None
        },
        dust_threshold=1000, # skip connected components with fewer than this many voxels
        anisotropy=(4, 4, 40), # default True
        fix_branching=True, # default True
        fix_borders=True, # default True
        fill_holes=True, # default False
        fix_avocados=True, # default False
        progress=True, # default False, show progress bar
        parallel=1, # <= 0 all cpu, 1 single process, 2+ multiprocess
        parallel_chunk_size=100, # how many skeletons to process before updating progress bar
        )

Filling Holes: 100%|██████████| 1061/1061 [00:02<00:00, 440.07it/s]
Fixing Avocados: 100%|██████████| 1/1 [00:00<00:00,  5.49it/s]
Avocado Pass:   0%|          | 0/20 [00:00<?, ?it/s]
Skeletonizing Labels: 100%|██████████| 602/602 [01:35<00:00,  6.29it/s]


In [17]:
gt_graph = skel_to_erlgraph(gt_skels)
proof_graph = skel_to_erlgraph(old_skels)

In [18]:
nodes_position = proof_graph.get_nodes_position((4, 4, 40))

node_segment_lut, mask_segment_id = compute_segment_lut(gt_seg, nodes_position)

In [19]:
score = compute_erl_score(erl_graph=proof_graph,
    node_segment_lut=node_segment_lut,
    mask_segment_id=mask_segment_id,
    merge_threshold=0,
    verbose=True)

score.compute_erl(None)
score.print_erl()

all skel
ERL	: 11977.55
gt ERL	: 12036.00
#skel	: 484
-----------------


In [20]:
node_segment_lut

array([3403396225, 3403396225, 3403396225, ..., 3186486657, 3186486657,
       3186486657], dtype=uint32)

## Calculate ERL for meshes

### Make meshes from segmentations

In [21]:
# Make mesher for voxel size (4, 4, 40) [(x, y, z)]
old_mesher = Mesher((4,4,40))
old_mesher.mesh(old_seg, close=True)

# Extract meshes for further processing
old_meshes = dict()
for obj_id in old_mesher.ids():
  mesh = old_mesher.get(
      obj_id, 
      normals=False, # whether to calculate normals or not
      # No simplification, so that lookup table is more accurate
      reduction_factor=0,

      # max_errors and voxel_centered may improve lut calculation

      # Max tolerable error in physical distance
      # note: if max_error is not set, the max error
      # will be set equivalent to one voxel along the 
      # smallest dimension.
      max_error=8,
      # whether meshes should be centered in the voxel
      # on (0,0,0) [False] or (0.5,0.5,0.5) [True]
      voxel_centered=False, 
    )
  # Transform to trimesh object
  old_meshes[obj_id] = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
  # Free up memory
  old_mesher.erase(obj_id)
  del mesh

### Same for gt meshes

In [22]:
# Make mesher for voxel size (4, 4, 40) [(x, y, z)]
gt_mesher = Mesher((4,4,40))
gt_mesher.mesh(gt_seg, close=True)


gt_meshes = dict()
for obj_id in gt_mesher.ids():
  mesh = gt_mesher.get(
      obj_id, 
      normals=False, # whether to calculate normals or not

      # tries to reduce triangles by this factor
      # 0 disables simplification
      reduction_factor=0, 

      # Max tolerable error in physical distance
      # note: if max_error is not set, the max error
      # will be set equivalent to one voxel along the 
      # smallest dimension.
      max_error=8,
      # whether meshes should be centered in the voxel
      # on (0,0,0) [False] or (0.5,0.5,0.5) [True]
      voxel_centered=False, 
    )
  gt_meshes[obj_id] = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
  # Free up memory
  gt_mesher.erase(obj_id) # delete high res mesh
  del mesh

In [23]:
mesh_lut = compute_mesh_lut(gt_meshes, nodes_position)

In [24]:
# Calculate overlap between mesh lut and segmentation lut
(mesh_lut == node_segment_lut).sum()/mesh_lut.size

0.8652134693535938

In [25]:
# Mask segment ID is not currently implemented with mesh lut
score = compute_erl_score(erl_graph=gt_graph,
    node_segment_lut=mesh_lut,
    mask_segment_id=[],
    merge_threshold=0,
    verbose=True)

score.compute_erl(None)
score.print_erl()

all skel
ERL	: 3.07
gt ERL	: 12043.16
#skel	: 484
-----------------


## Calculate number of merges/splits

In [36]:
# Funkelab implementation

# get unique segments of gt_graph
skeleton_segment, count = np.unique(
        np.hstack([gt_graph.nodes[:, :1], node_segment_lut.reshape(-1, 1)]),
        axis=0,
        return_counts=True,
    )

# how many skeletons are in each gt segment
segments, num_segment_skeletons = np.unique(
        skeleton_segment[:, 1], 
        return_counts=True
    )

merges = {}
# segments with merge errors have more than one gt skeleton assigned
merging_segments = segments[num_segment_skeletons > 1]
# keep those segments that merge
merging_segments_mask = np.isin(skeleton_segment[:, 1], merging_segments)
merged_segments = skeleton_segment[:, 1][merging_segments_mask]

merged_skeletons = skeleton_segment[:, 0][merging_segments_mask]

# collect ids of segments that are merged together
for segment, skeleton in zip(merged_segments, merged_skeletons):
    if segment not in merges:
        merges[segment] = []
    merges[segment].append(skeleton)

In [34]:
splits = {}
n_splits = 0
for skeleton_id in old_skels:
    
    skeletons = old_skels[skeleton_id]
    
    # for coordinates of each edge, collect which segment they are assigned to
    for u, v in skeletons.edges:
        segment_u = node_segment_lut[u]
        segment_v = node_segment_lut[v]

    # if the edges are not both assigned to the same segment, there is a split error
    if segment_u != segment_v:
        n_splits += 1

        # collect which segment was split to which ids
        if skeleton_id not in splits:
            splits[skeleton_id] = []
        splits[skeleton_id].append((segment_u, segment_v))

### Count Mergers and Splits, return number of problematic skeletons

In [None]:
# Reformulation of above code to collect all problematic segments that require human proofreading

problem_segments = []
n_merges = 0
n_splits = 0


# Mergers
# TODO: why different number of mergers than funkelab implementation?
skeleton_segment, count = np.unique(
        np.hstack([gt_graph.nodes[:, :1], node_segment_lut.reshape(-1, 1)]),
        axis=0,
        return_counts=True,
    )
segments, num_segment_skeletons = np.unique(
        skeleton_segment[:, 1], return_counts=True
    )

merged_segments = segments[num_segment_skeletons > 1]
merging_segments_mask = np.isin(skeleton_segment[:, 1], merged_segments)

merged_segments = skeleton_segment[:, 1][merging_segments_mask]

merged_skeletons = np.unique(
    skeleton_segment[np.isin(skeleton_segment[:, 1], merged_segments), 0]
)

n_mergers = len(merged_segments)

problem_segments += list(merged_segments)


# Splits TODO: probably inefficient
for skeleton_id in proof_skels:
    
    skeletons = proof_skels[skeleton_id]
    
    for u, v in skeletons.edges:
        segment_u = node_segment_lut[u]
        segment_v = node_segment_lut[v]

    if segment_u != segment_v:
        n_splits += 1

        if skeleton_id not in problem_segments:
            problem_segments += [skeleton_id]
