In [1]:
import random
import numpy as np

In [2]:
try:
    from fuel.transformers._image import window_batch_bchw
    window_batch_bchw_available = True
except ImportError:
    window_batch_bchw_available = False
from fuel.transformers import ExpectsAxisLabels, SourcewiseTransformer
from fuel import config

In [7]:
class dataAugmentation(ExpectsAxisLabels, SourcewiseTransformer):
    def __init__(self, data_stream, window_shape=(224,224),horizontal_mirroring=False,
                normalize=False, random_crop=True, **kwargs):
        self.window_shape = window_shape
        self.horizontal_mirroring = horizontal_mirroring
        self.random_crop = random_crop
        kwargs.setdefault('axis_labels', data_stream.axis_labels)
        
    def batch_transformation(self, source, source_name):
        self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
                                self.data_stream.axis_labels[source_name],
                                source_name)
        window_height, window_width = self.window_shape
        if (isinstance(source, list)
            or (isinstance(source, numpy.ndarray) and source.ndim == 1)):
            if all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
                examples = [self.transform_source_example(im, source_name)
                        for im in source]
                if isinstance(source, list):
                    return examples
                else:
                    return numpy.array(examples)
        elif isinstance(source, numpy.ndarray) and source.ndim == 4:
            batch_size = source.shape[0]
            height, width = source.shape[2:]
            
            if not self.random_crop:
                output = np.empty(source.shape[:2] + self.window_shape, dtype=source.dtype)
                maxoffset_x = width - window_width
                maxoffset_y = height - window_height
                if offset_x < 0 or offset_y < 0:
                    raise ValueError(
                        "Got ndarray batch with image dimensions {} but "
                        "requested window shape of {}".format(
                            source.shape[2:], self.window_shape))
                seed = np.random.RandomState(config.default_seed)
                offsets_x = seed.random_integers(0,maxoffset_x,size=batch_size)
                offsets_y= seed.random_integers(0,maxoffset_y,size=batch_size)
                window_batch_bchw(source,offsets_y,offsets_x,output)
                
                
            else:
                offset_x = (width - window_width) // 2
                offset_y = (height - window_height) // 2
                output = source[:, :, offset_y:-offset_y, offset_x:-offset_x]
            
            if self.horizontal_mirroring:
                for item in output:
                    if random.randint(0,1):
                        item[:] = item[:, :, ::-1]
                        
            
            if self.normalize:
                output = output.astype(numpy.float32) / 255.0
            
            return output
        else:
            raise ValueError("uninterpretable batch format; expected a list "
                             "of arrays with ndim = 3, or an array with "
                             "ndim = 4")
    def single_transformation(self, item, source_name):
        self.verify_axis_labels(('channel', 'height', 'width'),
                                self.data_stream.axis_labels[source_name],
                                source_name)
        
        window_height, window_width = self.window_shape
        if not isinstance(item, numpy.ndarray) or example.ndim != 3:
            raise ValueError("uninterpretable example format; expected "
                             "ndarray with ndim = 3")
        height, width = item.shape[1:]
        maxoffset_x = width - window_width
        maxoffset_y = height - window_height
        if maxoffset_x < 0 or maxoffset_y < 0:
            raise ValueError("can't obtain ({}, {}) window from image "
                             "dimensions ({}, {})".format(
                                 self.window_shape,
                                 item.shape[1:]))
            
        if not self.random_crop:
            if maxoffset_x > 0:
                offset_x = seed.random_integers(0, maxoffset_x)
            else:
                offset_x = 0
            
            if maxoffset_y > 0:
                offset_y = seed.random_integers(0, maxoffset_y)
            else:
                offset_y = 0
        else:
            offset_x = maxoffset_x // 2
            offset_y = maxoffset_y // 2
        
        item = item[:, offset_y:offset_y + window_height, 
                    offset_x:offset_x + window_width]
        
        if self.horizontal_mirroring:
            if random.randint(0,1):
                item[:] = item[:, :, ::-1]
            
        if self.normalize:
            item = item.astype(numpy.float32) / 255.0
            
        return item            