In [78]:
from glob import glob
from skimage.morphology import skeletonize
import cv2
import numpy as np
from skimage import measure, morphology
import random
from pathlib import Path
from PIL import Image

In [79]:
ground_truth_subfolder = 'masks'
images_subfolder = 'images'
ground_truth_suffix = '_mask'   # for example if the mask for image 'img_1.tif' is 'img_1_mask.tif'
extension = '.tif'
fov_mask_subfolder = None   # if None, fov masks will be generated
# change search pattern for ground_truth_paths to match mask paths in your dataset
ground_truth_paths = glob(f'set_*/{ground_truth_subfolder}/*{ground_truth_suffix}{extension}')

In [80]:
def remove_small_islands(gt):
    inverse_gt = 1 - gt
    labeled_regions_inverse = measure.label(inverse_gt, connectivity=1)
    labeled_inverse_regions = morphology.remove_small_objects(labeled_regions_inverse, min_size=5)
    gt[labeled_inverse_regions == 0] = 1
    labeled_regions = measure.label(gt, connectivity=1)
    labeled_regions = morphology.remove_small_objects(labeled_regions, min_size=50)
    gt[labeled_regions == 0] = 0
    return gt

In [81]:
new_dset_paths = ['all_skeleton', 'mix_skeleton', 'no_skeleton']
for new_dset_path in new_dset_paths:
    for split in ['train', 'val', 'test']:
        for category in ['img', 'gt', 'fov_mask']:
            new_path = Path(new_dset_path) / split / category
            new_path.mkdir(parents=True, exist_ok=True)

In [82]:
for new_dset_path in new_dset_paths:
	# create train, val, test sets
	random.shuffle(ground_truth_paths)
	skeleton_split_idx = int(len(ground_truth_paths) // 2)
	masks_to_skeletonize = ground_truth_paths[:skeleton_split_idx]
	im_names = [Path(fp).name for fp in ground_truth_paths]
	train_set_split_idx = int(len(im_names) * 0.72)
	test_set_split_idx = int(len(im_names) * 0.95)
	train_set = im_names[:train_set_split_idx]
	val_set = im_names[train_set_split_idx:test_set_split_idx]
	test_set = im_names[test_set_split_idx:]
	for i, gt_path in enumerate(ground_truth_paths):
		gt = cv2.imread(gt_path, 0)
		img_path = gt_path.replace(ground_truth_subfolder, images_subfolder).replace(ground_truth_suffix, '')
		img = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH)
		kernel = np.ones((2, 2), np.uint8)
		gt = cv2.dilate(gt, kernel, iterations=1)
		gt[gt==255] = 1
		gt = remove_small_islands(gt)
		if (
            (gt_path in masks_to_skeletonize and Path(new_dset_path).name=='mix_skeleton')
            or (Path(new_dset_path).name == 'all_skeleton')
		):
			gt = skeletonize(gt, method='lee')
			kernel = np.ones((3, 3), np.uint8)
			gt = cv2.dilate(gt, kernel, iterations=1)
		gt[gt==1] = 255
		im_name = Path(gt_path).name
		if im_name in train_set:
			train_val_test = 'train'
		elif im_name in val_set:
			train_val_test = 'val'
		elif im_name in test_set:
			train_val_test = 'test'
		else:
			raise ValueError
		save_filename = str(i) + '.png'
		gt_save_path = str(Path(new_dset_path) / train_val_test / 'gt' / save_filename)
		img_save_path = str(Path(new_dset_path) / train_val_test / 'img' / save_filename)
		fov_mask_save_path = str(Path(new_dset_path) / train_val_test / 'fov_mask' / save_filename)
        # Downscale gt and image to 1024x1024
		dim = (1024, 1024)
		gt = np.array(Image.fromarray(gt).resize(dim, Image.Resampling.NEAREST))
		img = cv2.resize(img, dim, cv2.INTER_LANCZOS4)
		# Convert img to 8-bit
		img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
		# Load or generate fov mask
		if fov_mask_subfolder is not None:
			fov_mask_path = gt_path.replace(ground_truth_subfolder, fov_mask_subfolder)
			fov_mask = cv2.imread(fov_mask_path, 0)
			fov_mask = np.array(Image.fromarray(fov_mask).resize(dim, Image.Resampling.NEAREST))
		else:
			fov_mask = np.ones_like(gt) * 255
		cv2.imwrite(gt_save_path, gt)
		cv2.imwrite(fov_mask_save_path, fov_mask)
		cv2.imwrite(img_save_path, img)

  return func(*args, **kwargs)


: 