In [1]:
import os
import cv2
from tqdm import tqdm
from glob import glob
from albumentations import CenterCrop, RandomRotate90, RandomSizedCrop, HorizontalFlip, VerticalFlip, Sharpen
import rasterio
from rasterio.plot import reshape_as_image, reshape_as_raster
import imageio
from osgeo import gdal
import numpy as np

In [2]:
def load_data(path):
     images = sorted(glob(os.path.join(path, "Image/*")))     
     masks = sorted(glob(os.path.join(path, "Mask/*")))
     return images, masks

In [3]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [4]:
def augment_data(images, masks, save_path, augment=True):
    H = 64
    W = 64

    for x, y in tqdm(zip(images, masks), total=len(images)):
        name = x.split("/")[-1].split(".")
        """ Extracting the name and extension of the image and the mask. """
        image_name = name[0]
        image_extn = name[1]

        name = y.split("/")[-1].split(".")
        mask_name = name[0]
        mask_extn = name[1]

        """ Reading image and mask. """
        x = rasterio.open(x)
        y = rasterio.open(y)
        
        #reading image as array using rasterio        
        x_array = x.read()
        y_array = y.read()
        
        """
        because the order of the dimension raster was defined bands, rows, columns. 
        we have to reshape to rows, columns, bands
        """
        x_image = reshape_as_image(x_array)
        y_image = reshape_as_image(y_array)

        """ Augmentation """
        if augment == True:
            aug = CenterCrop(H, W, p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x1 = augmented["image"]
            y1 = augmented["mask"]

            aug = RandomRotate90(p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x2 = augmented['image']
            y2 = augmented['mask']

            aug = RandomSizedCrop(min_max_height=(50, 101), height=H, width=W, p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x3 = augmented['image']
            y3 = augmented['mask']

            aug = HorizontalFlip(p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x4 = augmented['image']
            y4 = augmented['mask']
 
            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x5 = augmented['image']
            y5 = augmented['mask']

            aug = Sharpen(p=1.0)
            augmented = aug(image=x_image, mask=y_image)
            x6 = augmented['image']
            y6 = augmented['mask']

         
        ##add x and y if the original image not saved on the new folder
            save_images = [x1, x2, x3, x4, x5, x6]
            save_masks =  [y1, y2, y3, y4, y5, y6]
            
            #print('augmented shape: ',x1.shape)

        else:
            save_images = [x_image]
            save_masks = [y_image]

        """ Saving the image and mask. """
        idx = 0
        for i, m in zip(save_images, save_masks):
            
            #we have to reshape back the image to raster
            i = reshape_as_raster(i)
            m = reshape_as_raster(m)

            if len(images) == 1:
                tmp_img_name = f"{image_name}.{image_extn}"
                tmp_mask_name = f"{mask_name}.{mask_extn}"

            else:
                tmp_img_name = f"{image_name}_{idx}.{image_extn}"
                tmp_mask_name = f"{mask_name}_{idx}.{mask_extn}"

            image_path = os.path.join(save_path, "images", tmp_img_name)
            mask_path = os.path.join(save_path, "label", tmp_mask_name)
    
            # Create a new raster dataset for the output image
            profile = x.profile.copy()
            profile['driver'] = 'GTiff'
            profile['count'] = x_array.shape[0]  # Update the number of bands
            profile['width'] = x_array.shape[1]  # Update the width
            profile['height'] = x_array.shape[2]  # Update the height
            
            #saving image
            with rasterio.open(image_path,'w',**profile) as saveimage:
                saveimage.write(i)

            # Create a new raster dataset for the output mask
            profile_mask = y.profile.copy()
            profile_mask['driver'] = 'GTiff'
            profile_mask['count'] = y_array.shape[0]  # Update the number of bands
            profile_mask['width'] = y_array.shape[1]  # Update the width
            profile_mask['height'] = y_array.shape[2]  # Update the height
            
            #saving mask
            with rasterio.open(mask_path,'w',**profile_mask) as savemask:
                savemask.write(m)
            
            ###another image save###           
            #cv2.imwritemulti(image_path, i)
            #imageio.imwrite(image_path, [[0,1],[1,0]])
            #cv2.imwritemulti(mask_path, m)
            #imageio.imwrite(mask_path, [[0,1],[1,0]])

            idx += 1

In [5]:
path = "E:\\Tesis MPJ\\SAM\\Dataset_harmonized"
images, masks = load_data(path)
print(f"Original Images: {len(images)} - Original Masks: {len(masks)}")
#print(images)

Original Images: 405 - Original Masks: 405


In [6]:
#create_dir("new_data/images3")
#create_dir("new_data/masks3")

In [7]:
save_path = "new_data"

In [8]:
augment_data(images, masks, save_path, augment=True)

100%|██████████| 405/405 [04:41<00:00,  1.44it/s]
