In [8]:
from torch import Tensor
from torch.nn import ModuleList
import torch

from torchvision.transforms import RandomAffine
from torchvision.transforms import ToPILImage
from torchvision.transforms import ToTensor

from torchvision import transforms
import cv2
import numpy as np

import random
from torchvision.transforms.functional import InterpolationMode

In [1]:
class RandomRotate():
    def __init__(self, degrees, interpolation, precise=False):
        if interpolation == "NEAREST":
            interpolation = InterpolationMode.NEAREST
        elif interpolation == "BILINEAR":
            interpolation = InterpolationMode.BILINEAR
        else:
            raise TypeError(f"Transform interpolation {interpolation} not supported.")
        if precise:
            self.rotate = RandomAffine([degrees-0.1, degrees+0.1],
                    interpolation=interpolation)
        else:
            self.rotate = RandomAffine(degrees,
                    interpolation=interpolation)
        self.pad = None

    def __call__(self, x):
        old_h, old_w = x.shape[1:]

        if self.pad is None:
            self.pad_size = x.shape[1] // 4
            self.hp = self.pad_size // 2
            self.pad = transforms.Pad(self.pad_size, padding_mode='reflect')

        x = self.pad(x)

        h, w = x.shape[1:]

        x = self.rotate(x)

        top = np.random.randint(self.hp, h - old_h - self.hp)
        left = np.random.randint(self.hp, w - old_w - self.hp)

        x = x[:, top: top + old_h, left: left + old_w]
        return x

class HorizontalFlip:
    def __call__(self, x):
        return torch.flip(x, [2])
    
class VerticalFlip:
    def __call__(self, x):
        return torch.flip(x, [1])

In [2]:
import os

fnames    = os.listdir('Input')
in_paths  = []
out_paths = []

for fname in fnames:
    in_paths.append(os.path.join('Input', fname))
    out_paths.append(os.path.join('Output', fname))

Now we will introduce an abstract class that provides multiple levels of strength to a transformation (for example, rotation can be made with low angles < 5 degrees or with medium angles < 15 degrees or with large angles < 45 degrees; gaussian noise may have multiple levels)

In [17]:
class ImageManager:
    def __init__(self, in_paths, out_paths):
        self.file_paths = list(zip(in_paths, out_paths))
        self.order = list(range(len(self.file_paths)))
        random.shuffle(self.order)
        self.index = 0
                                                      
    def _resize(self, img, size=256):
        height, width = img.shape[0], img.shape[1]
        min_dim = min(height, width)
        scaling = min_dim / size
        new_height = int(height / scaling)
        new_width = int(width / scaling)
        img = cv2.resize(img, (new_width, new_height))

        y_start = new_height // 2 - size // 2
        x_start = new_width // 2 - size // 2
        return img[y_start : y_start + size, x_start : x_start + size, :]
        
    def get_image(self):
        in_path, out_path = self.file_paths[self.index]
        
        self.index += 1
        if self.index >= len(self.file_paths):
            self.index = 0
        
        img = cv2.imread(in_path)
        img = self._resize(img)
        img = np.transpose(img, (2, 0, 1))
        img = img / 255
        img = Tensor(img)
        
        return img, out_path
        
    def save_image(self, out_path, img):
        img = img.numpy()
        img = np.transpose(img, (1, 2, 0))
        img *= 255
        
        cv2.imwrite(out_path, img)

In [18]:
class AbstractTransformationConfig:
    def __init__(self, name, image_manager):
        self.name = name
        self.image_manager = image_manager
        
        # list of tests; each test is a dictionary with 3 attributes:
        #     'level_id': determines the strength of the transform
        #         if tests fail for level 2 then they will not work for level 3 and above
        #     'n': number of times the test must be repeated
        #     'transform': callable object that performs the transformation
        self.tests = []
        
    def augment(self):
        augments = []
        test_counts = {}
        
        for test in self.tests:
            level_id, n, transform = test['level_id'], test['n'], test['transform']
            
            if level_id not in test_counts:
                test_counts[level_id] = 0
                
            for i in range(n):
                augmentation = {}
                augmentation['level_id'] = level_id
                augmentation['name']     = self.name 
                
                img, out_path = self.image_manager.get_image()
                
                img = transform(img)
                
                new_path = out_path.split('.')[0] + '_' + self.name + '_' + \
                            str(level_id) + '_' + str(test_counts[level_id]) + \
                            '.' + out_path.split('.')[1]
                
                self.image_manager.save_image(new_path, img)
                
                augmentation['out_path'] = new_path
                augments.append(augmentation)
                
                test_counts[level_id] += 1
                
        return augments
     
    def config(self, results):
        """
        Looks in the list of user evaluations for its own augmentations; if
        it failed too many of them return None; otherwise returns the appropriate
        level of the transform
        """
        level_counts = {}
        
        for correct_augment in results:
            level_id, name, okay = correct_augment['level_id'], correct_augment['name'], \
                                        correct_augment['okay']
            if name == self.name:
                if level_id not in level_counts:
                    level_counts[level_id] = {
                          'TOTAL': 0,
                        'CORRECT': 0
                    }
                
                level_counts[level_id]['TOTAL'] += 1
                if okay == 'true':
                    level_counts[level_id]['CORRECT'] += 1
        
        good_id = -1
        
        accepted = []
        for level_id in level_counts:
            if level_counts[level_id]['TOTAL'] == level_counts[level_id]['CORRECT']:
                accepted.append(level_id)
                
        # find the highest level that has all of its tests accepted and all
        # levels bellow it are accepted
        accepted.sort()
        for level_id in accepted:
            if level_id == good_id + 1:
                good_id = level_id
            else:
                break
                
        return good_id
        
    
class RotateConfig(AbstractTransformationConfig):
    def __init__(self, image_manager):
        super().__init__('rotate', image_manager)
        self.tests = [
            {'level_id' : 0, 'n' : 2, 'transform': RandomRotate(  5, 'NEAREST', precise=True)},
            {'level_id' : 0, 'n' : 2, 'transform': RandomRotate( -5, 'NEAREST', precise=True)},
            {'level_id' : 1, 'n' : 2, 'transform': RandomRotate( 15, 'NEAREST', precise=True)},
            {'level_id' : 1, 'n' : 2, 'transform': RandomRotate(-15, 'NEAREST', precise=True)},
            {'level_id' : 2, 'n' : 2, 'transform': RandomRotate( 45, 'NEAREST', precise=True)},
            {'level_id' : 2, 'n' : 2, 'transform': RandomRotate(-45, 'NEAREST', precise=True)}
        ]
        
    def config(self, results):
        level_id = super().config(results)
        
        result = {
            'name' : 'RandomPadRotate',
            'interpolation' : 'NEAREST'
        }
        if level_id == 0:
            result['degrees'] = '5.0'
        elif level_id == 1:
            result['degrees'] = '15.0'
        elif level_id == 2:
            result['degrees'] = '45.0'
        else:
            result = None
            
        return result
        
class HorizontalFlipConfig(AbstractTransformationConfig):
    def __init__(self, image_manager):
        super().__init__('hflip', image_manager)
        self.tests = [
            {'level_id' : 0, 'n' : 4, 'transform':  HorizontalFlip()}
        ]
    
    def config(self, results):
        level_id = super().config(results)
        
        result = {
            'name' : 'RandomHorizontalFlip',
        }
        
        if level_id == -1:
            result = None
            
        return result
    
class VerticalFlipConfig(AbstractTransformationConfig):
    def __init__(self, image_manager):
        super().__init__('vflip', image_manager)
        self.tests = [
            {'level_id' : 0, 'n' : 4, 'transform':  VerticalFlip()}
        ]
    
    def config(self, results):
        level_id = super().config(results)
        
        result = {
            'name' : 'RandomVerticalFlip',
        }
        
        if level_id == -1:
            result = None
            
        return result
    
def augment(in_paths, out_paths):
    image_manager = ImageManager(in_paths, out_paths)
    
    augments = []
    
    rotateConf = RotateConfig(image_manager)
    hflipConf  = HorizontalFlipConfig(image_manager)
    vflipConf  = VerticalFlipConfig(image_manager)
    
    augments.extend(rotateConf.augment())
    augments.extend(hflipConf.augment())
    augments.extend(vflipConf.augment())
    
    random.shuffle(augments)
    
    return augments
    
    
def config(augments):
    rotateConf = RotateConfig(None)
    hflipConf  = HorizontalFlipConfig(None)
    vflipConf  = VerticalFlipConfig(None)
    
    print(rotateConf.config(augments))
    print(hflipConf.config(augments))
    print(vflipConf.config(augments))

    

In [19]:
def randomOkay(augment, percent):
    if np.random.rand() > percent:
        augment['okay'] = 'false'
    else:
        augment['okay'] = 'true'
    return augment

In [20]:
augments = augment(in_paths, out_paths)

print(len(augments))

# simulate the user rejecting some of them
augments = list(map(lambda x : randomOkay(x, 0.9), augments))

config(augments)

20
{'name': 'RandomPadRotate', 'interpolation': 'NEAREST', 'degrees': '5.0'}
None
None
