In [1]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
print(os.listdir("data_aug"))

['depths.csv', 'sample_submission.csv', 'test.zip', 'train', 'train.csv', 'train.zip', 'train_crop']


In [7]:
class BlurAwareCrop():
    def __init__(self, prob=0.7, blur_thres=200, min_crop=20, return_size=50):
        self.prob = prob
        self.blur_thres = blur_thres
        self.min_crop = min_crop
        self.return_size = return_size
        self.tr = None
    
    # reference: https://www.pyimagesearch.com/2015/09/07/blur-detection-with-opencv/
    def sharp_measure(self, img_pil):
        img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
        return cv2.Laplacian(img_cv, cv2.CV_64F).var()
    
    def __call__(self, img):
        '''
        if given image has RGB mode(salt image), compute the sharpness of image using cv and setup transforms to be applied
        otherwise, if mask is given, just applies same transform again.
        '''
        if img.mode == 'RGB':
            if self.sharp_measure(img) > self.blur_thres and np.random.rand() < self.prob:
                crop_size = np.random.randint(self.min_crop, self.return_size)
                self.tr = transforms.Compose([
                    transforms.RandomCrop(crop_size),
                    transforms.Resize(self.return_size)
                ])
            else:
                self.tr = transforms.Compose([])
        return self.tr(img)

In [8]:
tr = BlurAwareCrop()

In [4]:
fnames = pd.read_csv('data/train.csv', usecols=['id'])

In [9]:
def show_example(index):
    img = Image.open(f'data/train/images/{fnames.id[index]}.png')
    mask = Image.open(f'data/train/masks/{fnames.id[index]}.png')
    sharpness = tr.sharp_measure(img)
    print(f"image sharpness: {sharpness}")
    if sharpness > tr.blur_thres:
        print(f"image is sharp enough, cropping is applied with probability {tr.prob}")
    else:
        print("image is blurry, cropping will not applied")
    
    plt.figure(figsize=(16, 9))
    
    plt.subplot(141)
    plt.title('image before transform')
    plt.imshow(img)
    
    plt.subplot(142)
    plt.title('mask before transform')
    plt.imshow(mask)

    plt.subplot(143)
    plt.title('image after transform')
    plt.imshow(tr(img))
    
    plt.subplot(144)
    plt.title('mask after transform')
    plt.imshow(tr(mask))
    plt.show()

In [34]:
def save_image():
    if not os.path.exists('data_aug/train/images_blurriness/'):
        os.makedirs('data_aug/train/images_blurriness/')
    if not os.path.exists('data_aug/train/masks_blurriness/'):
        os.makedirs('data_aug/train/masks_blurriness/')
    for fname in fnames.id:
        img = Image.open(f'data/train/images/{fname}.png')
        mask = Image.open(f'data/train/masks/{fname}.png')

        tr(img).save(f'data_aug/train/images_blurriness/{fname}.png')
        tr(mask).save(f'data_aug/train/masks_blurriness/{fname}.png')

In [35]:
save_image()

In [5]:
def save_image_crop():
    if not os.path.exists('data_aug/train_crop/images_blurriness/'):
        os.makedirs('data_aug/train_crop/images_blurriness/')
    if not os.path.exists('data_aug/train_crop/masks_blurriness/'):
        os.makedirs('data_aug/train_crop/masks_blurriness/')
    for fname in fnames.id:
        imgLeftUpper = Image.open(f'data_aug/train_crop/images/{fname}-leftUpper.png')
        maskLeftUpper = Image.open(f'data_aug/train_crop/masks/{fname}-leftUpper.png')
        imgLeftBottom = Image.open(f'data_aug/train_crop/images/{fname}-leftBottom.png')
        maskLeftBottom = Image.open(f'data_aug/train_crop/masks/{fname}-leftBottom.png')
        imgRightUpper = Image.open(f'data_aug/train_crop/images/{fname}-rightUpper.png')
        maskRightUpper = Image.open(f'data_aug/train_crop/masks/{fname}-rightUpper.png')
        imgRightBottom = Image.open(f'data_aug/train_crop/images/{fname}-rightBottom.png')
        maskRightBottom = Image.open(f'data_aug/train_crop/masks/{fname}-rightBottom.png')

        tr(imgLeftUpper).save(f'data_aug/train_crop/images_blurriness/{fname}-leftUpper.png')
        tr(maskLeftUpper).save(f'data_aug/train_crop/masks_blurriness/{fname}-leftUpper.png')
        tr(imgLeftBottom).save(f'data_aug/train_crop/images_blurriness/{fname}-leftBottom.png')
        tr(maskLeftBottom).save(f'data_aug/train_crop/masks_blurriness/{fname}-leftBottom.png')
        tr(imgRightUpper).save(f'data_aug/train_crop/images_blurriness/{fname}-rightUpper.png')
        tr(maskRightUpper).save(f'data_aug/train_crop/masks_blurriness/{fname}-rightUpper.png')
        tr(imgRightBottom).save(f'data_aug/train_crop/images_blurriness/{fname}-rightBottom.png')
        tr(maskRightBottom).save(f'data_aug/train_crop/masks_blurriness/{fname}-rightBottom.png')

In [9]:
save_image_crop()