In [68]:
from glob import glob
from skimage.morphology import skeletonize
import cv2
import numpy as np
from skimage import measure, morphology
import random
from pathlib import Path
from PIL import Image

In [69]:
# change mask_paths to match paths for the original datasets
mask_paths = glob('set_*/masks/*.tif')

In [70]:
def remove_small_islands(mask):
    inverse_mask = 1 - mask
    labeled_regions_inverse = measure.label(inverse_mask, connectivity=1)
    labeled_inverse_regions = morphology.remove_small_objects(labeled_regions_inverse, min_size=5)
    mask[labeled_inverse_regions == 0] = 1
    labeled_regions = measure.label(mask, connectivity=1)
    labeled_regions = morphology.remove_small_objects(labeled_regions, min_size=50)
    mask[labeled_regions == 0] = 0
    return mask

In [71]:
new_dset_paths = ['all_skeleton', 'mix_skeleton', 'no_skeleton']
for new_dset_path in new_dset_paths:
    for split in ['train', 'val', 'test']:
        for category in ['img', 'mask']:
            new_path = Path(new_dset_path) / split / category
            new_path.mkdir(parents=True, exist_ok=True)

In [72]:
for new_dset_path in new_dset_paths:
	# create train, val, test sets
	random.shuffle(mask_paths)
	skeleton_split_idx = int(len(mask_paths) // 2)
	masks_to_skeletonize = mask_paths[:skeleton_split_idx]
	im_names = [Path(fp).name for fp in mask_paths]
	train_set_split_idx = int(len(im_names) * 0.72)
	test_set_split_idx = int(len(im_names) * 0.95)
	train_set = im_names[:train_set_split_idx]
	val_set = im_names[train_set_split_idx:test_set_split_idx]
	test_set = im_names[test_set_split_idx:]
	for i, mask_path in enumerate(mask_paths):
		mask = cv2.imread(mask_path, 0)
		img_path = mask_path.replace('masks', 'images').replace('_mask', '')
		img = cv2.imread(img_path, 0)
		kernel = np.ones((2, 2), np.uint8)
		mask = cv2.dilate(mask, kernel, iterations=1)
		mask[mask==255] = 1
		mask = remove_small_islands(mask)
		if (
            (mask_path in masks_to_skeletonize and Path(new_dset_path).name=='mix_skeleton')
            or (Path(new_dset_path).name == 'all_skeleton')
		):
			mask = skeletonize(mask, method='lee')
			kernel = np.ones((3, 3), np.uint8)
			mask = cv2.dilate(mask, kernel, iterations=1)
		mask[mask==1] = 255
		im_name = Path(mask_path).name
		if im_name in train_set:
			train_val_test = 'train'
		elif im_name in val_set:
			train_val_test = 'val'
		elif im_name in test_set:
			train_val_test = 'test'
		else:
			raise ValueError
		save_filename = str(i) + '.png'
		mask_save_path = str(Path(new_dset_path) / train_val_test / 'mask' / save_filename)
		img_save_path = str(Path(new_dset_path) / train_val_test / 'img' / save_filename)
        # Downscale mask and image to 1024x1024
		dim = (1024, 1024)
		mask = np.array(Image.fromarray(mask).resize(dim, Image.Resampling.NEAREST))
		try:
			img = cv2.resize(img, dim, cv2.INTER_LANCZOS4)
		except Exception as e:
			print(e)
			print(img_path)
			raise e
		cv2.imwrite(mask_save_path, mask)
		cv2.imwrite(img_save_path, img)