In [1]:
import sys
import cv2
import argparse
import numpy as np
import multiprocessing as mp
from skimage import (
    feature, filters, measure, morphology, segmentation
)
import modules.utils as utils
from functools import partial
from modules.Frame import Frame

## Functions

In [2]:
def generate_tile_paths(path, frame_id, starts, name_format):
    paths = [f"{path}/{name_format}" % (frame_id + j - 1) for j in starts]
    return paths

## Parameters

In [3]:
in_path = '../../P14_MultipleMyeloma/data/0ABBC03/image'
out_path = '../../P14_MultipleMyeloma/data/0ABBC03'
n_frames = 10
channels = ['DAPI', 'TRITC', 'Cy5', 'FITC']
starts = [1, 2305, 4609, 9217]
offset = 79
name_format = 'Tile%06d.tif'
n_threads = 4

# Segmentation Parameters
params = {
    'tophat_size': 45,
    'opening_size': 5,
    'blur_size': 5,
    'blur_sigma': 2,
    'thresh_size': 25,
    'thresh_offset': [-2000] * 4,
    'min_dist': 7,
    'size_min_thresh': 10,
    'seed_ch': 'DAPI',
    'save_mask':True
}

## Segment Events

In [10]:
def segment_frame(frame, params):
    "segment frame to identify all events"
    
    # Preparing input
    image_copy = frame.image.copy()
    image_copy = image_copy.astype('float32')
    
    # Preparing segmentation parameters
    tophat_kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (params['tophat_size'], params['tophat_size']))
    opening_kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (params['opening_size'], params['opening_size']))
    
    # Preprocessing and segmenting channels separately
    target_mask = np.zeros(image_copy.shape[:2], dtype=image_copy.dtype)
    for ch in frame.channels:
        i = frame.get_ch(ch)

        image_copy[..., i] = cv2.morphologyEx(
            image_copy[..., i],
            cv2.MORPH_TOPHAT,
            tophat_kernel
        )
        image_copy[..., i] = cv2.GaussianBlur(
            image_copy[..., i],
            (params['blur_size'], params['blur_size']),
            params['blur_sigma']
        )
        
        thresh_image = filters.threshold_local(
            image=image_copy[..., i],
            method='mean',
            block_size=params['thresh_size'],
            offset=params['thresh_offset'][i]
        )
        image_copy[..., i] = image_copy[..., i] > thresh_image
        image_copy[..., i] = utils.fill_holes(
            image_copy[..., i].astype('uint8'))
        target_mask = cv2.bitwise_or(target_mask, image_copy[..., i])
    
    # Postprocessing the masks
    target_mask = cv2.morphologyEx(target_mask, cv2.MORPH_OPEN, opening_kernel)
    target_dist = cv2.distanceTransform(
        target_mask.astype('uint8'), cv2.DIST_L2,3, cv2.CV_32F)
    
    # Generating the seeds
    seed_mask = cv2.morphologyEx(
        image_copy[..., frame.get_ch(params['seed_ch'])],
        cv2.MORPH_OPEN, opening_kernel
    )
    seed_dist = cv2.distanceTransform(
        seed_mask.astype('uint8'), cv2.DIST_L2, 3, cv2.CV_32F)
    local_max_coords = feature.peak_local_max(
        seed_dist, min_distance=params['min_dist'], exclude_border=False)
    seeds = np.zeros(seed_mask.shape, dtype=bool)
    seeds[tuple(local_max_coords.T)] = True
    seeds = measure.label(seeds)

    # Watershed segmentation
    event_mask = segmentation.watershed(-target_dist, seeds, mask=target_mask)
    frame.mask = event_mask
    features = utils.calc_basic_features(frame)
    return(features)

In [11]:
# Loading data
frames = []
for i in range(n_frames):
    frame_id = i + offset + 1
    paths = generate_tile_paths(
        path=in_path, frame_id=frame_id, starts=starts,
        name_format=name_format)
    frame = Frame(frame_id=frame_id, channels=channels)
    if frame.is_edge():
        continue
    frame.readImage(paths=paths)
    frames.append(frame)

In [6]:
# Regular processing
#utils.pltImageShow2by2(frames[0].image, title='Raw Images', size=(14,10))
result = segment_frame(frame=frames[0], params=params)

In [12]:
# Parallel processing
n_proc = n_threads if n_threads > 0 else mp.cpu_count()
pool = mp.Pool(n_proc)
result = pool.map(partial(segment_frame, params=params), frames)
print(len(result))

10


In [18]:
import pandas as pd

In [19]:
for item in result:
    print(len(item))

print(len(pd.concat(result)))

1608
1661
1760
1749
1797
1750
1722
1685
1620
1585
16937
