# Benchmarking Cell Tracking Challenge Data

In [1]:
import os
import pprint
import urllib.request
import zipfile

from tqdm import tqdm

from traccuracy import run_metrics
from traccuracy.loaders import load_ctc_data
from traccuracy.matchers import CTCMatcher, IOUMatcher
from traccuracy.metrics import CTCMetrics, DivisionMetrics

pp = pprint.PrettyPrinter(indent=4)

In [2]:
url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip"
data_dir = 'downloads'

if not os.path.exists(data_dir):
    os.mkdir(data_dir)

filename = url.split('/')[-1]
file_path = os.path.join(data_dir, filename)
ds_name = filename.split('.')[0]

In [3]:
# Add a utility to make a progress bar when downloading the file
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

if not os.path.exists(file_path):
    print(f"Downloading {ds_name} data from the CTC website")
    # Downloading data
    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, file_path, reporthook=t.update_to)
    # Unzip the data
    # TODO add a progress bar to zip as well
    with zipfile.ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

In [4]:

gt_data = load_ctc_data(
    'downloads/Fluo-N2DL-HeLa/01_GT/TRA',
    'downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt',
    name='Hela-01_GT'
)
pred_data = load_ctc_data(
    'sample-data/Fluo-N2DL-HeLa/01_RES',
    'sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt',
    name='Hela-01_RES'
)
# remove triple divisions from solution
for node, deg in pred_data.graph.out_degree():
    if deg > 2:
        # randomly remove edges
        out_edges = list(pred_data.graph.out_edges(node))
        pred_data.graph.remove_edges_from(out_edges[2:])

Loading TIFFs: 100%|██████████| 91/91 [00:00<00:00, 462.79it/s]
Computing node attributes: 100%|██████████| 92/92 [00:00<00:00, 279.78it/s]
1 non-connected masks at t=23.
2 non-connected masks at t=52.
Loading TIFFs: 100%|██████████| 91/91 [00:00<00:00, 729.92it/s]
Computing node attributes: 100%|██████████| 92/92 [00:00<00:00, 272.27it/s]


Run CTC metrics with additional evaluation of division events.

Use an IOU matcher which supports a minimum threshold for overlap and run division metrics.

In [5]:
matcher = IOUMatcher(iou_threshold=0.01, one_to_one=True)
matched = matcher.compute_mapping(gt_data, pred_data)
metric = DivisionMetrics()
metric.compute(matched)
metric = CTCMetrics()
metric.compute(matched)

Matching frames: 100%|██████████| 92/92 [00:00<00:00, 134.14it/s]
Evaluating nodes: 100%|██████████| 8600/8600 [00:00<00:00, 611787.90it/s]
Evaluating FP edges: 100%|██████████| 8533/8533 [00:00<00:00, 872309.73it/s]
Evaluating FN edges: 100%|██████████| 8562/8562 [00:00<00:00, 1101077.14it/s]


<traccuracy.metrics._base.Results at 0x12430fbe0>

In [6]:
import traccuracy
import numpy as np
import networkx as nx
import motile_tracker.data_model


def ensure_unique_labels(
    segmentation: np.ndarray,
) -> np.ndarray:
    """Relabels the segmentation in place to ensure that label ids are unique across
    time. This means that every detection will have a unique label id.
    Useful for combining predictions made in each frame independently, or multiple
    segmentation outputs that repeat label IDs.

    Args:
        segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x).

    Returns:
    """
    segmentation = segmentation.astype(np.uint64)
    label_map = {}
    curr_max = 0
    for idx in range(segmentation.shape[0]):
        frame = segmentation[idx]
        old_labels = np.unique(frame).tolist()
        for old in old_labels:
            if old != 0:
                label_map[(int(old), int(idx))] = old + curr_max
        frame[frame != 0] += curr_max
        curr_max = int(np.max(frame))
        segmentation[idx] = frame
    return segmentation, label_map

def traccuracy_graph_to_solution_tracks(traccuracy_graph: traccuracy.TrackingGraph) -> motile_tracker.data_model.SolutionTracks:
    graph = traccuracy_graph.graph
    seg = traccuracy_graph.segmentation
    time_attr = traccuracy_graph.frame_key
    label_attr = traccuracy_graph.label_key
    loc_keys = traccuracy_graph.location_keys
    id_map = None

    if seg is not None:
        assert label_attr is not None
        # ensure unique labels
        if len(set(np.unique(seg[0])).intersection(set(np.unique(seg[1])))) > 1:
            print("relabeling segmentation to be unique")
            # update segmentation labels to be unique
            seg, label_map = ensure_unique_labels(seg)
            print(label_map)
            # update label attribute to be accurate
            for node in graph.nodes:
                old_label = int(graph.nodes[node][label_attr])
                time = int(graph.nodes[node][time_attr])
                new_label = label_map[(old_label, time)] if (old_label, time) in label_map else old_label
                graph.nodes[node][label_attr] = new_label
        # set node ids to the label attribute
        id_map = {node: graph.nodes[node][label_attr] for node in graph.nodes}
        print(id_map)
        # can set copy to false if not in a notebook and rerunning the cell
        graph = nx.relabel_nodes(graph, id_map, copy=True)
    else:
        # set node ids to be integers
        if not isinstance(next(graph.nodes), int):
            id_map = {node: i for i, node in enumerate(graph.nodes)}
            graph = nx.relabel_nodes(graph, id_map, copy=True)
    
    solution_tracks = motile_tracker.data_model.SolutionTracks(
        graph,
        segmentation=seg,
        time_attr=time_attr,
        pos_attr=loc_keys,
        ndim=len(loc_keys) + 1,
    )
    return solution_tracks, id_map


In [7]:
import motile_tracker
import napari
import motile_tracker.application_menus
import motile_tracker.data_views
from motile_tracker.data_views.views.tree_view.tree_widget import TreeWidget

viewer = napari.Viewer()
menu_widget = motile_tracker.application_menus.MenuWidget(viewer)
gt_viewer = motile_tracker.data_views.TracksViewer(viewer)
pred_viewer = motile_tracker.data_views.TracksViewer(viewer)

gt_tree_widget = TreeWidget(gt_viewer)
pred_tree_widget = TreeWidget(pred_viewer)
viewer.window.add_dock_widget(gt_tree_widget, area="bottom")
viewer.window.add_dock_widget(pred_tree_widget, area="bottom")
viewer.window.add_dock_widget(menu_widget, area="right")

Making new tracking view controller


<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x2d8d0f0a0>

In [8]:

gt_tracks, gt_id_map = traccuracy_graph_to_solution_tracks(matched.gt_graph)
gt_viewer.tracks_list.add_tracks(gt_tracks, name="gt_iou_matched")


relabeling segmentation to be unique
{(1, 0): 1, (10, 0): 10, (12, 0): 12, (16, 0): 16, (22, 0): 22, (25, 0): 25, (30, 0): 30, (41, 0): 41, (50, 0): 50, (51, 0): 51, (55, 0): 55, (63, 0): 63, (76, 0): 76, (86, 0): 86, (94, 0): 94, (98, 0): 98, (102, 0): 102, (112, 0): 112, (125, 0): 125, (133, 0): 133, (137, 0): 137, (141, 0): 141, (151, 0): 151, (161, 0): 161, (169, 0): 169, (177, 0): 177, (186, 0): 186, (189, 0): 189, (197, 0): 197, (200, 0): 200, (204, 0): 204, (225, 0): 225, (257, 0): 257, (270, 0): 270, (284, 0): 284, (288, 0): 288, (305, 0): 305, (315, 0): 315, (326, 0): 326, (333, 0): 333, (353, 0): 353, (363, 0): 363, (366, 0): 366, (1, 1): 367, (10, 1): 376, (12, 1): 378, (16, 1): 382, (22, 1): 388, (25, 1): 391, (30, 1): 396, (41, 1): 407, (50, 1): 416, (51, 1): 417, (55, 1): 421, (63, 1): 429, (76, 1): 442, (86, 1): 452, (94, 1): 460, (98, 1): 464, (102, 1): 468, (112, 1): 478, (125, 1): 491, (133, 1): 499, (137, 1): 503, (141, 1): 507, (151, 1): 517, (161, 1): 527, (169, 1)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [9]:
pred_tracks, pred_id_map = traccuracy_graph_to_solution_tracks(matched.pred_graph)
pred_viewer.tracks_list.add_tracks(pred_tracks, name="pred_iou_matched")

relabeling segmentation to be unique
{(1, 0): 1, (4, 0): 4, (5, 0): 5, (6, 0): 6, (9, 0): 9, (10, 0): 10, (17, 0): 17, (24, 0): 24, (27, 0): 27, (28, 0): 28, (31, 0): 31, (38, 0): 38, (41, 0): 41, (52, 0): 52, (55, 0): 55, (58, 0): 58, (61, 0): 61, (68, 0): 68, (75, 0): 75, (78, 0): 78, (83, 0): 83, (88, 0): 88, (111, 0): 111, (118, 0): 118, (125, 0): 125, (128, 0): 128, (135, 0): 135, (142, 0): 142, (143, 0): 143, (167, 0): 167, (170, 0): 170, (181, 0): 181, (193, 0): 193, (206, 0): 206, (211, 0): 211, (220, 0): 220, (221, 0): 221, (224, 0): 224, (231, 0): 231, (234, 0): 234, (243, 0): 243, (246, 0): 246, (259, 0): 259, (1, 1): 260, (4, 1): 263, (5, 1): 264, (6, 1): 265, (9, 1): 268, (10, 1): 269, (17, 1): 276, (24, 1): 283, (27, 1): 286, (28, 1): 287, (31, 1): 290, (38, 1): 297, (41, 1): 300, (52, 1): 311, (55, 1): 314, (58, 1): 317, (61, 1): 320, (68, 1): 327, (75, 1): 334, (78, 1): 337, (83, 1): 342, (88, 1): 347, (111, 1): 370, (118, 1): 377, (125, 1): 384, (128, 1): 387, (135, 1)

In [10]:
def get_unique_keys(graph: nx.Graph):
    keys = set()
    for node in graph.nodes():
        keys = keys.union(set(graph.nodes[node].keys()))
    return keys

print(get_unique_keys(matched.pred_graph.graph))
print(get_unique_keys(matched.gt_graph.graph))

{'y', 'segmentation_id', 't', <NodeFlag.CTC_TRUE_POS: 'is_ctc_tp'>, <NodeFlag.FP_DIV: 'is_fp_division'>, 'x', <NodeFlag.TP_DIV: 'is_tp_division'>}
{'y', 'segmentation_id', 't', <NodeFlag.CTC_TRUE_POS: 'is_ctc_tp'>, 'x', <NodeFlag.TP_DIV: 'is_tp_division'>, <NodeFlag.FN_DIV: 'is_fn_division'>, <NodeFlag.CTC_FALSE_NEG: 'is_ctc_fn'>}


In [18]:
from traccuracy import NodeFlag
from napari.utils import DirectLabelColormap
import pandas as pd
node_error_colormap = {
    NodeFlag.CTC_FALSE_POS: [1., 0., 0., 1.],  # red
    NodeFlag.NON_SPLIT: [1., 0., 0., 1.],  # red
    NodeFlag.CTC_FALSE_NEG: [1., 0., 0., 1.],  # red
    NodeFlag.FN_DIV: [1., 0., 0., 1.],  # red
    NodeFlag.FP_DIV: [1., 0., 0., 1.],  # red
    NodeFlag.WC_DIV: [1., 0., 0., 1.],  # red
}
from motile_tracker.data_views import TracksViewer, TreeWidget
def color_nodes_by_error(tracks_viewer: TracksViewer, tree_widget: TreeWidget):
    layers = tracks_viewer.tracking_layers
    tracks = layers.tracks
    points_layer = layers.points_layer
    # for tree widget
    df = tree_widget.track_df
    tree_plot_color_list = df["color"].to_list()
    # for seg layer
    colormap_dict = layers.seg_layer.colormap.color_dict
    # points colors updated per point directly on layer

    for idx, node in enumerate(points_layer.nodes):
        node_color = [1., 1., 1., 1.]  # white
        for node_error, color in node_error_colormap.items():
            if tracks._get_node_attr(node, node_error) is not None:
                node_color = color
        
        # if node_color is not None:
        # update points layer
        points_layer.face_color[idx] = node_color
        # update seg layer
        colormap_dict[node] = node_color

        # update tree widget
        # get the row where node_id == node and set color column value to 255*color
        # I hate this code I am so sorry
        target_mask = (df["node_id"] == node).to_list()
        target_idx = target_mask.index(True)
        scaled_color = np.array(node_color) * 255
        tree_plot_color_list[target_idx] = scaled_color
    df["color"] = tree_plot_color_list
    
    # update seg layer colormap
    layers.seg_layer.colormap = DirectLabelColormap(color_dict=colormap_dict)
    
    layers._refresh
    # refresh tree widget
    tree_widget.tree_widget.set_data(df, tree_widget.feature)
    tree_widget.tree_widget._update_viewed_data(tree_widget.view_direction)


In [19]:
node_error_colormap = {
    NodeFlag.CTC_FALSE_POS: [1., 0., 0., 1.],  # red
    NodeFlag.NON_SPLIT: [1., 0., 0., 1.],  # red
    NodeFlag.CTC_FALSE_NEG: [1., 0., 0., 1.],  # red
    NodeFlag.FN_DIV: [1., 0., 0., 1.],  # red
    NodeFlag.FP_DIV: [1., 0., 0., 1.],  # red
    NodeFlag.WC_DIV: [0., 0., 1., 1.],  # blue
}
color_nodes_by_error(gt_viewer, gt_tree_widget)
node_error_colormap = {
    NodeFlag.CTC_FALSE_POS: [0., 1., 0., 1.],  # green
    NodeFlag.NON_SPLIT: [0., 1., 0., 1.],  # green
    NodeFlag.CTC_FALSE_NEG: [0., 1., 0., 1.],  # green
    NodeFlag.FN_DIV: [0., 1., 0., 1.],  # green
    NodeFlag.FP_DIV: [0., 1., 0., 1.],  # green
    NodeFlag.WC_DIV: [0., 0., 1., 1.],  # blue
}
color_nodes_by_error(pred_viewer, pred_tree_widget)