In [None]:
import sys
sys.path.append("../")

In [None]:
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imageio
import imgaug as ia
import imgaug.augmenters as iaa
from tqdm import tqdm
import cv2 as cv
import matplotlib.pyplot as plt
import rasterio
import geopandas as gpd
import os
import numpy as np

In [None]:
crop_size = 512

In [None]:
def display_image_in_actual_size(im_data):
    # im_data -> HxWxC
    dpi = 80
    height, width = im_data.shape[:2]

    # What size does the figure need to be in inches to fit the image?
    figsize = width / float(dpi), height / float(dpi)

    # Create a figure of the right size with one axes that takes up the full figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])

    # Hide spines, ticks, etc.
    ax.axis('off')

    # Display the image.
    ax.imshow(im_data)

    plt.show()

In [None]:
images_root = "../data/images"
labels_root = "../data/masks"

images_paths = sorted([os.path.join(images_root, img_name) for img_name in os.listdir(images_root)])
labels_paths = [os.path.splitext(os.path.join(labels_root, os.path.basename(img_name)))[0] + '.png' for img_name in images_paths]

In [None]:
src = rasterio.open(images_paths[0]).read()[:3]
src = np.transpose(src, [1, 2, 0])

In [None]:
src.shape

In [None]:
label = cv.imread(labels_paths[0])

In [None]:
plt.imshow(label)

In [None]:
np.unique(label)

In [None]:
seq = iaa.Sequential([
    iaa.Resize((1.0, 4.0)),
    iaa.CropToFixedSize(width=crop_size, height=crop_size),
    iaa.Rot90([1, 3]),
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5),
    iaa.Sometimes(
        0.6,
        [
            iaa.TranslateX(percent=(-0.2, 0.2)),
            iaa.TranslateY(percent=(-0.2, 0.2)),
        ]    
    ),
    iaa.Sometimes(
        0.6,
        iaa.Affine(rotate=(-23, 23)),
    ),
    iaa.Sometimes(
        0.5,
        iaa.Sharpen((0.0, 0.5), lightness=(0.75, 1.4)),
    ),
    iaa.Sometimes(
        0.05,
        iaa.SomeOf(
            1,
            [
                iaa.CLAHE(),
                iaa.AdditiveGaussianNoise(scale=(0, 25)),
                iaa.blur.GaussianBlur(0, 1),
            ], 
        ),        
    ),
    iaa.Sometimes(
        0.01,
        iaa.ElasticTransformation(alpha=50, sigma=5)    
    ),
])

In [None]:
segmap = SegmentationMapsOnImage(label // 255, shape=src.shape)

In [None]:
def modified_random_crop(image, label, crop_size):
    x0 = np.random.randint(0, image.shape[1] - crop_size)
    y0 = np.random.randint(0, image.shape[0] - crop_size)

    b_sc = 0
    for _try in range(4):
        _x0 = np.random.randint(0, image.shape[1] - crop_size)
        _y0 = np.random.randint(0, image.shape[0] - crop_size)
        _sc = label[_y0:_y0+crop_size, _x0:_x0+crop_size, :].sum()
        if _sc > b_sc:
            b_sc = _sc
            x0 = _x0
            y0 = _y0

    image = image[y0:y0+crop_size, x0:x0+crop_size, :]
    label = label[y0:y0+crop_size, x0:x0+crop_size, :]
    return image, label

In [None]:
images_aug = []
segmaps_aug = []
for _ in range(5):
    images_aug_i, segmaps_aug_i = seq(image=src, segmentation_maps=segmap)
    images_aug.append(images_aug_i)
    segmaps_aug.append(segmaps_aug_i)

In [None]:
cells = []
for image_aug, segmap_aug in zip(images_aug, segmaps_aug):
    # cells.append(src)                                         # column 1
    # cells.append(segmap.draw_on_image(src)[0])                # column 2
    cells.append(image_aug)                                     # column 3
    cells.append(segmap_aug.draw_on_image(image_aug)[0])        # column 4
    cells.append(segmap_aug.draw(size=image_aug.shape[:2])[0])  # column 5

In [None]:
grid_image = ia.draw_grid(cells, cols=3)

In [None]:
imageio.imwrite("example_augmentation.jpg", grid_image)

In [None]:
for img in images_aug:
    print(img.shape)

In [None]:
type(seq)

In [None]:
from torchvision import transforms
type(transforms.Compose([
     transforms.CenterCrop(10),
     transforms.ToTensor(),
]))