# Phenotype classifcation using CellX 

This notebook shows how to take segmented time lapse microscopy images and use h2b fluorescence markers to classfiy mitotic state of the cell cycle. 

The sections of this notebook are as follows:

1. Load images
2. Localise the objects
3. Classify the objects
4. Batch process

The data used in this notebook is timelapse microscopy data with h2b-gfp/rfp markers that show the spatial extent of the nucleus and it's mitotic state. 

This notebook uses the dask octopuslite image loader from the CellX/Lowe lab project.

In [1]:
from octopuslite import DaskOctopusLiteLoader
import btrack
from tqdm.auto import tqdm
import numpy as np
from scipy.special import softmax
import os
import matplotlib.pyplot as plt
from skimage.io import imread, imshow
from cellx import load_model
from cellx.tools.image import InfinitePaddedImage
from skimage.transform import resize
%matplotlib inline
plt.rcParams['figure.figsize'] = [18,8]

## 1. Load segmentation images

#### *Important:* from this point on you will need to be consistent with the use of cropping and alignment. 
Using a previously generated alignment transformation will aid greatly in the tracking notebook, which depends on the object localisation performed in this notebook. Cropping your images will ensure that no border effects from the translational shift are seen. 

In [2]:
# load images
expt = 'ND0022'
pos = 'Pos12'
root_dir = '/home/nathan/data/kraken/ras'
image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
transform_path = f'{root_dir}/{expt}/{pos}/gfp_clipped_transform_tensor.npy'
images = DaskOctopusLiteLoader(image_path, 
                               transforms=transform_path,
                               crop=(1200,1600), 
                               remove_background=True)

Using cropping: (1200, 1600)


## 2. Localise the objects
We need to also measure the mean intensity regionprops parameter in order to differentiate object class, for which we need to provide an image to measure. This means we need to provide the segmentation images twice: once to find the centroid and once to measure the pixel intensity. 

In [4]:
images.channels

[<Channels.BRIGHTFIELD: 0>,
 <Channels.GFP: 1>,
 <Channels.RFP: 2>,
 <Channels.IRFP: 3>,
 <Channels.MASK_B: 94>,
 <Channels.MASK_A: 95>,
 <Channels.MASK_IRFP: 96>,
 <Channels.MASK_RFP: 97>,
 <Channels.MASK_GFP: 98>,
 <Channels.MASK_2CH: 99>]

In [22]:
objects = btrack.utils.segmentation_to_objects(
    images['mask_2ch'],
    properties = ('area', 'eccentricity'),
    assign_class_ID = True,
)

[INFO][2022/03/30 06:56:33 PM] Localizing objects from segmentation...
[INFO][2022/03/30 07:02:30 PM] Objects are of type: <class 'dict'>
[INFO][2022/03/30 07:02:34 PM] ...Found 446249 objects in 1098 frames.


#### Can also assign measured values from raw image to each segment using `skimage.measure.regionprops` parameters
But also need to load the raw images to be measured first. Cannot currently save out `intensity_image` parameter to object file.

In [24]:
detailed_objects = btrack.utils.segmentation_to_objects(
    images['mask_2ch'], 
    images['gfp'],
    properties = ('area', 'mean_intensity', 'intensity_image'), 
)

[INFO][2022/03/30 07:10:06 PM] Localizing objects from segmentation...
[INFO][2022/03/30 07:10:06 PM] Found intensity_image data
[INFO][2022/03/30 07:10:06 PM] Calculating weighted centroids using intensity_image
[INFO][2022/03/30 07:24:37 PM] Objects are of type: <class 'dict'>
[INFO][2022/03/30 07:24:40 PM] ...Found 446249 objects in 1098 frames.


In [None]:
detailed_objects[0]

Example image showing PCNA-iRFP morphology 

In [None]:
imshow(detailed_objects[0].properties['intensity_image'])

## 2b. Differentiate the objects based on class ID

In [6]:
objects_gfp = [obj for obj in objects if obj.properties['class id'] == 1]
objects_rfp = [obj for obj in objects if obj.properties['class id'] == 2]

## 3. Classify the objects 

Load model

In [7]:
model = load_model('../models/cellx_classifier_stardist.h5')

Define normalisation functions

In [8]:
def normalize_channels(x):

    for dim in range(x.shape[-1]):
        x[..., dim] = normalize(x[..., dim])
        
    return x

def normalize(x):

    xf = x.astype(np.float32)
    mx = np.mean(xf)
    sd = np.max([np.std(xf), 1./np.prod(x.shape)])

    return (xf - mx) / sd

Define classifier function

In [9]:
def classify_objects(image,  gfp, rfp, objects, obj_type):
    
    # define stages of cell cycle to classify (dependent on model type)
    LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]
    
    # iterate over frames
    for n in tqdm(range(image.shape[0])):
        
        # only select objects if in frame
        _objects = [o for o in objects if o.t == n]
        
        # empty placeholder arrays
        crops = []
        to_update = []
        
        # select h2b channel to aid in classification
        fp = gfp if obj_type == 1 else rfp
        
        # create stack by computing each frame of dask array input
        frame = np.stack(
            [image[n, ...].compute(), fp[n, ...].compute()], 
            axis=-1,) 
        
        # create padded image for network
        vol = InfinitePaddedImage(frame, mode = 'reflect')
        
        # iterate over objects 
        for obj in _objects:
            
            # create coords for image slice
            xs = slice(int(obj.x-40), int(obj.x+40), 1)
            ys = slice(int(obj.y-40), int(obj.y+40), 1)
            
            # crop image
            crop = vol[ys, xs, :]
            crop = resize(crop, (64, 64), preserve_range=True).astype(np.float32)
            
            # normalise image
            if crop.shape == (64 ,64, 2):
                crops.append(normalize_channels(crop))
                to_update.append(obj)
            else:
                print(crop.shape)
                
        if not crops:
            continue
            
        # use classifcation model to predict
        pred = model.predict(np.stack(crops, axis=0))
        
        # check correct number of predictions
        assert pred.shape[0] == len(_objects)
        
        # assign labels to objects
        for idx in range(pred.shape[0]):
            obj = _objects[idx]
            
            # assigning details of prediction
            pred_label = np.argmax(pred[idx, ...])
            pred_softmax = softmax(pred[idx, ...])
            logits = {f"prob_{k}": pred_softmax[ki] for ki, k in enumerate(LABELS)}
            
            # write out
            obj.label = pred_label
            obj.properties = logits

    return objects

#### Load raw images for classifier, a colour channel dependent on `obj_type` needed too (i.e. GFP is `obj_type = 1`, RFP is `obj_type = 2`

In [10]:
bf = images['brightfield']
gfp = images['gfp']
rfp = images['rfp']

#### Classify objects

In [None]:
objects_gfp = classify_objects(bf, gfp, rfp, objects_gfp, obj_type = 1)


In [14]:
objects_rfp = classify_objects(bf, gfp, rfp, objects_rfp, obj_type = 2)

  0%|          | 0/1098 [00:00<?, ?it/s]

#### Inspect an example object

In [12]:
objects_gfp[0]

Unnamed: 0,ID,x,y,z,t,dummy,states,label,prob,area,class id,prob_interphase,prob_prometaphase,prob_metaphase,prob_anaphase,prob_apoptosis
0,0,160.435897,4.226496,0.0,0,False,0,0,0.0,234,1,1.0,1.944189e-10,1.602896e-10,2.379132e-11,9.544462e-10


#### Save out classified GFP objects

In [13]:
with btrack.dataio.HDF5FileHandler(
    f'{root_dir}/{expt}/{pos}/objects_type_1.h5', 'w', obj_type='obj_type_1',
) as hdf:
    #hdf.write_segmentation(masks['mask'])
    hdf.write_objects(objects_gfp)

[INFO][2022/03/30 01:29:58 PM] Opening HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12/objects_type_1.h5...
[INFO][2022/03/30 01:30:03 PM] Writing objects/obj_type_1
[INFO][2022/03/30 01:30:03 PM] Writing labels/obj_type_1
[INFO][2022/03/30 01:30:03 PM] Loading objects/obj_type_1 (386145, 5) (386145 filtered: None)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/area (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/class id (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/prob_interphase (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/prob_prometaphase (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/prob_metaphase (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/prob_anaphase (386145,)
[INFO][2022/03/30 01:30:08 PM] Writing properties/obj_type_1/prob_apoptosis (386145,)
[INFO][2022/03/30 01:30:08 PM] Closing HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12

#### Save out classified RFP objects

In [15]:
with btrack.dataio.HDF5FileHandler(
    f'{root_dir}/{expt}/{pos}/objects_type_2.h5', 'w', obj_type='obj_type_2',
) as hdf:
    #hdf.write_segmentation(masks['mask'])
    hdf.write_objects(objects_rfp)

[INFO][2022/03/30 02:38:14 PM] Opening HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12/objects_type_2.h5...
[INFO][2022/03/30 02:38:15 PM] Writing objects/obj_type_2
[INFO][2022/03/30 02:38:15 PM] Writing labels/obj_type_2
[INFO][2022/03/30 02:38:15 PM] Loading objects/obj_type_2 (38765, 5) (38765 filtered: None)
[INFO][2022/03/30 02:38:15 PM] Writing properties/obj_type_2/area (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/class id (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/prob_interphase (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/prob_prometaphase (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/prob_metaphase (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/prob_anaphase (38765,)
[INFO][2022/03/30 02:38:16 PM] Writing properties/obj_type_2/prob_apoptosis (38765,)
[INFO][2022/03/30 02:38:16 PM] Closing HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12/objects_

# Saving out as one file

In [20]:
with btrack.dataio.HDF5FileHandler(
    f'{root_dir}/{expt}/{pos}/objects.h5', 'w', obj_type='obj_type_1',
) as hdf:
    #hdf.write_segmentation(masks['mask'])
    hdf.write_objects(objects_gfp)
    
with btrack.dataio.HDF5FileHandler(
    f'{root_dir}/{expt}/{pos}/objects.h5', 'a', obj_type='obj_type_2',
) as hdf:
    #hdf.write_segmentation(masks['mask'])
    hdf.write_objects(objects_rfp)

[INFO][2022/03/30 06:27:43 PM] Opening HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12/objects.h5...
[INFO][2022/03/30 06:27:47 PM] Writing objects/obj_type_1
[INFO][2022/03/30 06:27:47 PM] Writing labels/obj_type_1
[INFO][2022/03/30 06:27:47 PM] Loading objects/obj_type_1 (386145, 5) (386145 filtered: None)
[INFO][2022/03/30 06:27:52 PM] Writing properties/obj_type_1/area (386145,)
[INFO][2022/03/30 06:27:52 PM] Writing properties/obj_type_1/class id (386145,)
[INFO][2022/03/30 06:27:52 PM] Writing properties/obj_type_1/prob_interphase (386145,)
[INFO][2022/03/30 06:27:52 PM] Writing properties/obj_type_1/prob_prometaphase (386145,)
[INFO][2022/03/30 06:27:52 PM] Writing properties/obj_type_1/prob_metaphase (386145,)
[INFO][2022/03/30 06:27:53 PM] Writing properties/obj_type_1/prob_anaphase (386145,)
[INFO][2022/03/30 06:27:53 PM] Writing properties/obj_type_1/prob_apoptosis (386145,)
[INFO][2022/03/30 06:27:53 PM] Closing HDF file: /home/nathan/data/kraken/ras/ND0022/Pos12/object

# 4. Batch process
Iterate over many experiments and positions (need to ensure you define normalisation and classification functions above first)

In [12]:
import time, re

In [26]:
time.sleep(14400)

In [None]:
%%timeit
root_dir = '/home/nathan/data/kraken/ras'
expt_list = sorted([expt for expt in os.listdir(root_dir) 
                    if 'ND' in expt and os.path.isdir(os.path.join(root_dir, expt))
                    and not '21' in expt 
                    and not '20' in expt], 
                    key = lambda x: [int(y) for y in re.findall(r'\d+', x)])
pos_list = 'all'
overwrite = True

for expt in tqdm(expt_list):
    
    # Find all positions in that experiment, if pos_list is all then it finds all positions
    if pos_list == 'all':
        pos_list = sorted([pos for pos in os.listdir(f'{root_dir}/{expt}') 
                    if 'Pos' in pos 
                    and os.path.isdir(f'{root_dir}/{expt}/{pos}')], 
                    key = lambda x: [int(y) for y in re.findall(r'\d+', x)])  

    ### Iterate over all positions in that experiment
    for pos in tqdm(pos_list):

        ### check if overwrite param is false check if raw directory already created and if type of transform file already exists and decide whether to skip pos
        if not overwrite and glob.glob(f'{root_dir}/{expt}/{pos}/*objects*.h5'):
            print(glob.glob(f'{root_dir}/{expt}/{pos}/*objects*.h5'), f'file found, skipping {expt}/{pos}')
            continue

        print(f'Starting {expt}/{pos}')
        # load segmentation images and apply necessary transforms and crops
        image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
        transform_path = f'{root_dir}/{expt}/{pos}/mask_reversed_clipped_transform_tensor.npy' #gfp_transform_tensor.npy'
        images = DaskOctopusLiteLoader(image_path, 
                           transforms=transform_path,
                           crop=(1200,1600), 
                           remove_background=False)
        
        # ID the objects in each segmentation image and assign option properties to them
        objects = btrack.utils.segmentation_to_objects(
                                                        images['mask'], 
                                                        properties = ('area',),
                                                        assign_class_ID = True
        )

        # differentiate the objects based on class ID
        objects_gfp = [obj for obj in objects if obj.properties['class id'] == 1]
        objects_rfp = [obj for obj in objects if obj.properties['class id'] == 2]

        # load classifcation model and define labels
        model = load_model('../models/cellx_classifier_stardist.h5')
        LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]
        
        # load images for classifcation
        bf = images['brightfield']
        gfp = images['gfp']
        rfp = images['rfp']

        # classify objects
        print("Classifying objects")
        objects_gfp = classify_objects(bf, gfp, rfp, objects_gfp, obj_type = 1)
        objects_rfp = classify_objects(bf, gfp, rfp, objects_rfp, obj_type = 2)

        # save out classified objects as segmentation h5 file
        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects_type_1.h5', 'w', obj_type='obj_type_1',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_gfp)
        with btrack.dataio.HDF5FileHandler(
            f'{root_dir}/{expt}/{pos}/objects_type_2.h5', 'w', obj_type='obj_type_2',
        ) as hdf:
            #hdf.write_segmentation(masks['mask'])
            hdf.write_objects(objects_rfp)     

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

Starting ND0013/Pos0
Using cropping: (1200, 1600)


[INFO][2022/03/23 07:02:56 PM] Localizing objects from segmentation...


In [33]:
transform_path

'/home/nathan/data/kraken/ras/ND0013/Pos3/mask_reversed_transform_tensor_clipped.npy'

In [31]:
np.load(transform_path)

array([[[ 1.00000000e+00,  0.00000000e+00, -6.22149428e+01],
        [ 0.00000000e+00,  1.00000000e+00,  1.00000000e+01],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       [[ 1.00000000e+00,  0.00000000e+00, -5.46285883e+01],
        [ 0.00000000e+00,  1.00000000e+00,  1.00000000e+01],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       [[ 1.00000000e+00,  0.00000000e+00, -5.72211371e+01],
        [ 0.00000000e+00,  1.00000000e+00,  1.00000000e+01],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       ...,

       [[ 1.00000000e+00,  0.00000000e+00, -5.03624896e-03],
        [ 0.00000000e+00,  1.00000000e+00, -7.44898619e-01],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       [[ 1.00000000e+00,  0.00000000e+00,  2.17864873e+00],
        [ 0.00000000e+00,  1.00000000e+00,  9.02855882e-02],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       [[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00],
 

In [34]:
images = DaskOctopusLiteLoader(image_path, 
                           transforms=transform_path,
                           crop=(1200,1600), 
                           remove_background=False)

Using cropping: (1200, 1600)


# Parallel batch process

In [7]:
def classify(pos):
    ### check if overwrite param is false check if raw directory already created and if type of transform file already exists and decide whether to skip pos
    if not overwrite and glob.glob(f'{root_dir}/{expt}/{pos}/*objects*.h5'):
        print(glob.glob(f'{root_dir}/{expt}/{pos}/*objects*.h5'), f'file found, skipping {expt}/{pos}')
        return

    print(f'Starting {expt}/{pos}')
    # load segmentation images and apply necessary transforms and crops
    image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
    transform_path = f'{root_dir}/{expt}/{pos}/gfp_transform_tensor.npy'
    images = DaskOctopusLiteLoader(image_path, 
                       transforms=transform_path,
                       crop=(1200,1600), 
                       remove_background=False)

    # ID the objects in each segmentation image and assign option properties to them
    objects = btrack.utils.segmentation_to_objects(
        images['mask'], images['mask'],
        properties = ('area', 'max_intensity', ),
    )

    # differentiate the objects based on class ID
    objects_gfp = [obj for obj in objects if obj.properties['max_intensity'] == 1]
    objects_rfp = [obj for obj in objects if obj.properties['max_intensity'] == 2]

    # load classifcation model and define labels
    model = load_model('../models/cellx_classifier_stardist.h5')
    LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]

    # load images for classifcation
    bf = images['brightfield']
    gfp = images['gfp']
    rfp = images['rfp']

    # classify objects
    print("Classifying objects")
    objects_gfp = classify_objects(bf, objects_gfp, obj_type = 1)
    objects_rfp = classify_objects(bf, objects_rfp, obj_type = 2)

    # save out classified objects as segmentation h5 file
    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects_type_1.h5', 'w', obj_type='obj_type_1',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_gfp)
    with btrack.dataio.HDF5FileHandler(
        f'{root_dir}/{expt}/{pos}/objects_type_2.h5', 'w', obj_type='obj_type_2',
    ) as hdf:
        #hdf.write_segmentation(masks['mask'])
        hdf.write_objects(objects_rfp)     

    return

In [8]:
from multiprocessing import Pool
cpus = os.cpu_count()
cpus

12

In [11]:
pos_list = [pos for pos in os.listdir(f'{root_dir}/{expt}') 
                    if 'Pos' in pos 
                    and os.path.isdir(f'{root_dir}/{expt}/{pos}')]
pos_list

['Pos5',
 'Pos11',
 'Pos3',
 'Pos1',
 'Pos8',
 'Pos10',
 'Pos0',
 'Pos2',
 'Pos6',
 'Pos7',
 'Pos9',
 'Pos4']

In [13]:
for expt in expt_list:
    if __name__ == '__main__':
        with Pool(cpus) as p:
            p.map(classify, pos_list)

Starting ND0010/Pos3Starting ND0010/Pos4Starting ND0010/Pos11Starting ND0010/Pos5Starting ND0010/Pos7Starting ND0010/Pos1Starting ND0010/Pos10Starting ND0010/Pos0Starting ND0010/Pos8
Starting ND0010/Pos2
Starting ND0010/Pos9

Starting ND0010/Pos6


Using cropping: (1200, 1600)Using cropping: (1200, 1600)

Using cropping: (1200, 1600)Using cropping: (1200, 1600)Using cropping: (1200, 1600)Using cropping: (1200, 1600)
Using cropping: (1200, 1600)

Using cropping: (1200, 1600)




Using cropping: (1200, 1600)
Using cropping: (1200, 1600)

Using cropping: (1200, 1600)

Using cropping: (1200, 1600)



[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO][2022/02/07 05:18:20 PM] Found intensity_image data
[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO][2022/02/07 05:18:20 PM] Calculating weighted centroids using intensity_image
[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO][2022/02/07 05:18:20 PM] Found intensity_image data
[INFO][2022/02/07 05:18:20 PM] Calculating weighted centroids using intensity_image
[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO][2022/02/07 05:18:20 PM] Found intensity_image data
[INFO][2022/02/07 05:18:20 PM] Found intensity_image data
[INFO][2022/02/07 05:18:20 PM] Calculating weighted centroids using intensity_image
[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO][2022/02/07 05:18:20 PM] Calculating weighted centroids using intensity_image
[INFO][2022/02/07 05:18:20 PM] Localizing objects from segmentation...
[INFO]

KeyboardInterrupt: 