# Segmentation

This notebook is for segmenting timelapse microscopy data, with associated sinhgle-cell labels and tracks, showing the infection of human macrophages with Mycobacterium Tuberculosis (Mtb), acquired on an Opera Phenix confocal microscope. 

In [1]:
import napari
import cellpose
from octopuslite import utils, tile

### Load experiment of choice

The Opera Phenix is a high-throughput confocal microscope that acquires very large 5-dimensional (TCZXY) images over several fields of view in any one experiment. Therefore, a lazy-loading approach is chosen to mosaic, view and annotate these images. This approach depends upon Dask and DaskFusion. The first step is to load the main metadata file (typically called `Index.idx.xml` and located in the main `Images` directory) that contains the image filenames and associated TCXZY information used to organise the images.

In [2]:
image_dir = '/mnt/DATA/sandbox/pierre_live_cell_data/outputs/Replication_IPSDM_GFP/Images/'
metadata_fn = '/mnt/DATA/sandbox/pierre_live_cell_data/outputs/Replication_IPSDM_GFP/Index.idx.xml'
metadata = utils.read_harmony_metadata(metadata_fn)

Reading metadata XML file...


Extracting HarmonyV5 metadata:   0%|          | 0/113400 [00:00<?, ?it/s]

Extracting metadata complete!


### View assay layout and mask information (optional)

The Opera Phenix acquires many time lapse series from a range of positions. The first step is to inspect the image metadata, presented in the form of an `Assaylayout/experiment_ID.xml` file, to show which positions correspond to which experimental assays.

In [3]:
metadata_path = '/mnt/DATA/sandbox/pierre_live_cell_data/outputs/Replication_IPSDM_GFP/Assaylayout/20210602_Live_cell_IPSDMGFP_ATB.xml'
utils.read_harmony_metadata(metadata_path, assay_layout=True)

Reading metadata XML file...
Extracting metadata complete!


Unnamed: 0,Unnamed: 1,Strain,Compound,Concentration,ConcentrationEC
3,4,RD1,CTRL,0.0,EC0
3,5,WT,CTRL,0.0,EC0
3,6,WT,PZA,60.0,EC50
3,7,WT,RIF,0.1,EC50
3,8,WT,INH,0.04,EC50
3,9,WT,BDQ,0.02,EC50
4,4,RD1,CTRL,0.0,EC0
4,5,WT,CTRL,0.0,EC0
4,6,WT,PZA,60.0,EC50
4,7,WT,RIF,0.1,EC50


### Define row and column of choice

In [25]:
row = '6'
column = '9'

### Now to lazily mosaic the images using Dask prior to viewing them.

1x (75,2,3) [TCZ] image stack takes approximately 1 minute to stitch together, so only load the one field of view I want.

In [26]:
images = tile.compile_mosaic(image_dir, 
                             metadata, 
                             row, 
                             column, 
                             set_channel=1, 
                             set_plane = 1, )
images

Unnamed: 0,Array,Chunk
Bytes,5.11 GiB,7.75 MiB
Shape,"(75, 1, 1, 6048, 6048)","(1, 1, 1, 2016, 2016)"
Count,2925 Tasks,675 Chunks
Type,uint16,numpy.ndarray
"Array Chunk Bytes 5.11 GiB 7.75 MiB Shape (75, 1, 1, 6048, 6048) (1, 1, 1, 2016, 2016) Count 2925 Tasks 675 Chunks Type uint16 numpy.ndarray",1  75  6048  6048  1,

Unnamed: 0,Array,Chunk
Bytes,5.11 GiB,7.75 MiB
Shape,"(75, 1, 1, 6048, 6048)","(1, 1, 1, 2016, 2016)"
Count,2925 Tasks,675 Chunks
Type,uint16,numpy.ndarray


# Segment 
Let us start simple, only segmenting the lowest Z plane where the largest regions of cells are and only ch1 (GFP) where the GFP signal is.

In [4]:
!nvcc --version
!nvidia-smi

from cellpose import core, utils, io, models, metrics

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')

model = models.Cellpose(gpu=True, model_type='cyto')

def segment(img):
    masks, flows, styles, diams = model.eval(img, diameter=200, channels=[0,0],
                                             flow_threshold=None, cellprob_threshold=0)
    return masks

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
Tue Jan 17 18:28:21 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:65:00.0  On |                  Off |
| 30%   37C    P8    32W / 300W |   1050MiB / 49140MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
 

# Testing different segmentation parameters 

In [8]:
import itertools
import numpy as np
import napari
import cellpose
import octopuslite
from octopuslite import tile
from tqdm.auto import tqdm
import numpy as np
import datetime 
from skimage.io import imsave, imshow, imread
from skimage.measure import label, regionprops
import skimage as ski
from skimage.morphology import remove_small_objects, remove_small_holes, binary_erosion
import os
import itertools, os
from tqdm.auto import tqdm
import scipy.ndimage as ndi
import sys
sys.path.append('../../unet_segmentation_metrics/')
import umetrics
import matplotlib.pyplot as plt

In [9]:
mask_dict = np.load('mask_dict.npy', allow_pickle = True)
modified_mask_dict = np.load('modified_mask_dict.npy', allow_pickle = True)

In [19]:
mask_dict = mask_dict.item()
modified_mask_dict = modified_mask_dict.item()

In [31]:
### average cell diameter
diameters = [200,250,300]
### flow threshold, larger value means more ROIs (maybe ill fitting), lower means fewer ROIs 
flow_thresholds = [0.8 ]#0.0, 0.6, ]
### cellprob_threshold, larger is is fewer ROIs, lower means more...? 
# cellprobs_thresholds = [-0.2, 0.0, 0.2]

# NOTE: this minimum size is too great. Need to re-do modified_mask_dict.npy and save as h5. Should this also be a sum_proj???

In [None]:
# mask_dict = dict()
# modified_mask_dict = dict()
params = list(itertools.product(diameters, flow_thresholds))
for n, (diameter, flow_threshold) in tqdm(enumerate(params), total = len(params)):
#     if n < 4:
#         continue
    mask_stack_ = []
    modified_mask_stack = []
    for timepoint in tqdm(images, total = len(images), leave = False):
        ### extract GFP channel and lowest Z plane from single time point
        gfp_z0_frame = timepoint[0,0,...]
        masks, flows, styles, diams = model.eval(gfp_z0_frame, diameter=diameter, channels=[0,0],
                                             flow_threshold=flow_threshold, cellprob_threshold=0)        
        
        
        pred = masks
        pred = remove_small_objects(pred, min_size=10000)
        ### need to iterate over each individual segment and erode
        new_mask = np.zeros(pred.shape, dtype = np.uint8)
        for segment_ID in tqdm(range(1, np.max(pred)), leave = False):
            segment = pred == segment_ID 
            eroded_segment = binary_erosion(segment)
            new_mask += eroded_segment.astype(np.uint8)
        pred = remove_small_holes(new_mask).astype(np.uint8)
        
        modified_mask_stack.append(pred)
        mask_stack_.append(masks)
    mask_images_ = da.stack(mask_stack_, axis = 0) 
    modified_mask_images = da.stack(modified_mask_stack, axis = 0) 
    mask_dict[(diameter, flow_threshold)] = mask_images_
    modified_mask_dict[(diameter, flow_threshold)] = modified_mask_images
    np.save('mask_dict.npy', mask_dict)
    np.save('modified_mask_dict.npy', modified_mask_dict)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/695 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/703 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/697 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/701 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/748 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/763 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/736 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/714 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/749 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/727 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/771 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/748 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/766 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/714 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/696 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/764 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/749 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/709 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/712 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/718 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/700 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/724 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/688 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/728 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/730 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/725 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/682 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/696 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/684 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/687 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/676 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/648 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/624 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/612 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/639 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/698 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/601 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/596 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/580 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/597 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/554 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/542 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/554 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/511 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/511 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/497 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/492 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/490 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/478 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/476 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/432 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/460 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/394 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/419 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/401 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/388 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/380 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/367 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/379 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/365 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/366 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/379 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/339 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/337 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/336 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/320 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/310 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/301 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/295 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/293 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/308 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/287 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/287 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/270 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/279 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/634 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/644 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/642 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/656 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/660 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/656 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/668 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/657 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/667 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/656 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/706 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/662 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/677 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/679 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/654 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/667 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/686 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/662 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/645 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/665 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/633 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/614 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/610 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/691 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/606 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/608 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/608 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/604 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/644 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/616 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/608 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/580 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/586 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/552 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/581 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/573 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/571 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/540 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/552 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/540 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/536 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/513 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/525 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/477 [00:00<?, ?it/s]

  return func(*args, **kwargs)


  0%|          | 0/483 [00:00<?, ?it/s]

  return func(*args, **kwargs)


# Now do sum_proj for comparison

In [None]:
images = tile.compile_mosaic(image_dir, 
                             metadata, 
                             row, 
                             column, 
                             set_channel=1, 
                             set_plane = 'sum_proj', )
images = images.compute().compute().astype(np.uint16)

In [None]:
### average cell diameter
diameters = [200,250,300]
### flow threshold, larger value means more ROIs (maybe ill fitting), lower means fewer ROIs 
flow_thresholds = [0.0, 0.6,]
### cellprob_threshold, larger is is fewer ROIs, lower means more...? 
# cellprobs_thresholds = [-0.2, 0.0, 0.2]

In [None]:
mask_dict_z_proj = dict()
modified_mask_dict_z_proj = dict()
params = list(itertools.product(diameters, flow_thresholds))
for n, (diameter, flow_threshold) in tqdm(enumerate(params), total = len(params)):
#     if n < 4:
#         continue
    mask_stack_ = []
    modified_mask_stack = []
    for timepoint in tqdm(images, total = len(images), leave = False):
        ### extract GFP channel and lowest Z plane from single time point
        gfp_z0_frame = timepoint[0,0,...]
        masks, flows, styles, diams = model.eval(gfp_z0_frame, diameter=diameter, channels=[0,0],
                                             flow_threshold=flow_threshold, cellprob_threshold=0)        
        
        
        pred = masks
        pred = remove_small_objects(pred, min_size=2500)
        ### need to iterate over each individual segment and erode
        new_mask = np.zeros(pred.shape, dtype = np.uint8)
        for segment_ID in tqdm(range(1, np.max(pred)), leave = False):
            segment = pred == segment_ID 
            eroded_segment = binary_erosion(segment)
            new_mask += eroded_segment.astype(np.uint8)
#         pred = remove_small_holes(new_mask).astype(np.uint8)
        
        modified_mask_stack.append(pred)
        mask_stack_.append(masks)
    mask_images_ = da.stack(mask_stack_, axis = 0) 
    modified_mask_images = da.stack(modified_mask_stack, axis = 0) 
    mask_dict_z_proj[(diameter, flow_threshold)] = mask_images_
    modified_mask_dict_z_proj[(diameter, flow_threshold)] = modified_mask_images
    np.save('mask_dict_z_proj.npy', mask_dict_z_proj)
    np.save('modified_mask_dict_z_proj.npy', modified_mask_dict_z_proj)

# Tidy up segmentation