# Data Augmentation

## Imports

In [1]:
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms

from OSVOS_PyTorch.dataloaders.custom_transforms import ScaleNRotate, Resize, RandomHorizontalFlip

## Paths & Constants

In [2]:
ANNOTATIONS_FOLDERS_PATH = "DAVIS_2016/DAVIS/Annotations/480p/"
IMAGES_FOLDERS_PATH = 'DAVIS_2016/DAVIS/JPEGImages/480p/'

ANNOTATIONS_AUGMENTED_FOLDERS_PATH = "DAVIS_2016/DAVIS/Annotations_augmented/480p/"
IMAGES_AUGMENTED_FOLDERS_PATH = 'DAVIS_2016/DAVIS/JPEGImages_augmented/480p/'

MEANVAL = (104.00699, 116.66877, 122.67892)
AUGMENTATION_COUNT = 2

## Functions

In [3]:
def save_sample(sample, frame, augmentation_count,
                annotations_augmented_folders_path, images_augmented_folders_path):
    
    image = sample['image']
    annotation = sample ['gt']

    file_name = '{}_{}'.format(frame[:5], augmentation_count)
    
    # Save image
    image_save_path = os.path.join(images_augmented_folders_path, file_name)
    np.save(image_save_path, image)
    
    # Save annotation
    annotation_save_path = os.path.join(annotations_augmented_folders_path, file_name)
    np.save(annotation_save_path, annotation)    

In [4]:
def augment_data(annotations_folders_path, images_folders_path,
                 annotations_augmented_folders_path, images_augmented_folders_path,
                 meanval, augmentation_count):

    # Augmentations
    composed_transforms = transforms.Compose([RandomHorizontalFlip(),
                                             ScaleNRotate()])
    # Get list of sequences
    sequences = os.listdir(images_folders_path)
    sequences.sort()
    
    # Iterate through sequences
    for i, sequence in enumerate(sequences):

        # Debug
        if (i > 2): break

        print('#{}: {}'.format(i, sequence))
        
        # Create folders to save augmented annotations and images
        annotations_aug_folder_path = os.path.join(annotations_augmented_folders_path, sequence)
        if not os.path.exists(annotations_aug_folder_path):
            os.makedirs(annotations_aug_folder_path)
            
        images_aug_folder_path = os.path.join(images_augmented_folders_path, sequence)
        if not os.path.exists(images_aug_folder_path):
            os.makedirs(images_aug_folder_path)

        # Get list of frames
        frames = os.listdir(os.path.join(images_folders_path, sequence))
        if '.ipynb_checkpoints' in frames:
            frames.remove('.ipynb_checkpoints')
        frames.sort()
        
        # Iterate through frames
        for j, frame in enumerate(frames):

            # Debug
            if (j > 2): break
            print('\t#{}: {}'.format(j, frame))
            
            if (sequence == 'bmx-bumps' and frame == '00059.png'): break
            if (sequence == 'surf' and frame == '00053.png'): break
                
            # Load annotation and image
            annotation_path = os.path.join(annotations_folders_path, sequence, frame[:5] + '.png')
            image_path = os.path.join(images_folders_path, sequence, frame)
            
            annotation = cv2.imread(annotation_path)
            annotation = np.array(annotation, dtype=np.float32)
            annotation = annotation/np.max([annotation.max(), 1e-8])
            
            image = cv2.imread(image_path)
            image = np.array(image, dtype=np.float32)
            image = np.subtract(image, np.array(meanval, dtype=np.float32))
                     
            # Create sample
            sample = {'image': image, 'gt': annotation}
            
            # Save original sample
            save_sample(sample, frame, '0',
                        annotations_aug_folder_path, images_aug_folder_path)
            
            # Apply augmentations and save them
            for i in range(augmentation_count):
                sample = composed_transforms(sample)
                save_sample(sample, frame, str(i+1),
                            annotations_aug_folder_path, images_aug_folder_path)
            
            # Show sample
            #rescaled = (255.0 / sample['image'].max() * (sample['image'] - sample['image'].min())).astype(np.uint8)
            #plt.imshow(rescaled)
            #plt.show()
            #plt.imshow(sample['gt'])
            #plt.show()

In [5]:
augment_data(ANNOTATIONS_FOLDERS_PATH, IMAGES_FOLDERS_PATH,
             ANNOTATIONS_AUGMENTED_FOLDERS_PATH, IMAGES_AUGMENTED_FOLDERS_PATH,
             MEANVAL, AUGMENTATION_COUNT)

#0: bear
	#0: 00000.jpg
	#1: 00001.jpg
	#2: 00002.jpg
#1: blackswan
	#0: 00000.jpg
	#1: 00001.jpg
	#2: 00002.jpg
#2: bmx-bumps
	#0: 00000.jpg
	#1: 00001.jpg
	#2: 00002.jpg
