# Преобразование обучающих изображений

In [12]:
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import os

import pandas as pd 
import numpy as np
import glob
from tqdm import tqdm
import cv2
from sklearn.model_selection import train_test_split
# !pip install fiona
import fiona
# !pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.models import resnet18
from torchvision.utils import draw_segmentation_masks

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

import json

In [13]:
class EyeDataset(Dataset):
    """
    Класс датасета, организующий загрузку и получение изображений и соответствующих разметок
    """

    def __init__(self, data_folder: str, transform = None):
        self.class_ids = {"vessel": 1}

        self.data_folder = data_folder
        self.transform = transform
        self._image_files = glob.glob(f"{data_folder}/all_sorted_data/*i.png")
#         print(self._image_files)

    @staticmethod
    def read_image(path: str) -> np.ndarray:
        image = cv2.imread(str(path), cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.array(image / 255, dtype=np.float32)
        return image

    def __getitem__(self, idx: int) -> dict:
        # Достаём имя файла по индексу
        image_path = self._image_files[idx]
#         print(image_path)

        # Получаем соответствующий файл разметки
        mask_path = image_path.replace("i", "m")
#         print(mask_path)
        if os.path.isfile(mask_path):
            image = self.read_image(image_path)
            mask = self.read_image(mask_path)

            sample = {'image': image,
                      'mask': mask}

            if self.transform is not None:
                sample = self.transform(**sample)

            return sample

    def __len__(self):
        return len(self._image_files)


In [14]:
from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose
)

In [15]:
# Задаем преобразование изображений
size = 512
def strong_aug(p=0.5):
    return Compose([
        albu.LongestMaxSize(size, interpolation=cv2.INTER_CUBIC, always_apply=True),
        albu.PadIfNeeded(size, size, always_apply=True),
# #         albu.RandomCrop(height=512, width=512, always_apply=True),
# #         ToTensorV2(transpose_mask=True),
        RandomRotate90(),
        Flip(),
# #         Transpose(),
        OneOf([
#             IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            MotionBlur(p=0.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.1),
#             IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ], p=0.3),
        HueSaturationValue(p=0.3),
    ], p=p)



In [16]:
# Инициализируем датасет
dataset = EyeDataset("data")

In [11]:
count = 0
for i in tqdm(dataset):
    if i != None:
        mask = i['mask']
        mask = mask[:,:,1]
        mask *= 255
        mask = mask.astype(np.uint8)
#         mask = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        image = i['image']
        image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
#         print(image)

    
        augmentation = strong_aug(p=0.9)
        data = {"image": image, "mask": mask}
        for i in range(3):

            augmented = augmentation(**data)
            imageA, maskA = augmented["image"], augmented["mask"]
            result_m = Image.fromarray((maskA).astype(np.uint8))
            result_i = Image.fromarray((imageA).astype(np.uint8))        
            result_i.save(f'data/images512/{count}.png')
            result_m.save(f'data/masks512/{count}.png')

            count += 1
#         break

100%|████████████████████████████████████████████████████████████████████████████████| 510/510 [03:09<00:00,  2.69it/s]
