In [1]:
import torch
import torch.nn as nn
import os 
import numpy as np
import cv2
from glob import glob 
from tqdm import tqdm 
import imageio
from albumentations import HorizontalFlip,VerticalFlip,Rotate

In [2]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def create_req_dirs():
    create_dir('new_data/train/images')
    create_dir('new_data/train/masks')
    create_dir('new_data/test/images')
    create_dir('new_data/test/masks')


def load_data(path):

    

    train_x = sorted(glob(os.path.join(path,'training','images','*.tif')))
    train_y = sorted(glob(os.path.join(path,'training','1st_manual','*.gif')))
    # print(train_x)
    # print(train_y)

    test_x = sorted(glob(os.path.join(path,'test','images','*.tif')))
    test_y = sorted(glob(os.path.join(path,'test','1st_manual','*.gif')))
    # print(test_x)
    # print(test_y)

    return train_x,train_y,test_x,test_y

In [6]:
def augment_data(images,masks,save_path,augment=True):
    size = (512,512)


    for idx,(x,y) in tqdm(enumerate(zip(images,masks)),total=len(images)):
        #extracting the name
        name = x.split('\\')[-1].split('.')[0]
        
        "reading the image and mask"
        x = cv2.imread(x,cv2.IMREAD_COLOR)
        y = imageio.mimread(y)[0] #since it is a gif file it doesnt have channels i.e it is a 2d array without rgb channels
         

        if augment == True:
            aug = HorizontalFlip(p=1.0) #p is the probability of the image being augmented
            augmented = aug(image=x,mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']

            aug = VerticalFlip(p=1.0) #p is the probability of the image being augmented
            augmented = aug(image=x,mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']

            aug = Rotate(limit=45,p=1.0) #p is the probability of the image being augmented
            augmented = aug(image=x,mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']
            
            X = [x,x1,x2,x3]
            Y = [y,y1,y2,y3]

        else:
            X = [x]
            Y = [y]

        "for resizing the image to the required size"
        index=0 
        for i,m in zip(X,Y):
            i = cv2.resize(i,size)
            m = cv2.resize(m,size)
            
            tmp_image_name = f'{name}_{index}.png'
            tmp_mask_name = f'{name}_{index}.png'

            image_path = os.path.join(save_path,'images',tmp_image_name)
            mask_path = os.path.join(save_path,'masks',tmp_mask_name)

            cv2.imwrite(image_path,i)
            cv2.imwrite(mask_path,m)

            index+=1

        

In [8]:
if __name__ == "__main__":
    "seeding"
    np.random.seed(42)

    "load the data"
    data_path = './data'
    train_x,train_y,test_x,test_y = load_data(data_path)


    "create directories to save the augmented images"
    create_req_dirs()

    "define the augmentations"
    augment_data(train_x,train_y,'new_data/train',augment=True)
    augment_data(test_x,test_y,'new_data/test',augment=False)


100%|██████████| 20/20 [00:02<00:00,  9.97it/s]
100%|██████████| 20/20 [00:00<00:00, 22.28it/s]
