In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import os
import cv2
from PIL import Image
import json
import math
import numpy as np
import matplotlib.pylab as plt
from matplotlib.path import Path as mpPath
from tqdm import tqdm

In [None]:
import joblib
from joblib import Parallel, delayed
import contextlib

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    class TqdmBatchCompletionCallBack(joblib.parallel.BatchCompletionCallBack):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            
        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)
        
    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallBack
    
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()

In [None]:
def create_cell_dict(path):
    with open(path) as f:
        data = json.load(f)
        
    ids = list()
    for i, img_dict in enumerate(data['images']):
        ids.append(data['images'][i]['id'])
        
    d = {k: {'segmentation': [], 'bbox': [], 'path': []} for k in ids}
    
    for i in range(len(d)):
        d[data['images'][i]['id']]['path'].append(data['images'][i]['file_name'])
        
    for key in data['annotations'].keys():
        img_id = data['annotations'][key]['image_id']
        seg = data['annotations'][key]['segmentation'][0]
        bbox = data['annotations'][key]['bbox']
        
        d[img_id]['segmentation'].append(seg)
        d[img_id]['bbox'].append(bbox)
    
    return d


def get_cell_mask(raw_segmentation, binary_mask=True):
    dtype = np.uint8 if binary_mask else np.uint16
    array = np.zeros((520, 704), dtype=dtype)
    
    for lbl, cell_mask in enumerate(raw_segmentation, 1):
        x = cell_mask[0::2]
        y = cell_mask[1::2]
        
        arr = [(x, y) for (x, y) in zip(y, x)]
        vertices = np.asarray(arr)
        path = mpPath(vertices)
        x, y = np.mgrid[:520, :704]
        
        # mesh grid to a list of points
        points = np.vstack([x.ravel(), y.ravel()]).T
        
        # select points included in the path
        mask = path.contains_points(points)
        
        if not binary_mask:
            mask = np.where(mask, lbl, 0)
        
        img_mask = mask.reshape(x.shape).astype(dtype)
        array += img_mask
        
    if binary_mask:
        array = np.clip(array, 0, 1)  # could there be values < 0 or > 1 ?
        
    return array


def get_cell_contour(raw_segmentation):
    mask = get_cell_mask(raw_segmentation, binary_mask=False)
    
    labels = np.unique(mask)[1:]  # exclude background
    
    contours = np.full_like(mask, fill_value=0., dtype=np.uint8)
    
    for label in labels:
        mask_l = np.where(mask == label, 1, 0).astype(np.uint8)
        contours_l, _ = cv2.findContours(mask_l, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contours, contours_l, -1, (1, 0, 0))
    
    return contours

In [None]:
def extract_image_data(data_dir, data):
    images = []
    images_names = []
    
    for key in data.keys():
        image_stem = data[key]['path'][0]
        image_path = str(Path(data_dir)/image_stem)
        image = np.array(Image.open(image_path))
        
        images.append(image)
        images_names.append(image_path)
        
    images = np.stack(images)
    
    return images, images_names


def extract_mask_data(data_dir, data, binary_mask=True, parallel=False, n_jobs=1):
    masks = []
    masks_names = [str(Path(data_dir)/data[key]['path'][0])
                   for key in data.keys()]
    
    if parallel:
        n_total = len(data)
        
        with tqdm_joblib(tqdm(desc='Livecell masks extraction', total=n_total)) as progress_bar:
            worker = Parallel(n_jobs=n_jobs, backend='loky')
            masks = worker(delayed(get_cell_mask)(data[key]['segmentation'], binary_mask=binary_mask)
                           for key in data.keys())
    else:
        for key in data.keys():
            raw_segmentation = data[key]['segmentation']
            mask = get_cell_mask(raw_segmentation, binary_mask=binary_mask)
            
            masks.append(mask)
    
    masks = np.stack(masks)
    
    return masks, masks_names


def extract_contour_data(data, parallel=False, n_jobs=1):
    contours = []
    
    if parallel:
        n_total = len(data)
        
        with tqdm_joblib(tqdm(desc='Livecell contours extraction', total=n_total)) as progress_bar:
            worker = Parallel(n_jobs=n_jobs, backend='loky')
            contours = worker(delayed(get_cell_contour)(data[key]['segmentation'])
                           for key in data.keys())
    else:
        for key in data.keys():
            raw_segmentation = data[key]['segmentation']
            contour = get_cell_mask(raw_segmentation)
            
            contours.append(contour)
    
    contours = np.stack(contours)
    
    return contours

In [None]:
# PATH to LiveCell_dataset_2021 directory
data_dir = Path('/content/drive/MyDrive/SartoriousDatasets/LiveCell')

image_dir = data_dir/'images'
annotation_dir = data_dir/'LIVECell_single_cells'

annotation_cell_dirs = sorted(list(annotation_dir.glob('*')))
annotation_cell_dirs[:3]

[PosixPath('/content/drive/MyDrive/SartoriousDatasets/LiveCell/LIVECell_single_cells/shsy5y')]

In [None]:
def create_dirs(save_dir):
    
    save_dir_train, save_dir_val, save_dir_test = save_dir/'train', save_dir/'val', save_dir/'test'
    list_save_dirs = [save_dir_train, save_dir_val, save_dir_test]
    
    # creating all directories
    if not os.path.exists(save_dir_train):
        os.makedirs(save_dir_train)
        os.makedirs(save_dir_val)
        os.makedirs(save_dir_test)
        
        [os.mkdir(d/'images') for d in list_save_dirs]
        # [os.mkdir(d/'masks') for d in list_save_dirs]
        
# Directory where to save parsed images and masks
save_dir = Path('/content/drive/MyDrive/SartoriousDatasets/LiveCell')

create_dirs(save_dir)

In [None]:
def parse_livecell_images(annotation_cell_dirs, save_dir, cell_dict):
    
    for cell_dir in annotation_cell_dirs:
        cell_split_paths = sorted(list(cell_dir.glob('*.json')))  # train|val|test
        
        cell_name = cell_dict[str(cell_dir.stem)]  # A172, BT474 etc.
        
        for split_path in cell_split_paths:
            data = create_cell_dict(split_path)
            
            split_name = str(split_path.stem).split('_')[-1]  # train|val|test
            
            if split_name == 'train' or split_name == 'val':
                data_image_dir = image_dir/'livecell_train_val_images'/cell_name
            elif split_name == 'test':
                data_image_dir = image_dir/'livecell_test_images'/cell_name
            else:
                raise ValueError(f"Split name not known: {split_name}")
            
            images, images_names = extract_image_data(data_image_dir, data)
            
            # to be sure dtype is okay
            images = images.astype(np.uint8)
            
            cur_save_dir = save_dir/split_name/'images'
            
            for idx in range(images.shape[0]):
                img_name = images_names[idx]
                
                img_save_path = cur_save_dir/(str(Path(img_name).stem) + '.png')
                
                Image.fromarray(images[idx]).save(img_save_path)
                

def parse_livecell_masks(annotation_cell_dirs, save_dir, cell_dict,
                         mask_dir_name='masks', mask_type='binary',
                         parallel=True, n_jobs=1):
    
    stem_dict = {
        'binary': 'mask',
        'mask_w_contour': 'mask_w_contour',
        'categorical': 'categorical_mask'
    }
    
    mask_stem = stem_dict.get(mask_type, 'mask')  # binary masks by default
    
    for cell_dir in annotation_cell_dirs:
        cell_split_paths = sorted(list(cell_dir.glob('*.json')))  # train|val|test
        
        cell_name = cell_dict[str(cell_dir.stem)]  # A172, BT474 etc.
        
        for split_path in cell_split_paths:
            data = create_cell_dict(split_path)
            
            split_name = str(split_path.stem).split('_')[-1]  # train|val|test
            
            if split_name == 'train' or split_name == 'val':
                data_image_dir = image_dir/'livecell_train_val_images'/cell_name
            elif split_name == 'test':
                data_image_dir = image_dir/'livecell_test_images'/cell_name
            else:
                raise ValueError(f"Split name not known: {split_name}")
        
            masks, masks_names = extract_mask_data(
                data_image_dir, data, binary_mask=True, parallel=parallel, n_jobs=n_jobs)
            
            if mask_type == 'mask_w_contour':
                contours = extract_contour_data(data, parallel=parallel, n_jobs=n_jobs)
                final_masks = np.stack([masks, contours], axis=-1)  # extend masks with contours: [N, H, W, 2]
            elif mask_type == 'categorical':
                contours = extract_contour_data(data, parallel=parallel, n_jobs=n_jobs)
                final_masks = masks + contours
            elif mask_type == 'binary':
                final_masks = masks
            else:
                raise ValueError()
            
            cur_save_dir = save_dir/split_name/mask_dir_name

            if not os.path.exists(cur_save_dir):
                os.makedirs(cur_save_dir)
            
            for idx in range(final_masks.shape[0]):
                mask_name = masks_names[idx]
                
                mask_save_path = cur_save_dir/(str(Path(mask_name).stem) + '_' + mask_stem + '.tif')
                
                Image.fromarray(final_masks[idx]).save(mask_save_path)

In [None]:
cell_dict = {
    'a172': 'A172',
    'bt474': 'BT474',
    'bv2': 'BV2',
    'huh7': 'Huh7',
    'mcf7': 'MCF7',
    'ratc6': 'RatC6',
    'shsy5y': 'SHSY5Y',
    'skbr3': 'SkBr3',
    'skov3': 'SKOV3'
}

In [None]:
parse_livecell_images(annotation_cell_dirs, save_dir, cell_dict)

In [None]:
parse_livecell_masks(annotation_cell_dirs, save_dir, cell_dict,
                     mask_dir_name='masks_w_contours', mask_type='mask_w_contour', n_jobs=4)

Livecell masks extraction: 100%|██████████| 449/449 [1:36:05<00:00, 12.84s/it]
Livecell contours extraction: 100%|██████████| 449/449 [1:44:36<00:00, 13.98s/it]
Livecell masks extraction: 100%|██████████| 79/79 [16:59<00:00, 12.91s/it]
Livecell contours extraction: 100%|██████████| 79/79 [18:13<00:00, 13.85s/it]
