In [1]:
import os
import cv2
from tqdm import tqdm
from glob import glob
from albumentations import CenterCrop, RandomRotate90, GridDistortion, HorizontalFlip, VerticalFlip

In [2]:
def load_data(path):
    images = sorted(glob(os.path.join(path, "images/*")))
    masks = sorted(glob(os.path.join(path, "masks/*")))
    return images, masks

In [3]:
def augment_data(images, masks, save_path, augment=True):
    H = 513
    W = 513

    for x, y in tqdm(zip(images, masks), total=len(images)):
        name = x.split("/")[-1].split(".")
        """ Extracting the name and extension of the image and the mask. """
        image_name = name[0]
        image_extn = name[1]

        name = y.split("/")[-1].split(".")
        mask_name = name[0]
        mask_extn = name[1]

        """ Reading image and mask. """
        x = cv2.imread(x, cv2.IMREAD_COLOR)
        y = cv2.imread(y, cv2.IMREAD_COLOR)

        """ Augmentation """
        if augment == True:

            aug = HorizontalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']

            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x, mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']
            
            aug = RandomRotate90(p=1.0)
            augmented = aug(image=x, mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']
            

            save_images = [x, x1, x2,x3]
            save_masks =  [y, y1, y2,y3]
        else:
            save_images = [x]
            save_masks = [y]

        """ Saving the image and mask. """
        idx = 0
        for i, m in zip(save_images, save_masks):
            i = cv2.resize(i, (W, H))
            m = cv2.resize(m, (W, H))

            if len(images) == 1:
                tmp_img_name = f"{image_name}.{image_extn}"
                tmp_mask_name = f"{mask_name}.{mask_extn}"
            else:
                tmp_img_name = f"{image_name}_{idx}.{image_extn}"
                tmp_mask_name = f"{mask_name}_{idx}.{mask_extn}"

            image_path = os.path.join(save_path, tmp_img_name)
            mask_path = os.path.join(save_path, tmp_mask_name)
        

            cv2.imwrite(image_path, i)
            cv2.imwrite(mask_path, m)

            idx += 1


In [4]:
path = "D:/Official/Masters Studies/Thesis/Dataset/Segmentation/Bolt Dataset/selected_set/Bolt_1299_Aug/red cables/"
new_path = "D:/Official/Masters Studies/Thesis/Dataset/Segmentation/Bolt Dataset/selected_set/Bolt_1299_Aug/red cables/new"
images, masks = load_data(path)
print(f"Original Images: {len(images)} - Original Masks: {len(masks)}")
augment_data(images, masks, new_path, augment=True)
images, masks = load_data(new_path)
print(f"Augmented Images: {len(images)} - Augmented Masks: {len(masks)}")

Original Images: 200 - Original Masks: 200


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:22<00:00,  8.72it/s]

Augmented Images: 800 - Augmented Masks: 800



