# Segment, localise and track


In [28]:
from pathlib import Path
import numpy as np
import zarr
import torch
from tqdm.auto import tqdm
from cellpose import models

import btrack
import btrack.io
import btrack.utils

from macrohet.label import segment, localise, track

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

ModuleNotFoundError: No module named 'cellpose'

### Load images

# WIP - deciding which example data and how to host it 

think downsizing the images by a factor of 5 will make sharing and archiving over github/zenodo much more practical

In [25]:
zarr_path = Path("../data/example_data.zarr")

In [27]:
zarr_root = zarr.open_group(zarr_path, mode="r")  # important: open_group, not open_array
images = zarr_root["images"]  # now this will work

KeyError: 'images'

### Create segmentation map

In [None]:
masks = np.stack([segment(frame, model=model) 
                  for frame in tqdm(images[:,0,...],  # segmenting the GFP channel 
                                    desc = 'Segmenting')])

### Save segmentation out 

#### Option 1: using btrack and h5 compression

In [None]:
segmentation_output_fn = '../data/segmentation.h5'

with btrack.io.HDF5FileHandler(segmentation_fn, 
                                   'w', 
                                   obj_type='obj_type_1'
                                   ) as writer:
        writer.write_segmentation(masks)

#### Option 2: using Zarr (NGFF-style layout)

In [None]:
segmentation_output_fn = '../data/example_data.zarr'
label_group_path = Path(segmentation_output_fn) / 'labels' / '0'

label_group_path.mkdir(parents=True, exist_ok=True)

zarr.save_array(
    store=label_group_path,
    arr=masks,  # this should be your stacked segmentation array
    compressor=zarr.Blosc(cname='zstd', clevel=5),
    overwrite=True
)

# attach NGFF label metadata to Zarr root
zarr_root = zarr.open_group(segmentation_output_fn, mode='a')
zarr_root.attrs['labels'] = [{"path": "labels/0", "type": "label"}]

### Quantify intracellular Mtb for each cell segment

In [None]:
# define thresholds
segment_size_thresh = 5000
mtb_load_thresh = 480 # determined via blind thresholding

In [None]:
# Determine thresholded Mtb presence across the specified Mtb channel
manual_mtb_thresh = images[:, mtb_channel, ...] >= mtb_load_thresh

# Construct a composite intensity image with GFP, RFP, and thresholded Mtb signal
# Shape: (T, Y, X, 3) — last axis channels: GFP, RFP, Mtb mask for regionprops 
intensity_image = np.stack([
    images[:, 0, ...],                # GFP channel
    images[:, 1, ...],                # RFP channel
    manual_mtb_thresh.astype(bool)    # Thresholded Mtb presence (binary)
], axis=-1)

# localise objects
objects = localise(masks, 
                   intensity_image, 
                   )

# filter out objects that are too small to be cells
objects = [o for o in objects if o.properties['area'] > segment_size_thresh]

# add label for infection
for obj in objects:
    obj.properties = ({"Infected": True} 
                        if obj.properties['mean_intensity'][2] > 0 # index 2 for manual mtb channel 
                        else {"Infected": False})
    obj.properties = ({"Mtb area px": obj.properties['mean_intensity'][2]*obj.properties['area']}) # index 2 for manual mtb channel 

### Save out single-cell quantifications prior to tracking across time

In [None]:
objects_output_fn = '../data/objects.h5'
with btrack.io.HDF5FileHandler(objects_output_fn), 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
                writer.write_objects(objects)

### Track single-cell objects across the time lapse

In [None]:
config_fn = '../models/tracking_model.json'
tracks = track(objects, masks, config_fn, search_radius = 20)

### Save out tracks

#### Option 1: using btrack and h5 compression

In [None]:
tracks_output_fn = '../data/tracks.h5'
with btrack.io.HDF5FileHandler(tracks_output_fn), 
                                   'w', 
                                   obj_type='obj_type_1'
                                   ) as writer:
        writer.write_tracks(tracks)

#### Option 2: using Zarr (NGFF-style layout)

In [None]:
# Flatten all timepoints from all tracks
tracklets = [
    (track_id, p.t, p.y, p.x, p.properties)
    for track_id, track in enumerate(tracks)
    for p in track
]

# Extract main track array: [track_id, t, y, x]
track_array = np.array([
    (tid, t, y, x)
    for tid, t, y, x, _ in tracklets
], dtype=np.float32)

# Extract features
features = {
    "area": np.array([p["area"] for _, _, _, _, p in tracklets], dtype=np.float32),
    "orientation": np.array([p["orientation"] for _, _, _, _, p in tracklets], dtype=np.float32),
    "major_axis_length": np.array([p["major_axis_length"] for _, _, _, _, p in tracklets], dtype=np.float32),
    "minor_axis_length": np.array([p["minor_axis_length"] for _, _, _, _, p in tracklets], dtype=np.float32),
    "mean_intensity": np.stack([p["mean_intensity"] for _, _, _, _, p in tracklets]).astype(np.float32)
}

# Write to Zarr
store = zarr.open(zarr_path, mode="a")

# Main Napari-compatible tracks array
store.create_dataset("tracks", data=track_array, compressor=zarr.Blosc(), overwrite=True)

# Features group
feat_grp = store.require_group("features")
for key, arr in features.items():
    feat_grp.create_dataset(key, data=arr, compressor=zarr.Blosc(), overwrite=True)

# Add Napari track metadata
store.attrs["tracks_metadata"] = {
    "format_version": "0.1",
    "type": "napari_tracks",
    "columns": ["track_id", "time", "y", "x"]
}
