## Import packages

In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
# Notebook at full width in the browser
from IPython.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

import time
from pathlib import Path

import skimage
import pandas as pd
import numpy as np
import napari
from tqdm import tqdm
import networkx as nx
import plotly.io as pio

# pio.renderers.default = "notebook_connected"
pio.renderers.default = "iframe"

import motile
from motile.plot import draw_track_graph, draw_solution
# from utils import InOutSymmetry, MinTrackLength

import traccuracy
from traccuracy import run_metrics
from traccuracy.matchers import CTCMatched
from traccuracy.metrics import CTCMetrics, DivisionMetrics
# from KL_load_data import load_raw_masks

from tqdm import tqdm
import zarr

# Pretty tqdm progress bars
! jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [8]:
# zarrpath = "/mnt/efs/shared_data/YeastiBois/zarr_files/Tien/glass_60x_023_RawMasks.zarr"

zarrpath = '/mnt/efs/shared_data/YeastiBois/zarr_files/Tien/glass_60x_023_RawMasksProbs.zarr'
zarrfile =  zarr.open(zarrpath,'r')
mask = zarrfile['masks'] #segmentation mask
raw = zarrfile['raw']
probs = zarrfile['probs']
# unique = np.unique(mask) #number of unique labeles in the segmentation mask
# fullmask = mask[:] #load the full mask into memory
# nonzero_unique = unique[1:] #zero is empty space

raw_MembraneChannel = raw[:,:,0]

# mask_frame = mask[0].astype(int)
#remove small objects to clean up masks
mask_clean = np.stack([skimage.morphology.remove_small_objects(m.astype(int), min_size=1000) for m in tqdm(mask)])

# NapariViewer = napari.Viewer()
# NapariViewer.add_image(raw_MembraneChannel,name='raw')
# NapariViewer.add_labels(mask_clean,name='mask')

100%|████████████████████████████████████████████████████████████████████████| 15/15 [00:09<00:00,  1.66it/s]


In [5]:
def visualize_tracks(viewer, y, links=None, name=""):
    """Utility function to visualize segmentation and tracks

    Args:
        viewer: napari viewer
        y: labels: list of 2D arrays, each array is a label image.
        links: np.ndarray, each row is a link (parent, child, parent_frame, child_frame).

    Returns:
        tracks: np.ndarray, shape (N, 4)
    """
    max_label = max(links.max(), y.max()) if links is not None else y.max()
    colorperm = np.random.default_rng(42).permutation(np.arange(1, max_label + 2))
    tracks = []
    for t, frame in enumerate(y):
        centers = skimage.measure.regionprops(frame)
        for c in centers:
            tracks.append(
                [colorperm[c.label], t, int(c.centroid[0]), int(c.centroid[1])]
            )
    tracks = np.array(tracks)
    tracks = tracks[tracks[:, 0].argsort()]

    graph = {}
    if links is not None:
        divisions = links[links[:, 3] != 0]
        for d in divisions:
            if (
                colorperm[d[0]] not in tracks[:, 0]
                or colorperm[d[3]] not in tracks[:, 0]
            ):
                continue
            graph[colorperm[d[0]]] = [colorperm[d[3]]]

    viewer.add_labels(y, name=f"{name}_detections")
    # viewer.layers[f"{name}_detections"].contour = 3
    viewer.add_tracks(tracks, name=f"{name}_tracks", graph=graph)
    return tracks

In [6]:
# viewer = napari.viewer.current_viewer()
# if viewer:
#     viewer.close()
# viewer = napari.Viewer()
# viewer.add_image(img)
# #visualize_tracks(viewer, labels, links.to_numpy(), "ground_truth")
# visualize_tracks(viewer, labels)
# #viewer.add_labels(det, name="detections")
# #viewer.grid.enabled = True

In [7]:
viewer = napari.viewer.current_viewer()
if viewer:
    viewer.close()

#### Build the ground truth graph, as well as a candidate graph from the detections

In [22]:
def build_graph(detections, max_distance, node_features=None, drift=(0, 0, 0)):
    """Build a candidate graph from a list of detections.

     Args:
        detections: list of 2D arrays, each array is a label image.
            Labels are expected to be consecutive integers starting from 1, background is 0.
        max distance: maximum distance between centroids of two detections to place a candidate edge.
        node_features: dict of arrays corresponding to global node ids
        drift: (y, x) tuple for drift correction in euclidian distance feature.
    Returns:
        G: motile.TrackGraph containing the candidate graph.
    """

    print("Build candidate graph")
    G = nx.DiGraph()

    print("add nodes")
    for t, d in tqdm(enumerate(detections)):
        regions = skimage.measure.regionprops(d)
        positions = []
        for i, r in enumerate(regions):
            draw_pos = int(r.centroid[1])
            if draw_pos in positions:
                draw_pos += 3  # To avoid overlapping nodes
            positions.append(draw_pos)
            
            features = {
                k: np.round(v[r.label-1], decimals=2).item() if v is not None else 1
                for k,v in node_features.items()
            }
            G.add_node(
                r.label-1,
                time=t,
                show=r.label,
                draw_position=draw_pos,
                z=int(r.centroid[0]),
                y=int(r.centroid[1]),
                x=int(r.centroid[2]),
                **features,
            )

    print("add edges")
    n_e = 0
    for t, (d0, d1) in tqdm(enumerate(zip(detections, detections[1:]))):
        r0 = skimage.measure.regionprops(d0)
        c0 = [np.array(r.centroid) for r in r0]

        r1 = skimage.measure.regionprops(d1)
        c1 = [np.array(r.centroid) for r in r1]

        for _r0, _c0 in zip(r0, c0):
            for _r1, _c1 in zip(r1, c1):
                dist = np.linalg.norm(_c0 + np.array(drift) - _c1)
                if dist < max_distance:
                    G.add_edge(
                        _r0.label - 1,
                        _r1.label - 1,
                        # before: 1 - normalized euclidian distance
                        feature=np.round(
                            np.linalg.norm(_c0 + np.array(drift) - _c1) / max_distance,
                            decimals=3,
                        ).item(),
                        edge_id=n_e,
                        show="?",
                    )
                    n_e += 1

    G = motile.TrackGraph(G, frame_attribute="time")

    return G

In [20]:

#ids are unique over the course of the video, cellpose relabeled from 1 every timestep
#global IDs

offset = 1
relabeled = []
for frame in tqdm(mask_clean):
    frame, _, _ = skimage.segmentation.relabel_sequential(frame, offset=offset)
    assert 0 in frame
    offset += len(np.unique(frame)) - 1
    relabeled.append(frame)

labels_global = np.stack(relabeled)
print(offset)

 20%|██████████████▌                                                          | 3/15 [00:03<00:15,  1.27s/it]


KeyboardInterrupt: 

In [13]:
labels_global.shape
labels_global.max()

1117

In [15]:
#create node feature
#det_probs = normalized size of regions
det_probs = []
for frame in tqdm(labels_global):
    regions = skimage.measure.regionprops(frame)
    for r in regions:
        det_probs.append(r.num_pixels)
det_probs = np.array(det_probs) / np.array(det_probs).max() #normalize by max

#Probs = probabilities
avg_probs = []
for _l, _p in tqdm(zip(labels_global, probs)):
    assert _p.min() >= 0 and _p.max() <= 1
    indices = np.unique(_l)
    for i in indices:
        if i == 0:
            continue
        a = _p[_l == i].mean()
        avg_probs.append(a)

100%|████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00,  7.33it/s]
15it [00:24,  1.61s/it]


In [30]:
# graph for how many frames
s = slice(0, 10)#10 frames
labels_crop = labels_global[s]
img_crop = raw_MembraneChannel[s]
candidate_graph = build_graph(
    labels_crop,
    max_distance=100,
    node_features={
        "size": det_probs,
        "avg_prob": avg_probs,
    },
    drift=(0,0,0)
)

Build candidate graph
add nodes


10it [00:08,  1.25it/s]


add edges


9it [00:06,  1.46it/s]


In [24]:
len(candidate_graph.nodes)

702

In [19]:
# Show candidate graphs

fig_candidate = draw_track_graph(
    candidate_graph,
    position_attribute="draw_position",
    width=1000,
    height=500,
    label_attribute="show",
    alpha_attribute="feature",
    node_size=25,
)
fig_candidate = fig_candidate.update_layout(
    title={
        "text": "Candidate graph",
        "y": 0.98,
        "x": 0.5,
    }
)
fig_candidate.show()

Here is a utility function to gauge some statistics of a solution.

In [25]:
def print_solution_stats(solver, graph, gt_graph=None):
    """Prints the number of nodes and edges for candidate, ground truth graph, and solution graph.

    Args:
        solver: motile.Solver, after calling solver.solve()
        graph: motile.TrackGraph, candidate graph
        gt_graph: motile.TrackGraph, ground truth graph
    """
    time.sleep(0.1)  # to wait for ilpy prints
    print(
        f"\nCandidate graph\t\t{len(graph.nodes):3} nodes\t{len(graph.edges):3} edges"
    )
    if gt_graph:
        print(
            f"Ground truth graph\t{len(gt_graph.nodes):3} nodes\t{len(gt_graph.edges):3} edges"
        )

    node_selected = solver.get_variables(motile.variables.NodeSelected)
    edge_selected = solver.get_variables(motile.variables.EdgeSelected)
    nodes = 0
    for node in candidate_graph.nodes:
        if solver.solution[node_selected[node]] > 0.5:
            nodes += 1
    edges = 0
    for u, v in candidate_graph.edges:
        if solver.solution[edge_selected[(u, v)]] > 0.5:
            edges += 1
    print(f"Solution graph\t\t{nodes:3} nodes\t{edges:3} edges")

### Recolor detections in napari according to solution and compare to ground truth

In [26]:
def solution2graph(solver, base_graph, detections, label_key="show"):
    """Convert a solver solution to a graph and corresponding dense selected detections.

    Args:
        solver: A solver instance
        base_graph: The base graph
        detections: The detections
        label_key: The key of the label in the detections
    Returns:
        track_graph: Solution as motile.TrackGraph
        graph: Solution as networkx graph
        selected_detections: Dense label array containing only selected detections
    """
    graph = nx.DiGraph()
    node_indicators = solver.get_variables(motile.variables.NodeSelected)
    edge_indicators = solver.get_variables(motile.variables.EdgeSelected)

    selected_detections = np.zeros_like(detections)

    # Build nodes
    for node, index in node_indicators.items():
        if solver.solution[index] > 0.5:
            node_features = base_graph.nodes[node]
            graph.add_node(node, **node_features)
            t = node_features[base_graph.frame_attribute]
            selected_detections[t][
                detections[t] == node_features[label_key]
            ] = node_features[label_key]

    # Build edges
    for edge, index in edge_indicators.items():
        if solver.solution[index] > 0.5:
            # print(base_graph.edges[edge])
            graph.add_edge(*edge, **base_graph.edges[edge])

    # Add cell division markers on edges for traccuracy
    for (u, v), features in graph.edges.items():
        out_edges = graph.out_edges(u)
        if len(out_edges) == 2:
            features["is_intertrack_edge"] = 1
        elif len(out_edges) == 1:
            features["is_intertrack_edge"] = 0
        else:
            raise ValueError()

    track_graph = motile.TrackGraph(graph, frame_attribute="time")

    return track_graph, graph, selected_detections

In [27]:
def recolor_segmentation(segmentation, graph, det_attribute="show"):
    """Recolor a segmentation based on a graph, such that each cell and its daughter cells have a unique color.

    Args:
        segmentation (np.ndarray): Predicted dense segmentation.
        graph (motile.TrackGraph): A directed graph representing the tracks.
        det_attribute (str): The attribute of the graph nodes that corresponds to ids in `segmentation`.

    Returns:
        out (np.ndarray): A recolored segmentation.
    """
    out = []
    n_tracks = 1
    color_lookup_tables = []

    for t in range(0, len(segmentation)):
        new_frame = np.zeros_like(segmentation[t])
        color_lut = {}
        for node_id in graph.nodes_by_frame(t):
            det_id = graph.nodes[node_id][det_attribute]
            if node_id not in graph.nodes:
                continue

            in_edges = []
            for u, v in graph.edges:
                if v == node_id:
                    in_edges.append((u, v))
            if not in_edges:
                new_frame[segmentation[t] == det_id] = n_tracks
                color_lut[det_id] = n_tracks
                n_tracks += 1
            else:
                for v_tm1, u_t0 in in_edges:
                    new_frame[
                        segmentation[t] == graph.nodes[u_t0][det_attribute]
                    ] = color_lookup_tables[t - 1][graph.nodes[v_tm1][det_attribute]]
                    color_lut[graph.nodes[u_t0][det_attribute]] = color_lookup_tables[
                        t - 1
                    ][graph.nodes[v_tm1][det_attribute]]

        color_lookup_tables.append(color_lut)
        out.append(new_frame)

    out = np.stack(out)
    return out

## Exercise 2.2 - ILP with track birth and death
<div class="alert alert-block alert-info"><h3>Exercise 2.2: Adapt the network flow from Exercise 2.1 such that tracks can start and end at arbitrary time points.</h3>

Hint: You will have to add both costs and constraints to the template below.
</div>

In [31]:
def solve(graph):
    """ILP allowing for appearance and disappearance of cells.

    Args:
        graph (motile.TrackGraph): The candidate graph.

    Returns:
        solver (motile.Solver): The solver.
    """

    solver = motile.Solver(graph)

    solver.add_costs(
        motile.costs.NodeSelection(
            weight=-1,
            attribute="size",
            constant=0,
        ),
        name="size",
    )
    solver.add_costs(
        motile.costs.NodeSelection(
            weight=-1,
            attribute="avg_prob",
            constant=0,
        ),
        name="avg_prob",
    )
    # weight * attribute + constant
    solver.add_costs(
        motile.costs.EdgeSelection(
            weight = 0.01,#+0.5 or 1...
            attribute="feature",
            constant=0,
        )
    )
    solver.add_costs(motile.costs.Appear(constant=0.5))
    # solver.add_costs(motile.costs.Split(constant=1))
    solver.add_costs(motile.costs.Disappear(constant=0.5))
    
    solver.add_constraints(motile.constraints.MaxParents(1))
    solver.add_constraints(motile.constraints.MaxChildren(1))

    solution = solver.solve()

    return solver

Run the optimization, and compare the found solution to the ground truth.

In [32]:
with_birth = solve(candidate_graph)
print_solution_stats(with_birth, candidate_graph)

Could not create Gurobi backend: Gurobi error in ilpy/impl/solvers/GurobiBackend.cpp:22: PIP license can only be used from gurobipy interface

Candidate graph		702 nodes	2712 edges
Solution graph		620 nodes	498 edges


In [33]:
# fig_birth = draw_solution(
#     candidate_graph,
#     with_birth,
#     position_attribute="draw_position",
#     width=1000,
#     height=500,
#     label_attribute="show",
#     node_size=25,
# )
# fig_birth = fig_birth.update_layout(
#     title={
#         "text": f"ILP formulation (no divisions) - cost: {with_birth.solution.get_value()}",
#         "y": 0.98,
#         "x": 0.5,
#     }
# )
# fig_birth.show()

In [34]:
recolored_birth = recolor_segmentation(
    labels_crop, graph=solution2graph(with_birth, candidate_graph, labels_crop)[0]
)

viewer = napari.viewer.current_viewer()
if viewer:
    viewer.close()
viewer = napari.Viewer()
viewer.add_image(img_crop)
viewer.add_labels(labels_crop)
# visualize_track
# visualize_tracks(viewer, recolored_birth)

viewer.add_labels(recolored_birth)
viewer.grid.enabled = True



In [36]:
#save tracks

# imgdata = img_crop
# labeldata = labels_crop
# savepath = '/home/tienc/Documents/trackvids/tracked.zarr/'
# with zarr.open(savepath,'w') as zarrsave:
#     zarrsave.create_dataset('img',data=imgdata)
#     zarrsave.create_dataset('label',data=labeldata)
# print('done')

done


In [None]:
viewer = napari.viewer.current_viewer()
if viewer:
    viewer.close()

In [None]:
_, birth_graph, birth_det = solution2graph(with_birth, candidate_graph, det)
get_metrics(gt_nx_graph, labels, birth_graph, birth_det)

In [None]:
viewer = napari.viewer.current_viewer()
if viewer:
    viewer.close()

In [None]:
_, ilp_graph, ilp_det = solution2graph(full_ilp, candidate_graph, det)
get_metrics(gt_nx_graph, labels, ilp_graph, ilp_det)

## Exercise 2.4 (Bonus)
<div class="alert alert-block alert-info"><h3>Exercise 2.4: Try to improve the ILP-based tracking from exercise 2.3</h3>

For example
- Tune the hyperparameters.
- Better edge features than drift-corrected euclidian distance.
- Tune the detection algorithm to avoid false negatives.

</div>