# Complete pre-analysis cell labelling pipeline

1. Alignment
2. Segmentation
3. Object localisation
4. Tracking


In [1]:
import os
import glob
import enum
import re
import numpy as np
import btrack
import pandas as pd
from pystackreg import StackReg
from skimage.io import imsave, imread
from tqdm.auto import tqdm
from octopuslite import DaskOctopusLiteLoader
from skimage import transform as tf
from skimage.transform import resize ### tidy up these dependencies
from stardist.models import StarDist2D 
from stardist.plot import render_label
from csbdeep.utils import normalize
from scipy import ndimage as nd
from scipy.special import softmax
from cellx import load_model
from cellx.tools.image import InfinitePaddedImage
from skimage.measure import label, regionprops
from skimage.morphology import binary_erosion, remove_small_objects
from natsort import natsorted

seg_model = StarDist2D.from_pretrained('2D_versatile_fluo')

def image_generator(files, crop = None):
    
    if crop is None:
        for filename in files:
            img = imread(filename)
            yield img
    else:
        for filename in files:
            img = imread(filename)
            img = crop_image(img, crop)
            yield img

def normalize_channels(x):

    for dim in range(x.shape[-1]):
        x[..., dim] = normalize(x[..., dim])
        
    return x

def normalize(x):

    xf = x.astype(np.float32)
    mx = np.mean(xf)
    sd = np.max([np.std(xf), 1./np.prod(x.shape)])

    return (xf - mx) / sd

def classify_objects(image,  gfp, rfp, objects, obj_type):
    
    # define stages of cell cycle to classify (dependent on model type)
    LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]
    
    # iterate over frames
    for n in tqdm(range(image.shape[0])):
        
        # only select objects if in frame
        _objects = [o for o in objects if o.t == n]
        
        # empty placeholder arrays
        crops = []
        to_update = []
        
        # select h2b channel to aid in classification
        fp = gfp if obj_type == 1 else rfp
        
        # create stack by computing each frame of dask array input
        frame = np.stack(
            [image[n, ...].compute(), fp[n, ...].compute()], 
            axis=-1,) 
        
        # create padded image for network
        vol = InfinitePaddedImage(frame, mode = 'reflect')
        
        # iterate over objects 
        for obj in _objects:
            
            # create coords for image slice
            xs = slice(int(obj.x-40), int(obj.x+40), 1)
            ys = slice(int(obj.y-40), int(obj.y+40), 1)
            
            # crop image
            crop = vol[ys, xs, :]
            crop = resize(crop, (64, 64), preserve_range=True).astype(np.float32)
            
            # normalise image
            if crop.shape == (64 ,64, 2):
                crops.append(normalize_channels(crop))
                to_update.append(obj)
            else:
                print(crop.shape)
                
        if not crops:
            continue
            
        # use classifcation model to predict
        pred = model.predict(np.stack(crops, axis=0))
        
        # check correct number of predictions
        assert pred.shape[0] == len(_objects)
        
        # assign labels to objects
        for idx in range(pred.shape[0]):
            obj = _objects[idx]
            
            # assigning details of prediction
            pred_label = np.argmax(pred[idx, ...])
            pred_softmax = softmax(pred[idx, ...])
            logits = {f"prob_{k}": pred_softmax[ki] for ki, k in enumerate(LABELS)}
            
            # write out
            obj.label = pred_label
            obj.properties = logits

    return objects

Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.


# Experiment info

In [2]:
expt_info = pd.read_csv('/home/nathan/data/kraken/ras/experiment_info_final.csv', header = 1)

In [3]:
expt_info = expt_info.rename(columns = {'EXP n˚':'Experiments', 'POSITION':'Positions', 'CELL TYPE':'Condition', 'Useable (in radial analysis)':'Valid'})

In [4]:
expt_info

Unnamed: 0,Experiments,Positions,Condition,Well,EXPT NOTES,POS NOTES,Valid,BF CHANNEL,GFP CHANNEL,RFP CHANNEL,...,Focus?,ALIGNED?,SEGMENTED?,Localised?,TRACKED?,segmentation notes,SEG Model,TRACK MODEL,BLISTERING?,COMPETITION?
0,08.11.2021,stopped due to focus issue,,,,,,,,,...,,FALSE,FALSE,False,False,,,,,
1,ND0000,Pos0,MDCK Rasv12 -,,stopped due to focus issue,uninduced,False,,Ras,mutant(ras)-h2b,...,,FALSE,FALSE,False,False,,,,,
2,ND0000,Pos1,50:50 wt:ras+,,stopped due to focus issue,induced,False,,Ras + wt-h2b,mutant(ras)-h2b,...,,FALSE,FALSE,False,False,,,,,
3,ND0000,Pos2,MDCK Rasv12 +,,stopped due to focus issue,induced,False,,Ras,mutant(ras)-h2b,...,,FALSE,FALSE,False,False,,,,,
4,ND0000,Pos3,50:50 wt:ras+,,stopped due to focus issue,induced,False,,Ras + wt-h2b,mutant(ras)-h2b,...,,FALSE,FALSE,False,False,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
363,ND0025,Pos9,97.5:2.5 wt:ras+,6.0,,induced 3x seed dens,True,,Ras + wt-h2b,mutant(ras)-h2b,...,,,,,,,,,,
364,ND0025,Pos10,97.5:2.5 wt:ras+,6.0,,induced 3x seed dens,True,,Ras + wt-h2b,mutant(ras)-h2b,...,,,,,,,,,,
365,ND0025,Pos11,97.5:2.5 wt:ras+,6.0,,induced 3x seed dens,True,,Ras + wt-h2b,mutant(ras)-h2b,...,,,,,,,,,,
366,ND0025,Pos12,97.5:2.5 wt:ras+,6.0,,induced 3x seed dens,True,,Ras + wt-h2b,mutant(ras)-h2b,...,,,,,,,,,,


##### Just 90:10 expts

In [5]:
expt_pos_list = expt_info.loc[(expt_info['Valid'] == True) & (expt_info['Condition'] == "90:10 wt:ras+" )][['Experiments','Positions', 'Condition']]
expt_pos_list

Unnamed: 0,Experiments,Positions,Condition
176,ND0013,Pos3,90:10 wt:ras+
177,ND0013,Pos4,90:10 wt:ras+
178,ND0013,Pos5,90:10 wt:ras+
179,ND0013,Pos6,90:10 wt:ras+
180,ND0013,Pos7,90:10 wt:ras+
181,ND0013,Pos8,90:10 wt:ras+
182,ND0013,Pos9,90:10 wt:ras+
183,ND0013,Pos10,90:10 wt:ras+
188,ND0014,Pos0,90:10 wt:ras+
189,ND0014,Pos1,90:10 wt:ras+


In [6]:
root_dir = '/home/nathan/data/kraken/ras'

for i, expt_pos in tqdm(expt_pos_list.iterrows(), desc = 'Progress of experiment annotation', total = len(expt_pos_list)):    
    expt = expt_pos['Experiments']
    pos = expt_pos['Positions']
    if pos != 'Pos6':
        continue
    print(f'Starting alignment for {expt}/{pos}')

    ### create new subdir of for raw files and move them all there
    image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
    if not os.path.exists(image_path):
        os.mkdir(image_path)
        files = sorted(glob.glob(f'{root_dir}/{expt}/{pos}/*.tif'))
        for file in files:
            os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_images'))

    # check if blanks dir exists and make if not and move
    if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_blanks'):
        os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_blanks')
        ### pre load files from raw file dir 
        images = DaskOctopusLiteLoader(image_path, remove_background= False)

        ### measure mean pixel value arrays and use to find under/over-exposed frames
        max_pixel = 200
        min_pixel = 2
        # set empty dict arrays for mean values 
        mean_arrays = {}
        # set for dodgy frames (only unique entries)
        dodgy_frame_list = set([])
        #iterate over channels
        for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
            if 'MASK' in channel.name:
                continue
            # find mean value of each frame in each channel
            mean_arrays[channel.name] = [np.mean(img) for img in image_generator(images.files(channel.name))]
            # iterate over frames
            for frame, mean_value in enumerate(mean_arrays[channel.name]):
                # check to see if mean frame pixel value meets criteria
                if max_pixel < mean_value or mean_value < min_pixel:
                    # if so add to delete list
                    dodgy_frame_list.add(frame)
        # format delete list to only include single values
        dodgy_frame_list = list(dodgy_frame_list)
        print('Number of under/over-exposed frames:', len(dodgy_frame_list))

        # move blank images into this directory
        for channel in images.channels:
            for f in images.files(channel.name):
                for i in dodgy_frame_list:
                    if str(i).zfill(9) in f:
                        os.rename(f, f.replace('_images', '_blanks'))

    if not os.path.exists(f'{root_dir}/{expt}/{pos}/transform_tensor.npy'):
        # crop central window out of reference image with blanks removed
        reference_image = DaskOctopusLiteLoader(image_path, 
                                                crop = (500, 500)
                                               )['gfp'].compute() 


        ### Register alignment
        print('Registering alignment for', pos, expt)
        # create operator using transformation type (translation)
        sr = StackReg(StackReg.TRANSLATION) 
        # register each frame to the previous as transformation matrices/tensor
        transform_tensor = sr.register_stack(reference_image, reference = 'previous', )

        ### clip transformation tensor to eliminate any rare jumps, (1688-1600)/2=44
        transform_tensor = np.clip(transform_tensor, a_max= 44, a_min = -44)

        # save out transform tensor
        np.save(f'{root_dir}/{expt}/{pos}/transform_tensor.npy', transform_tensor)

        print('Alignment complete for', expt, pos)



    print('Starting segmentation for', expt, pos)
    # load images
    images = DaskOctopusLiteLoader(image_path, 
                                   remove_background = True)
    
    
    if 'MASK_GFP' in [channel.name for channel in images.channels]:
        print('Masks already preprocessing, skipping to next experiment')
        continue
    # iterate over images filenames 
    for frame, fn in tqdm(enumerate(images.files('gfp')),total = len(images.files('gfp'))):
        # load two seperate images
        if os.path.exists(fn.replace('channel001', 'channel099')):
            continue
        gfp = imread(fn)
        # predict gfp labels with a higher threshold as the fl. signal is strong
        labels, details = seg_model.predict_instances(normalize(gfp), prob_thresh=0.75)
        # create empty mask image
        mask = np.zeros(labels.shape, dtype = np.uint8)
        # remove any small, unrealistically nuclear objects from seg output
        labels = remove_small_objects(labels, min_size = 200)
        ### image post processing, start at 1 to skip background label
        for i in range(1, np.amax(labels)):
            #needs erosion step to stop merging of labels
            segment = labels == i
            seg_props = regionprops(label(segment), cache = False)
            ### if segment exists, subject to exclusion criteria
            if seg_props:
                ### if segment area is large and elliptical it is probably a missclassified ras cyto (keeping for future use)
                if 2000 <= seg_props[0].area or seg_props[0].eccentricity > 0.95:
                    ### dont bother eroding the large ras cyto masks as will add time
                    #segment = binary_erosion(segment)
                    mask[segment] = 3
                else:
                    segment = binary_erosion(segment)
                    mask[segment] = 1

        # now do the same for the rfp channel
        rfp = imread(fn.replace('channel001', 'channel002'))   
        # predict labels a much lower threshold as rfp signal is dim
        labels, details = seg_model.predict_instances(normalize(rfp), prob_thresh=0.2)

        ### remove small objects (low thresh picks up hot pixels) also reduce number of iterations needed for individual binary erosion
        labels = remove_small_objects(labels, min_size = 200)

        ### iterate over individual segments, eroding and reassigning label to not merge
        for i in range(1, np.amax(labels)):
            #needs erosion step to stop merging of labels
            segment = labels == i
            segment = binary_erosion(segment)
            ## add to main mask
            mask[segment] = 2

        # set filename as mask format (channel099)
        fn = ((images.files('gfp')[frame])).replace('channel001', 'channel099')
        #save out labelled image
        imsave(fn, mask.astype(np.uint8), check_contrast=False)
        
    print('Segmentation complete for', expt, pos)

    
    print('Starting object localisation for', expt, pos)
    
    #if not os.path.exists(f'{root_dir}/{expt}/{pos}/objects.h5'):
    transform_path = f'{root_dir}/{expt}/{pos}/transform_tensor.npy'
    images = DaskOctopusLiteLoader(image_path, 
                                   transforms=transform_path,
                                   crop=(1200,1600), 
                                   remove_background=True)

    ## loading seperate instances of objects so that fl. intensities can be measured
    objects_gfp = btrack.utils.segmentation_to_objects(
        images['mask']==1,
        images['gfp'],
        properties = ('area', 'eccentricity', 'mean_intensity'),
        assign_class_ID = True,
    )
    objects_rfp = btrack.utils.segmentation_to_objects(
        (images['mask']==2)*2,
        images['rfp'],
        properties = ('area', 'eccentricity', 'mean_intensity'),
        assign_class_ID = True,
    )
    ### filter for size
    ### probably redundant two lines but just keeping as insurance
#         objects_gfp = [o for o in objects_gfp if 4000.>o.properties['area']>100.]
#         objects_rfp = [o for o in objects_rfp if 4000.>o.properties['area']>100.]
#         objects_gfp = [obj for obj in objects_gfp if obj.properties['class id'] == 1]
#         objects_rfp = [obj for obj in objects_rfp if obj.properties['class id'] == 2]

    model = load_model('../models/cellx_classifier_stardist.h5')

    bf = images['brightfield']
    gfp = images['gfp']
    rfp = images['rfp']

    print('Classifying objects in', expt, pos)
    objects_gfp = classify_objects(bf, gfp, rfp, objects_gfp, obj_type = 1)
    objects_rfp = classify_objects(bf, gfp, rfp, objects_rfp, obj_type = 2)

    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects_type_1.h5', 'w', obj_type='obj_type_1',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_gfp)    

    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects_type_2.h5', 'w', obj_type='obj_type_2',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_rfp)

    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects.h5', 'w', obj_type='obj_type_1',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_gfp)

    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects.h5', 'a', obj_type='obj_type_2',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_rfp)

    print('Object localisation complete for', expt, pos)

    print('Starting tracking for', expt, pos)
    
    #if not os.path.exists(f'{root_dir}/{expt}/{pos}/tracks.h5'):
    # initialise a tracker session using a context manager
    with btrack.BayesianTracker() as tracker:

        # configure the tracker using a config file
        tracker.configure_from_file(
            '../models/MDCK_config_wildtype.json'
        )
        tracker.max_search_radius = 40

        # append the objects to be tracked
        tracker.append(objects_gfp)

        # set the volume
        tracker.volume=((0, 1600), (0, 1200), (-1e5, 1e5))

        # track them (in interactive mode)
        tracker.track_interactive(step_size=100)

        # generate hypotheses and run the global optimizer
        tracker.optimize()

        tracker.export(f'{root_dir}/{expt}/{pos}/tracks.h5', obj_type='obj_type_1')

        # get the tracks in a format for napari visualization (optional)
#         visaulise_tracks, properties, graph = tracker.to_napari(ndim=2)

#         gfp_tracks = tracker.tracks

    # initialise a tracker session using a context manager
    with btrack.BayesianTracker() as tracker:

        # configure the tracker using a config file
        tracker.configure_from_file(
            '../models/MDCK_config_scribble_sparse.json'
        )
        tracker.max_search_radius = 40

        # append the objects to be tracked
        tracker.append(objects_rfp)

        # set the volume
        tracker.volume=((0, 1600), (0, 1200), (-1e5, 1e5))

        # track them (in interactive mode)
        tracker.track_interactive(step_size=100)

        # generate hypotheses and run the global optimizer
        tracker.optimize()

        tracker.export(f'{root_dir}/{expt}/{pos}/tracks.h5', obj_type='obj_type_2')

#         # get the tracks in a format for napari visualization (optional)
#         visaulise_tracks, properties, graph = tracker.to_napari(ndim=2)

#         rfp_tracks = tracker.tracks

Progress of experiment annotation:   0%|          | 0/13 [00:00<?, ?it/s]

Starting alignment for ND0013/Pos6
Starting segmentation for ND0013 Pos6


  0%|          | 0/2787 [00:00<?, ?it/s]

Segmentation complete for ND0013 Pos6
Starting object localisation for ND0013 Pos6
Using cropping: (1200, 1600)


[INFO][2022/05/15 06:34:01 AM] Localizing objects from segmentation...
[INFO][2022/05/15 06:34:01 AM] Found intensity_image data
[INFO][2022/05/15 06:34:01 AM] Calculating weighted centroids using intensity_image
[INFO][2022/05/15 07:09:53 AM] Objects are of type: <class 'dict'>
[INFO][2022/05/15 07:10:03 AM] ...Found 1329084 objects in 2787 frames.
[INFO][2022/05/15 07:10:03 AM] Localizing objects from segmentation...
[INFO][2022/05/15 07:10:03 AM] Found intensity_image data
[INFO][2022/05/15 07:10:03 AM] Calculating weighted centroids using intensity_image
[INFO][2022/05/15 07:41:37 AM] Objects are of type: <class 'dict'>
[INFO][2022/05/15 07:41:38 AM] ...Found 202177 objects in 2787 frames.


Classifying objects in ND0013 Pos6


  0%|          | 0/2787 [00:00<?, ?it/s]

  0%|          | 0/2787 [00:00<?, ?it/s]

[INFO][2022/05/15 10:30:09 AM] Opening HDF file: /home/nathan/data/kraken/ras/ND0013/Pos6/objects_type_1.h5...
[INFO][2022/05/15 10:30:19 AM] Writing objects/obj_type_1
[INFO][2022/05/15 10:30:19 AM] Writing labels/obj_type_1
[INFO][2022/05/15 10:30:19 AM] Loading objects/obj_type_1 (1329084, 5) (1329084 filtered: None)
[INFO][2022/05/15 10:30:34 AM] Writing properties/obj_type_1/area (1329084,)
[INFO][2022/05/15 10:30:34 AM] Writing properties/obj_type_1/eccentricity (1329084,)
[INFO][2022/05/15 10:30:34 AM] Writing properties/obj_type_1/mean_intensity (1329084,)
[INFO][2022/05/15 10:30:35 AM] Writing properties/obj_type_1/class id (1329084,)
[INFO][2022/05/15 10:30:35 AM] Writing properties/obj_type_1/prob_interphase (1329084,)
[INFO][2022/05/15 10:30:35 AM] Writing properties/obj_type_1/prob_prometaphase (1329084,)
[INFO][2022/05/15 10:30:35 AM] Writing properties/obj_type_1/prob_metaphase (1329084,)
[INFO][2022/05/15 10:30:35 AM] Writing properties/obj_type_1/prob_anaphase (1329084

Object localisation complete for ND0013 Pos6
Starting tracking for ND0013 Pos6


[INFO][2022/05/15 10:31:13 AM] Set volume to ((0, 1600), (0, 1200), (-100000.0, 100000.0))
[INFO][2022/05/15 10:31:13 AM] Starting tracking... 
[INFO][2022/05/15 10:31:14 AM] Tracking objects in frames 0 to 99 (of 2787)...
[INFO][2022/05/15 10:31:14 AM]  - Timing (Bayesian updates: 0.66ms, Linking: 0.17ms)
[INFO][2022/05/15 10:31:14 AM]  - Probabilities (Link: 1.00000, Lost: 0.80191)
[INFO][2022/05/15 10:31:14 AM]  - Stats (Active: 70, Lost: 1122, Conflicts resolved: 156)
[INFO][2022/05/15 10:31:14 AM] Tracking objects in frames 100 to 199 (of 2787)...
[INFO][2022/05/15 10:31:14 AM]  - Timing (Bayesian updates: 0.66ms, Linking: 0.16ms)
[INFO][2022/05/15 10:31:14 AM]  - Probabilities (Link: 1.00000, Lost: 1.00000)
[INFO][2022/05/15 10:31:14 AM]  - Stats (Active: 69, Lost: 2114, Conflicts resolved: 352)
[INFO][2022/05/15 10:31:14 AM] Tracking objects in frames 200 to 299 (of 2787)...
[INFO][2022/05/15 10:31:14 AM]  - Timing (Bayesian updates: 1.09ms, Linking: 0.22ms)
[INFO][2022/05/15 10

[INFO][2022/05/15 10:35:54 AM]  - Stats (Active: 820, Lost: 138707, Conflicts resolved: 48752)
[INFO][2022/05/15 10:35:54 AM] Tracking objects in frames 2400 to 2499 (of 2787)...
[INFO][2022/05/15 10:36:18 AM]  - Timing (Bayesian updates: 221.47ms, Linking: 4.32ms)
[INFO][2022/05/15 10:36:18 AM]  - Probabilities (Link: 0.99998, Lost: 1.00000)
[INFO][2022/05/15 10:36:18 AM]  - Stats (Active: 782, Lost: 146017, Conflicts resolved: 52117)
[INFO][2022/05/15 10:36:18 AM] Tracking objects in frames 2500 to 2599 (of 2787)...
[INFO][2022/05/15 10:36:42 AM]  - Timing (Bayesian updates: 222.36ms, Linking: 4.32ms)
[INFO][2022/05/15 10:36:42 AM]  - Probabilities (Link: 1.00000, Lost: 0.99746)
[INFO][2022/05/15 10:36:42 AM]  - Stats (Active: 791, Lost: 151878, Conflicts resolved: 55047)
[INFO][2022/05/15 10:36:42 AM] Tracking objects in frames 2600 to 2699 (of 2787)...
[INFO][2022/05/15 10:37:05 AM]  - Timing (Bayesian updates: 223.41ms, Linking: 4.17ms)
[INFO][2022/05/15 10:37:05 AM]  - Probabilit

[INFO][2022/05/15 10:38:01 AM]  - Probabilities (Link: 1.00000, Lost: 1.00000)
[INFO][2022/05/15 10:38:01 AM]  - Stats (Active: 76, Lost: 14455, Conflicts resolved: 769)
[INFO][2022/05/15 10:38:01 AM] Tracking objects in frames 1200 to 1299 (of 2787)...
[INFO][2022/05/15 10:38:01 AM]  - Timing (Bayesian updates: 0.92ms, Linking: 0.20ms)
[INFO][2022/05/15 10:38:01 AM]  - Probabilities (Link: 0.99945, Lost: 1.00000)
[INFO][2022/05/15 10:38:01 AM]  - Stats (Active: 81, Lost: 15776, Conflicts resolved: 887)
[INFO][2022/05/15 10:38:01 AM] Tracking objects in frames 1300 to 1399 (of 2787)...
[INFO][2022/05/15 10:38:02 AM]  - Timing (Bayesian updates: 1.38ms, Linking: 0.25ms)
[INFO][2022/05/15 10:38:02 AM]  - Probabilities (Link: 0.97856, Lost: 1.00000)
[INFO][2022/05/15 10:38:02 AM]  - Stats (Active: 95, Lost: 17205, Conflicts resolved: 1043)
[INFO][2022/05/15 10:38:02 AM] Tracking objects in frames 1400 to 1499 (of 2787)...
[INFO][2022/05/15 10:38:02 AM]  - Timing (Bayesian updates: 1.65ms,

# Finishing nd0013