# Phenotype classifcation using CellX 

This notebook shows how to take segmented time lapse microscopy images and use h2b fluorescence markers to classfiy mitotic state of the cell cycle. 

The sections of this notebook are as follows:

1. Load images
2. Localise the objects
3. Classify the objects
4. Filter the objects
5. Run btrack, uniting the objects locations over time

The data used in this notebook is timelapse microscopy data with h2b-gfp/rfp markers that show the spatial extent of the nucleus and it's mitotic state. 

This notebook uses the dask octopuslite image loader from the CellX/Lowe lab project.

In [1]:
from octopuslite import DaskOctopusLiteLoader
import btrack
from tqdm import tqdm
import numpy as np
from scipy.special import softmax
import os
import matplotlib.pyplot as plt
from skimage.io import imread, imshow
%matplotlib inline
plt.rcParams['figure.figsize'] = [18,8]

In [7]:
def image_generator(files, crop = None):
    """Image generator for iterative procesess"""
    #get dims
    shape = imread(files[0]).shape
    dims = imread(files[0]).ndim
    if crop == None:
        for filename in files:
            img = imread(filename)
            yield img
    else:
        cslice = lambda d: slice(
            int((shape[d] - crop[d]) // 2),
            int((shape[d] - crop[d]) // 2 + crop[d]))
        crops = tuple([cslice(d) for d in range(dims)])
        for filename in files:
            img = imread(filename)[crops]
            yield img

## 1. Load segmentation images

In [2]:
# load images
expt = 'ND0009'
pos = 'Pos3'
image_path = f'/home/nathan/data/kraken/ras/{expt}/{pos}/{pos}_stardist_masks'
masks = DaskOctopusLiteLoader(image_path, crop=(1200,1600), remove_background=True)
## efficiently load segmentation images by providing the path to the generator function
segmentation_gfp = image_generator(masks.files('mask_gfp'), crop=(1200,1600))
segmentation_rfp = image_generator(masks.files('mask_rfp'), crop=(1200,1600))

Using cropping: (1200, 1600)


## 2. Localise the objects

#### GFP object localisation

In [None]:
objects_gfp = btrack.utils.segmentation_to_objects(
    segmentation_gfp,
    properties = ('area', ),
)

#### (Optional) RFP object localisation

In [5]:
objects_rfp = btrack.utils.segmentation_to_objects(
    segmentation_rfp,
    properties = ('area', ),
)

[INFO][2022/01/20 01:13:59 PM] Localizing objects from segmentation...
[INFO][2022/01/20 01:22:29 PM] Objects are of type: <class 'dict'>
[INFO][2022/01/20 01:22:29 PM] ...Found 30102 objects in 1072 frames.


#### Can also assign measured values to each segment using `skimage.measure.regionprops` parameters
But also need to load the images to be measured first

In [None]:
images = DaskOctopusLiteLoader(f'/home/nathan/data/kraken/ras/{expt}/{pos}/{pos}_aligned', crop = (1200,1600))
gfp = image_generator(images.files('gfp'), crop=(1200,1600))
objects_gfp = btrack.utils.segmentation_to_objects(
    segmentation_gfp,
    gfp,
    properties = ('area', 'mean_intensity'),
)

## 3. Classify the objects 

In [6]:
from cellx import load_model
from cellx.tools.image import InfinitePaddedImage
from skimage.transform import resize

model = load_model('/home/nathan/analysis/segment-classify-track/models/cellx_classifier_stardist.h5')

In [7]:
LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]

In [8]:
def normalize_channels(x):

    for dim in range(x.shape[-1]):
        x[..., dim] = normalize(x[..., dim])
        
    return x

def normalize(x):

    xf = x.astype(np.float32)
    mx = np.mean(xf)
    sd = np.max([np.std(xf), 1./np.prod(x.shape)])

    return (xf - mx) / sd

In [9]:
def classify_objects(image, objects, obj_type=1):
    labels = []
    for n in tqdm(range(image.shape[0])):

        _objects = [o for o in objects if o.t == n]

        crops = []
        to_update = []
        
        
        fp = gfp if obj_type == 1 else rfp


        frame = np.stack(
            [image[n, ...].compute(), fp[n, ...].compute()], 
            axis=-1,
        ) 

        vol = InfinitePaddedImage(frame, mode = 'reflect')

        for obj in _objects:
            xs = slice(int(obj.x-40), int(obj.x+40), 1)
            ys = slice(int(obj.y-40), int(obj.y+40), 1)

            crop = vol[ys, xs, :]
            crop = resize(crop, (64, 64), preserve_range=True).astype(np.float32)

            if crop.shape == (64 ,64, 2):
                crops.append(normalize_channels(crop))
                to_update.append(obj)
            else:
                print(crop.shape)

        if not crops:
            continue


        pred = model.predict(np.stack(crops, axis=0))

        assert pred.shape[0] == len(_objects)
        for idx in range(pred.shape[0]):
            obj = _objects[idx]

            pred_label = np.argmax(pred[idx, ...])
            pred_softmax = softmax(pred[idx, ...])

            logits = {f"prob_{k}": pred_softmax[ki] for ki, k in enumerate(LABELS)}

            obj.label = pred_label
            obj.properties = logits
            
    return objects

#### Load raw images for classifier

In [None]:
images = DaskOctopusLiteLoader(f'/home/nathan/data/kraken/ras/{expt}/{pos}/{pos}_aligned', crop = (1200,1600))
bf = images['brightfield']
gfp = images['gfp']
rfp = images['rfp']

#### Classify objects

In [None]:
objects_gfp = classify_objects(bf, objects_gfp, obj_type = 1)
objects_rfp = classify_objects(bf, objects_rfp, obj_type = 2)

#### Inspect objects

In [None]:
objects_gfp[0]

#### Save out classified GFP objects

In [None]:
with btrack.dataio.HDF5FileHandler(
    os.path.join(f'/home/nathan/data/kraken/ras/{expt}/{pos}/segmented_gfp.h5'), 'w', obj_type='obj_type_1',
) as hdf:
    hdf.write_segmentation(masks['mask_irfp'])
    hdf.write_objects(objects_gfp)

#### Save out classified RFP objects

In [None]:
with btrack.dataio.HDF5FileHandler(
    os.path.join(f'/home/nathan/data/kraken/ras/{expt}/{pos}/segmented.h5'), 'w', obj_type='obj_type_2',
) as hdf:
    hdf.write_segmentation(masks['mask_rfp'])
    hdf.write_objects(objects_rfp)

## 4. Filter the objects 

Based on segments that are too small to feasibly be cells

In [None]:
filtered_gfp_objects = [o for o in objects_gfp if o.properties['area']>100.]
filtered_rfp_objects = [o for o in objects_rfp if o.properties['area']>100.]

## 5. Run btrack  

Unite each object with it's subsequent position at the following time point and export as a tracking file

#### For GFP objects

In [None]:
# initialise a tracker session using a context manager
with btrack.BayesianTracker() as tracker:

    # configure the tracker using a config file
    tracker.configure_from_file(
        "/home/nathan/analysis/BayesianTracker/models/MDCK_config_new.json"
    )
    tracker.max_search_radius = 40

    # append the objects to be tracked
    tracker.append(filtered_gfp_objects)

    # set the volume
    tracker.volume=((0, 1200), (0, 1600), (-1e5, 1e5))

    # track them (in interactive mode)
    tracker.track_interactive(step_size=100)

    # generate hypotheses and run the global optimizer
    tracker.optimize()

    tracker.export((f'/home/nathan/data/kraken/ras/{expt}/{pos}/tracks.h5'), obj_type='obj_type_1')

    # get the tracks in a format for napari visualization (optional)
    data, properties, graph = tracker.to_napari(ndim=2)
    
    gfp_tracks = tracker.tracks

#### For RFP objects

In [None]:
# initialise a tracker session using a context manager
with btrack.BayesianTracker() as tracker:

    # configure the tracker using a config file
    tracker.configure_from_file(
        "/home/nathan/analysis/BayesianTracker/models/MDCK_config_new.json"
    )
    tracker.max_search_radius = 40

    # append the objects to be tracked
    tracker.append(filtered_rfp_objects)

    # set the volume
    tracker.volume=((0, 1200), (0, 1600), (-1e5, 1e5))

    # track them (in interactive mode)
    tracker.track_interactive(step_size=100)

    # generate hypotheses and run the global optimizer
    tracker.optimize()

    tracker.export((f'/home/nathan/data/kraken/ras/{expt}/{pos}/tracks.h5'), obj_type='obj_type_2')

    # get the tracks in a format for napari visualization (optional)
    data, properties, graph = tracker.to_napari(ndim=2)
    
    rfp_tracks = tracker.tracks

In [None]:
gfp_tracks[0]

In [None]:
rfp_tracks[0]