In [1]:
import import_ipynb
import config
from torch.utils.data import Dataset
import cv2
from scipy.ndimage.filters import gaussian_filter
import numpy as np
from skimage.io import imread, imshow
import torch

importing Jupyter notebook from config.ipynb


In [None]:
#to display the cells distincts and avoid cluttering of the cells we add contours to identify cell boundaries
#this code is to identify contours in the image
def fix_patch(patch, val):
    patch_tmp = np.where(patch == val, patch, 0)
    blurred_patch = gaussian_filter(patch_tmp, sigma=0.7)
    patch_tmp = np.where((blurred_patch < int(0.9 * val)) & (blurred_patch > int(0.5 * val)), 0, 1)
    return patch * patch_tmp


def smart_matrix_indexing(r_min, r_max, c_min, c_max, mat):
    row_max, col_max = np.subtract(mat.shape, (1, 1))
    r_min = np.max([r_min - 3, 0])
    r_max = np.min([r_max + 3, row_max])
    c_min = np.max([c_min - 3, 0])
    c_max = np.min([c_max + 3, col_max])
    return r_min, r_max, c_min, c_max


def fix_segmentation_maps(mask):
    unique_values = np.unique(mask)
    unique_values = unique_values[np.where(unique_values > 0)]
    for val in unique_values:
        r, c = np.where(mask == val)
        r_min, r_max, c_min, c_max = smart_matrix_indexing(r.min(), r.max(), c.min(), c.max(), mask)
        patch = mask[r_min:r_max, c_min:c_max]
        mask[r_min:r_max, c_min:c_max] = fix_patch(patch, val)
    return mask

In [3]:
class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, transforms):
        # store the image and mask filepaths, and augmentation
        # transforms
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms = transforms

    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)

    def __getitem__(self, idx):
        # grab the image path from the current index
        imagePath = self.imagePaths[idx]

        # load the image from disk, swap its channels from BGR to RGB,
        # and read the associated mask from disk in grayscale mode
        image = cv2.imread(imagePath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        fixed_mask = imread(self.maskPaths[idx])
        fixed_mask = fixed_mask.astype(np.uint8)
        #plt.imshow(fixed_mask)
        mask_copy = cv2.imread(self.maskPaths[idx])
        mask_copy = mask_copy.astype(np.uint8)
        fixed_mask = fix_segmentation_maps(fixed_mask)
        #print(fixed_mask.shape)
        #print(np.unique(fixed_mask))
        #plt.imshow(fixed_mask)

        #bin_mask(fixed_mask,mask_copy)
        for i in range(len(fixed_mask)):
            for j in range(len(fixed_mask[i])):
                if(fixed_mask[i][j] != 0):
                    fixed_mask[i][j] = 255
                else:
                    fixed_mask[i][j]=0
        #print(fixed_mask[0][0])

        # check to see if we are applying any transformations
        if self.transforms is not None:
            # apply the transformations to both image and its mask
            image = self.transforms(image)
            fixed_mask = self.transforms(fixed_mask)
        # return a tuple of the image and its mask
        #print(np.unique(fixed_mask))
        return (image, fixed_mask)

#img_path = config.IMAGE_DATASET_PATH
#mask_path = config.MASK_DATASET_PATH

from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt

from torchvision import transforms

transforms = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)),
                                 transforms.ToTensor()
])

#transforms.ToPILImage()
#transforms.Resize((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH))
#transforms.ToTensor()

imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))
maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH)))

len(imagePaths)

len(maskPaths)

trainDS = SegmentationDataset(imagePaths=imagePaths, maskPaths=maskPaths,
    transforms=transforms)

(trainDS.__getitem__(0)[1])

plt.imshow(trainDS.__getitem__(0)[1])

plt.imshow(trainDS.__getitem__(0)[0])

trainDS.__getitem__(0)