# Complete pre-analysis cell labelling pipeline

1. Alignment
2. Segmentation
3. Object localisation
4. Tracking


In [11]:
import os
import glob
import enum
import re
import numpy as np
import btrack
from pystackreg import StackReg
from skimage.io import imsave, imread
from tqdm.auto import tqdm
from octopuslite import DaskOctopusLiteLoader
from skimage import transform as tf
from stardist.models import StarDist2D 
from stardist.plot import render_label
from csbdeep.utils import normalize
from scipy import ndimage as nd
from scipy.special import softmax
from cellx import load_model
from cellx.tools.image import InfinitePaddedImage
from skimage.transform import resize

seg_model = StarDist2D.from_pretrained('2D_versatile_fluo')

def normalize_channels(x):

    for dim in range(x.shape[-1]):
        x[..., dim] = normalize(x[..., dim])
        
    return x

def normalize(x):

    xf = x.astype(np.float32)
    mx = np.mean(xf)
    sd = np.max([np.std(xf), 1./np.prod(x.shape)])

    return (xf - mx) / sd

def classify_objects(image,  gfp, rfp, objects, obj_type):
    
    # define stages of cell cycle to classify (dependent on model type)
    LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]
    
    # iterate over frames
    for n in tqdm(range(image.shape[0])):
        
        # only select objects if in frame
        _objects = [o for o in objects if o.t == n]
        
        # empty placeholder arrays
        crops = []
        to_update = []
        
        # select h2b channel to aid in classification
        fp = gfp if obj_type == 1 else rfp
        
        # create stack by computing each frame of dask array input
        frame = np.stack(
            [image[n, ...].compute(), fp[n, ...].compute()], 
            axis=-1,) 
        
        # create padded image for network
        vol = InfinitePaddedImage(frame, mode = 'reflect')
        
        # iterate over objects 
        for obj in _objects:
            
            # create coords for image slice
            xs = slice(int(obj.x-40), int(obj.x+40), 1)
            ys = slice(int(obj.y-40), int(obj.y+40), 1)
            
            # crop image
            crop = vol[ys, xs, :]
            crop = resize(crop, (64, 64), preserve_range=True).astype(np.float32)
            
            # normalise image
            if crop.shape == (64 ,64, 2):
                crops.append(normalize_channels(crop))
                to_update.append(obj)
            else:
                print(crop.shape)
                
        if not crops:
            continue
            
        # use classifcation model to predict
        pred = model.predict(np.stack(crops, axis=0))
        
        # check correct number of predictions
        assert pred.shape[0] == len(_objects)
        
        # assign labels to objects
        for idx in range(pred.shape[0]):
            obj = _objects[idx]
            
            # assigning details of prediction
            pred_label = np.argmax(pred[idx, ...])
            pred_softmax = softmax(pred[idx, ...])
            logits = {f"prob_{k}": pred_softmax[ki] for ki, k in enumerate(LABELS)}
            
            # write out
            obj.label = pred_label
            obj.properties = logits

    return objects

Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.


In [None]:
root_dir = '/home/nathan/data/kraken/ras'
expt_list = sorted([expt for expt in os.listdir(root_dir) 
                    if 'ND' in expt and os.path.isdir(os.path.join(root_dir, expt))], 
                    key = lambda x: [int(y) for y in re.findall(r'\d+', x)])

for expt in tqdm(expt_list):
    pos_list = sorted([pos for pos in os.listdir(f'{root_dir}/{expt}') 
                        if 'Pos' in pos 
                        and os.path.isdir(f'{root_dir}/{expt}/{pos}')],
                        key = lambda x: [int(y) for y in re.findall(r'\d+', x)])
    
    for pos in tqdm(pos_list):

        print(f'Starting alignment for {expt}/{pos}')

        ### create new subdir of for raw files and move them all there
        image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
        if not os.path.exists(image_path):
            os.mkdir(image_path)
            files = sorted(glob.glob(f'{root_dir}/{expt}/{pos}/*.tif'))
            for file in files:
                os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_images'))

        # check if blanks dir exists and make if not and move
        if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_blanks'):
            os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_blanks')
            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(image_path, remove_background= False)

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # set empty dict arrays for mean values 
            mean_arrays = {}
            # set for dodgy frames (only unique entries)
            dodgy_frame_list = set([])
            #iterate over channels
            for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
                if 'MASK' in channel.name:
                    continue
                # find mean value of each frame in each channel
                mean_arrays[channel.name] = [np.mean(img) for img in image_generator(images.files(channel.name))]
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    # check to see if mean frame pixel value meets criteria
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if so add to delete list
                        dodgy_frame_list.add(frame)
            # format delete list to only include single values
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list))

            # move blank images into this directory
            for channel in images.channels:
                for f in images.files(channel.name):
                    for i in dodgy_frame_list:
                        if str(i).zfill(9) in f:
                            os.rename(f, f.replace('_images', '_blanks'))

        # crop central window out of reference image with blanks removed
        reference_image = DaskOctopusLiteLoader(image_path, 
                                                crop = (500, 500)
                                               )['gfp'].compute() 


        ### Register alignment
        print('Registering alignment for', pos, expt)
        # create operator using transformation type (translation)
        sr = StackReg(StackReg.TRANSLATION) 
        # register each frame to the previous as transformation matrices/tensor
        transform_tensor = sr.register_stack(reference_image, reference = 'previous', )

        ### clip transformation tensor to eliminate any rare jumps, (1688-1600)/2=44
        transform_tensor = np.clip(transform_tensor, a_max= 44, a_min = -44)

        # save out transform tensor
        np.save(f'{root_dir}/{expt}/{pos}/transform_tensor.npy', transform_tensor)

        print('Alignment complete for', expt, pos)



        print('Starting segmentation for', expt, pos)
        # load images
        images = DaskOctopusLiteLoader(image_path, 
                                       remove_background = True)

        # iterate over images filenames 
        for frame, fn in tqdm(enumerate(images.files('gfp')),total = len(images.files('gfp'))):
            # load two seperate images
            gfp = imread(fn)
            # predict labels using 2 instances of the model with different params
            labels, details = seg_model.predict_instances(normalize(gfp), prob_thresh=0.75)
            # format 2channel mask image 
            mask = np.zeros(labels.shape)

            for i in range(1, np.amax(labels)):
                #needs erosion step to stop merging of labels
                segment = nd.binary_erosion(labels==i)
                mask[segment] = 1 ## for gfp
                # set background to zero
                mask[labels == 0] = 0

            # now do the same for the rfp channel
            rfp = imread(fn.replace('channel001', 'channel002'))   
            # predict labels using 2 instances of the model with different params
            labels, details = seg_model.predict_instances(normalize(rfp), prob_thresh=0.2)

            for i in range(1, np.amax(labels)):
                #needs erosion step to stop merging of labels
                segment = nd.binary_erosion(labels==i)
                mask[segment] = 2 ## for rfp

            # set filename as mask format (channel099)
            fn = ((images.files('gfp')[frame])).replace('channel001', 'channel099')
            # save out labelled image
            imsave(fn, mask.astype(np.uint8), check_contrast=False)

        print('Segmentation complete for', expt, pos)

        print('Starting object localisation for', expt, pos)

        transform_path = f'{root_dir}/{expt}/{pos}/transform_tensor.npy'
        images = DaskOctopusLiteLoader(image_path, 
                                      # transforms=transform_path,
                                       crop=(1200,1600), 
                                       remove_background=True)

        objects = btrack.utils.segmentation_to_objects(
            images['mask'],
            images['gfp'],
            properties = ('area', 'eccentricity', 'mean_intensity'),
            assign_class_ID = True,
        )

        objects_gfp = [obj for obj in objects if obj.properties['class id'] == 1]
        objects_rfp = [obj for obj in objects if obj.properties['class id'] == 2]

        model = load_model('../models/cellx_classifier_stardist.h5')

        bf = images['brightfield']
        gfp = images['gfp']
        rfp = images['rfp']
        
        print('Classifying objects in', expt, pos)
        objects_gfp = classify_objects(bf, gfp, rfp, objects_gfp, obj_type = 1)
        objects_rfp = classify_objects(bf, gfp, rfp, objects_rfp, obj_type = 2)

        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects_type_1.h5', 'w', obj_type='obj_type_1',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_gfp)

        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects_type_2.h5', 'w', obj_type='obj_type_2',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_rfp)

        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects.h5', 'w', obj_type='obj_type_1',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_gfp)

        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects.h5', 'a', obj_type='obj_type_2',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_rfp)

        print('Object localisation complete for', expt, pos)

        print('Starting tracking for', expt, pos)

        # initialise a tracker session using a context manager
        with btrack.BayesianTracker() as tracker:

            # configure the tracker using a config file
            tracker.configure_from_file(
                '../models/MDCK_config_wildtype.json'
            )
            tracker.max_search_radius = 40

            # append the objects to be tracked
            tracker.append(objects_gfp)

            # set the volume
            tracker.volume=((0, 1600), (0, 1200), (-1e5, 1e5))

            # track them (in interactive mode)
            tracker.track_interactive(step_size=100)

            # generate hypotheses and run the global optimizer
            tracker.optimize()

            tracker.export(f'{root_dir}/{expt}/{pos}/tracks.h5', obj_type='obj_type_1')

            # get the tracks in a format for napari visualization (optional)
            visaulise_tracks, properties, graph = tracker.to_napari(ndim=2)

            gfp_tracks = tracker.tracks

        # initialise a tracker session using a context manager
        with btrack.BayesianTracker() as tracker:

            # configure the tracker using a config file
            tracker.configure_from_file(
                '../models/MDCK_config_scribble_sparse.json'
            )
            tracker.max_search_radius = 40

            # append the objects to be tracked
            tracker.append(objects_rfp)

            # set the volume
            tracker.volume=((0, 1600), (0, 1200), (-1e5, 1e5))

            # track them (in interactive mode)
            tracker.track_interactive(step_size=100)

            # generate hypotheses and run the global optimizer
            tracker.optimize()

            tracker.export(f'{root_dir}/{expt}/{pos}/tracks.h5', obj_type='obj_type_2')

            # get the tracks in a format for napari visualization (optional)
            visaulise_tracks, properties, graph = tracker.to_napari(ndim=2)

            rfp_tracks = tracker.tracks

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

Starting alignment for ND0013/Pos0
Using cropping: (500, 500)
Registering alignment for Pos0 ND0013
Alignment complete for ND0013 Pos0
Starting segmentation for ND0013 Pos0


  0%|          | 0/2785 [00:00<?, ?it/s]