# Segment, localise and track


In [1]:
from macrohet import dataio, tile, notify
import numpy as np
from tqdm.auto import tqdm
from cellpose import models
import btrack 
import torch
import os
import dask.array as da
import glob
import zarr
import logging
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# # defining personal trained cellpose model to use
# model_path = '/home/dayn/analysis/models/cellpose/PS0000/macrohet_seg'
# model = models.CellposeModel(gpu=True, 
#                              pretrained_model=model_path)

# ORRRR test the new cellpose model
model = models.Cellpose(gpu=True, model_type='cyto3')

# Initialize the logging configuration
log_dir = "logs"  # Specify the directory where logs will be saved
os.makedirs(log_dir, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s]: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

# Add a FileHandler to save logs to a file in the specified directory
log_file = os.path.join(log_dir, "assay_processing.log")
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s")
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)

# Define a function to log progress and potential errors
def log_progress(position, message):
    logging.info(f"Position {position}: {message}")


Using device: cuda

NVIDIA RTX A6000
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


INFO:cellpose.core:** TORCH CUDA version installed and working. **
INFO:cellpose.core:>>>> using GPU
INFO:cellpose.models:>> cyto3 << model set to be used
INFO:cellpose.models:>>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)


### Define functions to tidy up main block of code

In [2]:
# define thresholds
segment_size_thresh = 5000
Mtb_load_thresh = 480

# define tracking scale factor
scale_factor = 1/5.04

# define features to use for tracking 
features = [
  "area",
  "major_axis_length",
  "minor_axis_length",
  "orientation",
  "mean_intensity",
    ]

# define tracker config fn to use, using a prob_not_assign = 0.1
config_fn = '/home/dayn/analysis/models/btrack/particle_config_pnassign.json'
# define tracker config fn to use
# config_fn = '/home/dayn/analysis/btrack/models/particle_config.json'

def segment(frame, model = model, channels = [0,0], diameter = 350,#325
            min_size = 5000 #2500
           ):
    
#     masks, flows, styles, diams = model.eval(frame, # for default model
#                                              channels = channels, 
#                                              diameter = diameter, 
#                                              min_size = min_size, 
#                                              )
    masks, flows, styles = model.eval(frame, # for personal model
                                      channels = channels, 
                                      diameter = diameter, 
                                      min_size = min_size, 
                                      )
    return masks


def localise(masks, intensity_image, properties=tuple(features), use_weighted_centroid = False):
    
    # localise objs in images
    objects = btrack.utils.segmentation_to_objects(segmentation=masks,
                                                   intensity_image=intensity_image, 
                                                   properties=properties,
                                                   scale=(scale_factor,scale_factor),
                                                   use_weighted_centroid=use_weighted_centroid, 
                                                   )
                                                   
    return objects


def track(objects, masks, config_fn, search_radius = 20):

    # initialise a tracker session using a context manager
    with btrack.BayesianTracker() as tracker:
        # configure the tracker using a config file
        tracker.configure(config_fn)
        # set max search radius
        tracker.max_search_radius = search_radius
        # define tracking method
        tracker.tracking_updates = ["MOTION", "VISUAL"]
        # redefine features so that both channels are included in track measurements
        tracker.features = list(objects[0].properties.keys())
        # append the objects to be tracked
        tracker.append(objects)
        # set the tracking volume
        tracker.volume=((0, masks.shape[-2]*scale_factor), (0, masks.shape[-1]*scale_factor))
        # track them (in interactive mode)
        tracker.track(step_size=25)
        # generate hypotheses and run the global optimizer
        tracker.optimize()
        # store the tracks
        tracks = tracker.tracks

    return tracks


def otsu_threshold_stack(images):
    """
    Function to characterise intra-Mφ Mtb load
    Computes Otsu's threshold value and returns a binary segmentation for
    each image in a time series of grayscale images.

    Parameters:
    -----------
    images : ndarray
        A 3D array of shape (n_images, height, width) containing a time series
        of grayscale images.

    Returns:
    --------
    ndarray
        A boolean array of shape (n_images, height, width) containing the
        binary segmentation for each image in the time series.
    """
    segmentations = np.zeros(images.shape, dtype=bool)
    for i, image in tqdm(enumerate(images), 
                         total=len(images), 
                         leave=False, 
                         desc='Otsu segmenting'):
        loaded_image = image.compute().compute()
        threshold = threshold_otsu(loaded_image)
        segmentations[i] = loaded_image > threshold
        
    return segmentations

### Load experiment of choice

The Opera Phenix is a high-throughput confocal microscope that acquires very large 5-dimensional (TCZXY) images over several fields of view in any one experiment. Therefore, a lazy-loading approach is chosen to mosaic, view and annotate these images. This approach depends upon Dask and DaskFusion. The first step is to load the main metadata file (typically called `Index.idx.xml` and located in the main `Images` directory) that contains the image filenames and associated TCXZY information used to organise the images.

In [3]:
base_dir = '/mnt/SYNO/macrohet_syno/ND0002/'
metadata_fn = os.path.join(base_dir, 'acquisition/Images/Index.idx.xml')
metadata = dataio.read_harmony_metadata(metadata_fn)  
metadata

Reading metadata XML file...


0it [00:00, ?it/s]

Extracting metadata complete!


Unnamed: 0,id,State,URL,Row,Col,FieldID,PlaneID,TimepointID,ChannelID,FlimID,...,PositionZ,AbsPositionZ,MeasurementTimeOffset,AbsTime,MainExcitationWavelength,MainEmissionWavelength,ObjectiveMagnification,ObjectiveNA,ExposureTime,OrientationMatrix
0,0103K1F1P1R1,Ok,r01c03f01p01-ch1sk1fk1fl1.tiff,1,3,1,1,0,1,1,...,-2E-06,0.135466397,0,2023-11-30T17:22:09.49+00:00,640,706,40,1.1,0.2,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
1,0103K1F1P1R2,Ok,r01c03f01p01-ch2sk1fk1fl1.tiff,1,3,1,1,0,2,1,...,-2E-06,0.135466397,0,2023-11-30T17:22:09.723+00:00,488,522,40,1.1,0.1,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
2,0103K1F1P2R1,Ok,r01c03f01p02-ch1sk1fk1fl1.tiff,1,3,1,2,0,1,1,...,0,0.135468394,0,2023-11-30T17:22:10.067+00:00,640,706,40,1.1,0.2,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
3,0103K1F1P2R2,Ok,r01c03f01p02-ch2sk1fk1fl1.tiff,1,3,1,2,0,2,1,...,0,0.135468394,0,2023-11-30T17:22:10.287+00:00,488,522,40,1.1,0.1,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
4,0103K1F1P3R1,Ok,r01c03f01p03-ch1sk1fk1fl1.tiff,1,3,1,3,0,1,1,...,2E-06,0.135470405,0,2023-11-30T17:22:10.627+00:00,640,706,40,1.1,0.2,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
388615,0612K150F9P1R2,Ok,r06c12f09p01-ch2sk150fk1fl1.tiff,6,12,9,1,149,2,1,...,-2E-06,0.1351538,268191.66,2023-12-03T20:06:16.08+00:00,488,522,40,1.1,0.1,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
388616,0612K150F9P2R1,Ok,r06c12f09p02-ch1sk150fk1fl1.tiff,6,12,9,2,149,1,1,...,0,0.135155797,268191.66,2023-12-03T20:06:16.423+00:00,640,706,40,1.1,0.2,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
388617,0612K150F9P2R2,Ok,r06c12f09p02-ch2sk150fk1fl1.tiff,6,12,9,2,149,2,1,...,0,0.135155797,268191.66,2023-12-03T20:06:16.657+00:00,488,522,40,1.1,0.1,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."
388618,0612K150F9P3R1,Ok,r06c12f09p03-ch1sk150fk1fl1.tiff,6,12,9,3,149,1,1,...,2E-06,0.135157794,268191.66,2023-12-03T20:06:17+00:00,640,706,40,1.1,0.2,"[[1.000989,0,0,10.0],[0,-1.000989,0,-6.8],[0,0..."


### View assay layout and mask information (optional)

The Opera Phenix acquires many time lapse series from a range of positions. The first step is to inspect the image metadata, presented in the form of an `Assaylayout/experiment_ID.xml` file, to show which positions correspond to which experimental assays.

In [4]:
metadata_path = glob.glob(os.path.join(base_dir, 'acquisition/Assaylayout/*.xml'))[0]
assay_layout = dataio.read_harmony_metadata(metadata_path, assay_layout=True,)# mask_exist=True,  image_dir = image_dir, image_metadata = metadata)
assay_layout

Reading metadata XML file...
Extracting metadata complete!


Unnamed: 0_level_0,Unnamed: 1_level_0,Strain,Compound,Concentration,ConcentrationEC
Row,Column,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
3,1,UNI,CTRL,0.0,EC0
3,2,UNI,CTRL,0.0,EC0
3,3,WT,CTRL,0.0,EC0
3,4,WT,CTRL,0.0,EC0
3,5,WT,PZA,60.0,EC50
3,6,WT,PZA,60.0,EC50
3,7,WT,RIF,0.1,EC50
3,8,WT,RIF,0.1,EC50
3,9,WT,INH,0.04,EC50
3,10,WT,INH,0.04,EC50


In [6]:
already_processed_acq_IDs = [(3, 4), (4, 3), (4, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9),
       (3, 10), (3, 11), (3, 12), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9),
       (4, 10), (3, 3)]

# Segment, localise and track

In [5]:
mtb_channel = 0
gfp_channel = 1
manual_mtb_thresh_channel = 2

# Test if i have the capacity to stack these channels together like this

In [None]:
# Inside your loop, use the log_progress function to log progress and errors
for (row, column), info in tqdm(assay_layout.iterrows(), desc='Progress through positions', total=len(assay_layout)):
    try:
        acq_ID = (row, column)
        log_progress(acq_ID, "Starting new acquisition")
        
        # if info['Strain'] == 'UNI':
        #     log_progress(acq_ID, "Skipping uninfected acquisition for now")
        #     continue
        # if acq_ID in already_processed_acq_IDs:
        #     log_progress(acq_ID, "Skipping already processed")
        #     continue
        # if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_warea.h5')):
        #     log_progress(acq_ID, "Skipping already processed")
        #     continue

        # process images using zarr
        image_dir = os.path.join(base_dir, f'acquisition/zarr/{acq_ID}.zarr')
        zarr_store = zarr.open(image_dir, mode='r')
        images = zarr_store.images
        # create a max projection
        images = np.max(images, axis = 2)
        
        log_progress(acq_ID, "Images loaded and stacked")

        # check if already segmented using m2 model
    #     if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}.h5')):
    #         continue
    #     else:
        log_progress(acq_ID, "Starting segmentation")

    #     if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_seg_backup.h5')):
    #         with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_seg_backup.h5'), 
    #                                        'r', 
    #                                        obj_type='obj_type_1'
    #                                        ) as reader:
    # #             writer.write_objects(objects)
    #             # writer.write_tracks(tracks)
    #             masks = reader.segmentation
    #         log_progress(acq_ID, "Loaded previously calculated segmentation")
    #     else:
        # segment images from gfp channel only
        masks = np.stack([segment(frame) 
                          for frame in tqdm(images[:,gfp_channel,...],  # segmenting the GFP channel 
                                            desc = 'Segmenting')])

        log_progress(acq_ID, "Finished segmentation")
        
        with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/cpv3/{row, column}_cpv3_mask_backup.h5'), 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
    #             writer.write_objects(objects)
                # writer.write_tracks(tracks)
                writer.write_segmentation(masks)
            
        log_progress(acq_ID, "Saved out masks")  
        
        log_progress(acq_ID, "Measuring Mtb area")       
        
        # characterise Mtb growth using Otsu segmentation
        # otsu_mtb = otsu_threshold(images[:,1,...]) # time consuming and non-deterministic when compared to hardcoded, could result in different thresholds for same image? 
        # characterise Mtb growth using hardcoded threshold :S
        manual_mtb_thresh = np.where(images[:,mtb_channel,...] >= Mtb_load_thresh, True, False)
        log_progress(acq_ID, "Creating intensity image for localisation")  
        # reshape intensity image to be gfp, rfp on last axis for regionprops
        intensity_image = np.stack([images[:,0,...], 
                                    images[:,1,...],  
    #                                 otsu_mtb, 
                                    manual_mtb_thresh], axis = -1)
        log_progress(acq_ID, "Localising objects")  
        # localise objects
        objects = localise(masks, 
                           intensity_image, 
                           )
        log_progress(acq_ID, "Filtering small objects")  
        # filter out small objects
        objects = [o for o in objects if o.properties['area'] > segment_size_thresh]

        log_progress(acq_ID, "Adding infection labels to objects")  
        # add label for infection
        for obj in objects:
            obj.properties = ({"Infected": True} 
                                if obj.properties['mean_intensity'][manual_mtb_thresh_channel] > 0 # index 2 for manual mtb channel 
                                else {"Infected": False})
            obj.properties = ({"Mtb area px": obj.properties['mean_intensity'][manual_mtb_thresh_channel]*obj.properties['area']})

        with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/cpv3/{row, column}_cpv3_objects_backup.h5'), 
                                                   'w', 
                                                   obj_type='obj_type_1'
                                                   ) as writer:
                        writer.write_objects(objects)
                        # writer.write_tracks(tracks)
        
        log_progress(acq_ID, "Beginning tracking")  
        # track on upscaled config fn
        tracks = track(objects, masks, config_fn, search_radius = 20)
        log_progress(acq_ID, "Saving tracking")  
        # save out 
        with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/cpv3/{row, column}_cpv3_tracks_backup.h5'), 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
    #             writer.write_objects(objects)
                writer.write_tracks(tracks)
                # writer.write_segmentation(masks)
        # Log successful completion

        with btrack.io.HDF5FileHandler(f'{row, column}_cpv3_full_backup.h5', 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
                    writer.write_tracks(tracks)
                    writer.write_objects(objects)
                    writer.write_segmentation(masks)
        # Log successful completion
        log_progress(acq_ID, "Processing completed successfully")

    except Exception as e:
        # Log errors
        log_progress(acq_ID, f"Processing failed: {str(e)}")

# You can also log information before and after the loop
logging.info("Processing completed")

# Notify if required
notify.send_sms("Processing completed")


Progress through positions:   0%|          | 0/42 [00:00<?, ?it/s]

INFO:root:Position (3, 1): Starting new acquisition


# And now try and redo ps00000

### Load experiment of choice

The Opera Phenix is a high-throughput confocal microscope that acquires very large 5-dimensional (TCZXY) images over several fields of view in any one experiment. Therefore, a lazy-loading approach is chosen to mosaic, view and annotate these images. This approach depends upon Dask and DaskFusion. The first step is to load the main metadata file (typically called `Index.idx.xml` and located in the main `Images` directory) that contains the image filenames and associated TCXZY information used to organise the images.

In [None]:
base_dir = '/mnt/DATA/macrohet/PS0000/'
metadata_fn = os.path.join(base_dir, 'acquisition/Images/Index.idx.xml')
metadata = dataio.read_harmony_metadata(metadata_fn)  
metadata

### View assay layout and mask information (optional)

The Opera Phenix acquires many time lapse series from a range of positions. The first step is to inspect the image metadata, presented in the form of an `Assaylayout/experiment_ID.xml` file, to show which positions correspond to which experimental assays.

In [None]:
metadata_path = glob.glob(os.path.join(base_dir, 'acquisition/Assaylayout/*.xml'))[0]
assay_layout = dataio.read_harmony_metadata(metadata_path, assay_layout=True,)# mask_exist=True,  image_dir = image_dir, image_metadata = metadata)
assay_layout

# Segment, localise and track

In [None]:
mtb_channel = 1
gfp_channel = 2
manual_mtb_thresh_channel = 2

In [None]:
# Inside your loop, use the log_progress function to log progress and errors
for (row, column), info in tqdm(assay_layout.iterrows(), desc='Progress through positions', total=len(assay_layout)):
    try:
        acq_ID = (row, column)
        log_progress(acq_ID, "Starting new acquisition")
        
        # if info['Strain'] == 'UNI':
        #     log_progress(acq_ID, "Skipping uninfected acquisition for now")
        #     continue
        # if acq_ID in already_processed_acq_IDs:
        #     log_progress(acq_ID, "Skipping already processed")
        #     continue
        if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_warea.h5')):
            log_progress(acq_ID, "Skipping already processed")
            continue

        # process images using zarr
        # image_dir = os.path.join(base_dir, f'acquisition/zarr/{acq_ID}.zarr')
        # zarr_store = zarr.open(image_dir, mode='r')
        # images = zarr_store.images
        # # create a max projection
        # images = np.max(images, axis = 2)

        image_dir = os.path.join(base_dir, 'acquisition/Images')
        images = tile.compile_mosaic(image_dir, 
                                     metadata, 
                                     row, column, 
                                     # subset_field_IDs=['16', '17',  '20', '21'], 
                                     # n_tile_rows = 2, n_tile_cols = 2,
                                     set_plane='max_proj'
                                     # set_channel=1,
                                     # set_time = 1,
        #                             input_transforms = [input_transforms]
                                    ).compute().compute()
        images = images[:,:,0,...]
        
        log_progress(acq_ID, "Images loaded and stacked")

        # check if already segmented using m2 model
    #     if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}.h5')):
    #         continue
    #     else:
        log_progress(acq_ID, "Starting segmentation")

        if os.path.exists(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_seg_backup.h5')):
            with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_seg_backup.h5'), 
                                           'r', 
                                           obj_type='obj_type_1'
                                           ) as reader:
    #             writer.write_objects(objects)
                # writer.write_tracks(tracks)
                masks = reader.segmentation
            log_progress(acq_ID, "Loaded previously calculated segmentation")
        else:
            # segment images from gfp channel only
            masks = np.stack([segment(frame) 
                              for frame in tqdm(images[:,gfp_channel,...],  # segmenting the GFP channel 
                                                desc = 'Segmenting')])

            log_progress(acq_ID, "Finished segmentation")
        with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_seg_backup.h5'), 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
    #             writer.write_objects(objects)
                # writer.write_tracks(tracks)
                writer.write_segmentation(masks)
        log_progress(acq_ID, "Saved out masks")  
        log_progress(acq_ID, "Measuring Mtb area")       
        # characterise Mtb growth using Otsu segmentation
        # otsu_mtb = otsu_threshold(images[:,1,...]) # time consuming and non-deterministic when compared to hardcoded, could result in different thresholds for same image? 
        # characterise Mtb growth using hardcoded threshold :S
        manual_mtb_thresh = np.where(images[:,mtb_channel,...] >= Mtb_load_thresh, 1, 0)
        log_progress(acq_ID, "Creating intensity image for localisation")  
        # reshape intensity image to be gfp, rfp on last axis for regionprops
        intensity_image = np.stack([images[:,0,...], 
                                    images[:,1,...],  
    #                                 otsu_mtb, 
                                    manual_mtb_thresh], axis = -1)
        log_progress(acq_ID, "Localising objects")  
        # localise objects
        objects = localise(masks, 
                           intensity_image, 
                           )
        log_progress(acq_ID, "Filtering small objects")  
        # filter out small objects
        objects = [o for o in objects if o.properties['area'] > segment_size_thresh]

        log_progress(acq_ID, "Adding infection labels to objects")  
        # add label for infection
        for obj in objects:
            obj.properties = ({"Infected": True} 
                                if obj.properties['mean_intensity'][manual_mtb_thresh_channel] > 0 # index 2 for manual mtb channel 
                                else {"Infected": False})

        log_progress(acq_ID, "Beginning tracking")  
        # track on upscaled config fn
        tracks = track(objects, masks, config_fn, search_radius = 20)
        log_progress(acq_ID, "Saving tracking")  
        # save out 
        with btrack.io.HDF5FileHandler(os.path.join(base_dir, f'labels/macrohet_seg_model/{row, column}_first_pass_warea.h5'), 
                                           'w', 
                                           obj_type='obj_type_1'
                                           ) as writer:
    #             writer.write_objects(objects)
                writer.write_tracks(tracks)
                writer.write_segmentation(masks)
        # Log successful completion
        log_progress(acq_ID, "Processing completed successfully")

    except Exception as e:
        # Log errors
        log_progress(acq_ID, f"Processing failed: {str(e)}")

# You can also log information before and after the loop
logging.info("Processing completed")

# Notify if required
# notify.send_sms("Processing completed")
