In [1]:
from glob import glob
from pathlib import Path
import os.path as osp
import cv2
import albumentations as A
from fl_tissue_model_tools.transforms import get_elastic_dual_transform
from fl_tissue_model_tools.preprocessing import get_augmentor, augment_images

In [2]:
'''
This notebook is used to test the augmentation pipeline for binary segmentation.
'''

albumentations_transform = A.Compose([
    A.Rotate(p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.RandomCrop(height=572, width=572),
    A.Flip(p=0.5),
    A.OneOf([
            A.MultiplicativeNoise(p=0.5),
            A.AdvancedBlur(p=0.5)
    ], p=0.8),
    A.OneOf([
        A.RandomGamma(p=0.2),
        A.RandomBrightnessContrast(p=0.4),
        A.RandomToneCurve(scale=0.75, p=0.2),
        A.CLAHE(p=0.2)
    ], p=0.75)
])

elastic_transform = get_elastic_dual_transform()

transform_list = (albumentations_transform, elastic_transform)
augmentor = get_augmentor(transform_list)

In [3]:
print(type(transform_list))

<class 'tuple'>


In [4]:
def get_img_mask_paths(img_dir, mask_dir=None, img_suffix_pattern='.tif', mask_suffix_pattern='_mask.tif'):
    if mask_dir is None:
        mask_dir = img_dir
    # make sure the search patterns are distinct
    same_dir = img_dir==mask_dir
    if same_dir and img_suffix_pattern==mask_suffix_pattern:
        raise ValueError('directories and suffixes for images and masks are identical')
    exclude_mask_suffix_from_img_search = same_dir and mask_suffix_pattern.endswith(img_suffix_pattern)
    exclude_img_suffix_from_mask_search = same_dir and img_suffix_pattern.endswith(mask_suffix_pattern)

    # get image paths
    img_paths = glob(osp.join(img_dir, f'*{img_suffix_pattern}'))
    if exclude_mask_suffix_from_img_search:
        img_paths = [p for p in img_paths if not p.endswith(mask_suffix_pattern)]

    # get mask filenames
    mask_filenames = [Path(fp).name for fp in glob(osp.join(mask_dir, f'*{mask_suffix_pattern}'))]
    if exclude_img_suffix_from_mask_search:
        mask_filenames = [p for p in mask_paths if not p.endswith(img_suffix_pattern)]

    # sort paths and make sure images and masks are paired 1:1
    if len(img_paths) != len(mask_filenames):
        raise ValueError(f'number of images ({len(img_paths)}) and masks ({len(mask_filenames)}) is different')
    img_paths = sorted(img_paths)
    mask_paths = []
    for img_path in img_paths:
        sample_name = Path(img_path).name.replace(img_suffix_pattern, '')
        mask_fname = sample_name + mask_suffix_pattern
        if mask_fname in mask_filenames:
            mask_paths.append(osp.join(mask_dir, mask_fname))
        else:
            raise ValueError(f'mask {mask_fname} not found for image {Path(img_path).name}')

    return img_paths, mask_paths

In [5]:
img_dir = '/home/bean/lab/training-data/jan31'
mask_dir = img_dir
out_path = "transforms_output"
out_ext = ".png"
num_cycles = 1
img_paths, mask_paths = get_img_mask_paths(img_dir, mask_dir)
images = [cv2.imread(fp, 0) for fp in img_paths]
masks = [cv2.imread(fp, 0) for fp in mask_paths]
im_mask_pairs = list(zip(images, masks))
Path(out_path).mkdir(parents=True, exist_ok=True)

In [6]:
for cycle in range(num_cycles):
	augmented = augment_images(images, masks, augmentor)
	cycle_num = str(cycle + 1) if num_cycles > 1 else ''
	for i, (img, mask) in enumerate(augmented):
		sample_name = Path(img_paths[i]).stem
		cv2.imwrite(f"{out_path}/{sample_name}_aug{cycle_num}{out_ext}", img)
		cv2.imwrite(f"{out_path}/{sample_name}_mask_aug{cycle_num}{out_ext}", mask)