In [20]:
from pathlib import Path
import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
from natsort import natsorted
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from torchvision import transforms
import cv2 as cv

thispath = Path.cwd().resolve()

datadir = Path(thispath.parent / "data")

outputdir = Path(datadir / "Data_augmentation")
Path(outputdir).mkdir(exist_ok=True, parents=True)

class Dataset_instance(Dataset):

    def __init__(self, wsi_path_patches, transform=None, preprocess=None):

        self.wsi_path_patches = wsi_path_patches
        self.transform = transform
        self.preprocess = preprocess


    def __len__(self):
        return len(self.wsi_path_patches)

    def __getitem__(self, index):
        
        # Load the patch image saved as png (key)
        # with open(self.wsi_path_patches[index][0], 'rb') as fin:
        #     key =  pyspng.load(fin.read())
        # open method used to open different extension image file
        # key = np.array(Image.open(self.wsi_path_patches[index][0]))
        key = cv.imread(self.wsi_path_patches[index][0])
        key = cv.cvtColor(key, cv.COLOR_BGR2RGB)

        if self.transform:
            query = self.transform(image=key)['image']
        else:
            query = key

        if self.preprocess:
            query = self.preprocess(query).type(torch.FloatTensor)
            key = self.preprocess(key).type(torch.FloatTensor)
        
        return key, query

prob_augmentation = 0.5

pyhistdir = Path(datadir / "Mask_PyHIST_v2")
dataset_path = natsorted([i for i in pyhistdir.rglob("*_densely_filtered_paths_v2.csv")])

number_patches = 0
path_patches = []
for wsi_patches in tqdm(dataset_path, desc="Selecting all patches for training"):

    csv_instances = pd.read_csv(wsi_patches).to_numpy()
    
    number_patches = number_patches + len(csv_instances)
    path_patches.extend(csv_instances)

logging.info(f"Total number of patches {number_patches}")

pipeline_transform = A.Compose([
    # A.RandomScale(scale_limit=(-0.005,0.005), interpolation=2, p=prob),
    # A.RandomCrop(height=220, width=220, p=prob),
    # A.Resize(224,224,always_apply=True),
    # A.MotionBlur(blur_limit=3, p=prob),
    # A.MedianBlur(blur_limit=3, p=prob),
    # A.CropAndPad(percent=(-0.01, -0.05),pad_mode=1,always_apply=True),
    A.RandomResizedCrop(height=224, width=224, scale=(0.8, 1), p = prob_augmentation),
    A.VerticalFlip(p=prob_augmentation),
    A.HorizontalFlip(p=prob_augmentation),
    A.RandomRotate90(p=prob_augmentation),
    A.HueSaturationValue(hue_shift_limit=(-25,15),sat_shift_limit=(-20,30),val_shift_limit=(-15,15),always_apply=True),
    A.ColorJitter (brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=prob_augmentation),
    # A.GaussianBlur (blur_limit=(1, 3), sigma_limit=0, p=prob),
    # A.HueSaturationValue(hue_shift_limit=(-25,10),sat_shift_limit=(-25,15),val_shift_limit=(-15,15),always_apply=True),
    # A.RGBShift (r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, always_apply=True, p=prob),
    # A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=prob),
    # A.RandomBrightness(limit=0.2, p=prob),
    # A.RandomContrast(limit=0.2, p=prob),
    # A.GaussNoise(p=prob),
    A.ElasticTransform(alpha=200, sigma=10, alpha_affine=10, interpolation=2, border_mode=4, p=prob_augmentation),
    A.GridDistortion(num_steps=1, distort_limit=0.2, interpolation=1, border_mode=4, p=prob_augmentation),
    A.GlassBlur(sigma=0.1, max_delta=1, iterations=1, p=prob_augmentation),
    A.OpticalDistortion (distort_limit=0.2, shift_limit=0.2, interpolation=1, border_mode=4, value=None, p=prob_augmentation),
    # A.GridDropout (ratio=0.3, unit_size_min=3, unit_size_max=40, holes_number_x=3, holes_number_y=3, shift_x=1, shift_y=10, random_offset=True, fill_value=0, p=prob),
    A.Equalize(p=prob_augmentation),
    # A.Posterize(p=prob, always_apply=True),
    # A.RandomGamma(p=prob, always_apply=True),
    # A.Superpixels(p_replace=0.05, n_segments=100, max_size=128, interpolation=1, p=prob),
    # A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3, p=prob),
    A.ToGray(p=0.2),
    # A.Affine(shear = (-5, 5), translate_px = (-5,5), p = prob),
    # A.Affine(translate_px = (-5,5), p = 1),
    # A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=0, p=prob),
    # A.CoarseDropout (max_holes=20, max_height=10, max_width=10, min_holes=None, min_height=1, min_width=1, fill_value=255, p=prob),
    ])

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224, 224),
    antialias=True)
])
params_instance = {'batch_size': 1,
                           'shuffle': True,
                           'pin_memory': True,
                           'drop_last':True,
                           'num_workers': 1}

instances = Dataset_instance(path_patches[:50], pipeline_transform, None)
generator = DataLoader(instances, **params_instance)

with torch.no_grad():
    for a, (x_k, x_q) in enumerate(generator):
        cv.imwrite(f"{outputdir}/{a}_key.png", x_k.squeeze().numpy())
        cv.imwrite(f"{outputdir}/{a}_query.png", x_q.squeeze().numpy())

Selecting all patches for training: 100%|██████████| 1366/1366 [00:09<00:00, 150.11it/s]


tensor([[[[151,  91, 144],
          [164, 102, 153],
          [179, 110, 160],
          ...,
          [236, 239, 238],
          [238, 239, 237],
          [238, 237, 239]],

         [[130,  73, 122],
          [164,  88, 139],
          [146,  81, 133],
          ...,
          [240, 239, 240],
          [241, 239, 239],
          [241, 237, 241]],

         [[145,  84, 146],
          [131,  72, 126],
          [129,  72, 135],
          ...,
          [240, 239, 239],
          [239, 239, 239],
          [239, 237, 241]],

         ...,

         [[184,  78, 102],
          [165,  73,  98],
          [144,  63, 108],
          ...,
          [241, 241, 239],
          [241, 241, 237],
          [240, 236, 237]],

         [[192,  85, 109],
          [166,  72,  97],
          [176,  87, 130],
          ...,
          [239, 235, 239],
          [240, 235, 237],
          [239, 240, 237]],

         [[197,  80, 115],
          [174,  74, 111],
          [159,  74, 112],
         