In [1]:
import os

from pathlib import Path

import napari
import numpy as np
from tqdm import tqdm
from rich.pretty import pprint


from ultrack import track, to_tracks_layer, tracks_to_zarr
from ultrack.imgproc import normalize
from ultrack.utils import estimate_parameters_from_labels, labels_to_edges
from ultrack.utils.array import array_apply, create_zarr
from ultrack.config import MainConfig
from ultrack.core.solve.processing import solve
from ultrack.imgproc.segmentation import reconstruction_by_dilation, Cellpose
from ultrack.utils.cuda import import_module, to_cpu, torch_default_device

import dask.array as da

import matplotlib.pyplot as plt

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
# frames_start = 115
# frames_stop = 125
# row_start = 4250
# row_stop = 4750
# col_start = 4200
# col_stop = 4900

frames_start = 0
frames_stop = -1
row_start = 0
row_stop = -1
col_start = 0
col_stop = -1

In [3]:
zarr_masks_path = r'D:\kasia\tracking\E6_exp\E6_small_masks.zarr'

dask_masks = da.from_zarr(zarr_masks_path,0)[frames_start:frames_stop, row_start:row_stop, col_start:col_stop]
dask_masks.shape

(10, 500, 700)

In [4]:
ch0_path = r'D:\kasia\tracking\E6_exp\E6_C0.zarr'
ch1_path = r'D:\kasia\tracking\E6_exp\E6_C1.zarr'

ch0_da = da.from_zarr(ch0_path,1)[frames_start:frames_stop, row_start:row_stop, col_start:col_stop]
ch1_da = da.from_zarr(ch1_path,1)[frames_start:frames_stop, row_start:row_stop, col_start:col_stop]

print(ch0_da.shape)
print(ch1_da.shape)

(10, 500, 700)
(10, 500, 700)


In [30]:
chunks = (1, 512, 512)

normalized = create_zarr(ch1_da.shape, np.float16, "normalized.zarr", chunks=chunks, overwrite=True)
array_apply(
    ch1_da,
    out_array=normalized,
    func=normalize,
    gamma=0.5,
)

Applying normalize ...: 100%|██████████| 10/10 [00:00<00:00, 14.89it/s]


In [35]:
torch_default_device()

device(type='cpu')

In [50]:
cellpose_labels = create_zarr(ch1_da.shape, np.uint32, "240311_cellpose_labels.zarr", chunks=chunks, overwrite=True)

array_apply(
    normalized,
    out_array=cellpose_labels,
    func=Cellpose(model_type="cyto2", device=torch_default_device()),
    tile=False,
    normalize=False,
    diameter = 50,
)

Applying Cellpose ...: 100%|██████████| 10/10 [00:11<00:00,  1.12s/it]


In [51]:
zarr_masks_path2 = r'D:\kasia\tracking\E6_exp\code\tracks_interactions\examples\240311_cellpose_labels.zarr'

dask_masks2 = da.from_zarr(zarr_masks_path2)
dask_masks2.shape

(10, 500, 700)

In [38]:
# test if it can be run from a list of labels

detection2, edges2 = labels_to_edges(dask_masks2, 
                                   detection_store_or_path="detection2.zarr", 
                                   edges_store_or_path="edges2.zarr", 
                                   overwrite=True, 
                                   sigma=4.0)

Converting labels to edges:   0%|          | 0/10 [00:00<?, ?it/s]

Converting labels to edges: 100%|██████████| 10/10 [00:00<00:00, 36.17it/s]


In [45]:
config = MainConfig()
config.segmentation_config.min_area = 400
config.segmentation_config.max_area = 2000
config.segmentation_config.min_frontier = 0.1
config.segmentation_config.n_workers = 10

config.linking_config.max_distance = 25
config.linking_config.n_workers = 10

config.tracking_config.appear_weight = -1
config.tracking_config.disappear_weight = -1
config.tracking_config.division_weight = -0.1
config.tracking_config.power = 4
config.tracking_config.bias = -0.001
config.tracking_config.solution_gap = 0.0

config.tracking_config.window_size = 50
config.tracking_config.overlap_size = 5

pprint(config)

In [43]:
detection = da.from_zarr('detection.zarr')#(r"D:\kasia\tracking\E6_exp\code\tests\detection.zarr")
edges = da.from_zarr('edges.zarr')#(r"D:\kasia\tracking\E6_exp\code\tests\edges.zarr")
detection2 = da.from_zarr('detection2.zarr')#(r"D:\kasia\tracking\E6_exp\code\tests\detection.zarr")
edges2 = da.from_zarr('edges2.zarr')#(r"D:\kasia\tracking\E6_exp\code\tests\edges.zarr")

In [None]:
# tracking directly from the labels
# may require more memory to keep edges in
track(
    labels=dask_masks,
    sigma=4.0,
    config=config,
    overwrite=True, #"solutions" - to recalculate from the database
)

In [53]:
track(
    labels= [dask_masks, dask_masks2],
    config=config,
    overwrite=True,
)

Converting labels to edges: 100%|██████████| 10/10 [00:00<00:00, 17.06it/s]
Adding nodes to database: 100%|██████████| 10/10 [00:10<00:00,  1.00s/it]
Linking nodes.: 100%|██████████| 9/9 [00:10<00:00,  1.13s/it]


Using Gurobi solver
Solving ILP batch 0
Constructing ILP ...
Solving ILP ...
Saving solution ...
Done!


In [54]:
tracks_df, graph = to_tracks_layer(config)
labels = tracks_to_zarr(config, tracks_df)

Exporting segmentation masks:   0%|          | 0/10 [00:00<?, ?it/s]

Exporting segmentation masks: 100%|██████████| 10/10 [00:00<00:00, 117.65it/s]


In [55]:
viewer = napari.Viewer()
viewer.add_image(ch0_da, name="ch0", colormap="green", blending="additive")
viewer.add_image(ch1_da, name="ch1", colormap="red", blending="additive")
viewer.add_labels(dask_masks)
viewer.add_image(edges, blending="additive", colormap="magma")
viewer.add_image(edges2, blending="additive", colormap="magma")
viewer.add_labels(labels)


  qapp = get_app()


<Labels layer 'labels' at 0x1ed110d9660>

In [52]:
viewer.add_labels(dask_masks2)

<Labels layer 'dask_masks2' at 0x1ed2ab49ab0>