In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import glob
import os, sys, time
import h5py
import numbers
import random
from torch.utils import data
import helpers
from pathlib import Path
import cv2
from torch.utils.data import DataLoader, Dataset
import PIL

In [2]:
def data_load(args, datatype):
    
    if datatype=='Train':
        data = TrainData(args)
        shuffle=True
        
    elif datatype=='Validation':
        data = ValidData(args)

    elif datatype=='Test':
        data = TestData(args)
        
    else:
        raise ValueError('Choose the data type: Train, Validation, Test')
        
        
    dataload =  torch.utils.data.DataLoader(data, batch_size=args.batch_size, 
                                            shuffle=shuffle) 

In [3]:
class TrainData(torch.utils.data.Dataset):
     
    def __init__(self, file_path, data_cache_size=3, transform=None):
        recursive=False 
        load_data=False
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size
        self.transform = transform #Compose([
                    #RandomVerticalFlip(),
                    #RandomHorizontalFlip(),
                    #RandomAffine(degrees=(-20,20),translate=(0.1,0.1),
                    #             scale=(0.9,1.1), shear=(-0.2,0.2))])

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())
        if recursive:
            files = sorted(p.glob('**/*.h5'))
        else:
            files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
    
    def __getitem__(self, index):
        # get data
        x = self.get_data("image", index)
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)

        # get label
        y = self.get_data("mask", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('image'))
    
    def _add_data_infos(self, file_path, load_data):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
        
                for dname, ds in h5_file.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    if load_data:
                        # add data to the data cache
                        
                        if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                        else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': np.shape(ds[()]), 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            
                for dname, ds in h5_file.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                    else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                            
                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        #print(file_path)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [4]:
class ValidData(torch.utils.data.Dataset):
    
    def __init__(self, file_path, data_cache_size=3, transform=None):
        recursive=False 
        load_data=False
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size
        self.transform = transform

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())
        if recursive:
            files = sorted(p.glob('**/*.h5'))
        else:
            files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
    
    def __getitem__(self, index):
        # get data
        x = self.get_data("image", index)
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)

        # get label
        y = self.get_data("mask", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('image'))
    
    def _add_data_infos(self, file_path, load_data):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
        
                for dname, ds in h5_file.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    if load_data:
                        # add data to the data cache
                        
                        if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                        else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': np.shape(ds[()]), 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            
                for dname, ds in h5_file.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                    else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                            
                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        #print(file_path)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [9]:
class TestData(torch.utils.data.Dataset):
    def __init__(self, file_path, data_cache_size=3, transform=None):
        recursive=False 
        load_data=False
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size
        self.transform = transform

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())
        if recursive:
            files = sorted(p.glob('**/*.h5'))
        else:
            files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
    
    def __getitem__(self, index):
        # get data
        x = self.get_data("image", index)
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)

        # get label
        y = self.get_data("mask", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('image'))
    
    def _add_data_infos(self, file_path, load_data):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
        
                for dname, ds in h5_file.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    if load_data:
                        # add data to the data cache
                        
                        if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                        else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': np.shape(ds[()]), 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            
                for dname, ds in h5_file.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    if dname == 'mask':
                            mask = np.zeros(np.shape(ds[()]))
                            mask[:,:,0] = ds[()][:,:,0]*50
                            mask[:,:,1] = ds[()][:,:,1]*100
                            mask[:,:,2] = ds[()][:,:,2]*150
                            maskflat = mask[:,:,0]+mask[:,:,1]+mask[:,:,2]
                            idex = self._add_to_cache(maskflat, file_path)
                            
                    else:
                            image = np.zeros(np.shape(ds[()]))
                            image[:,:,0] = cv2.normalize(ds[()][:,:,0],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,1] = cv2.normalize(ds[()][:,:,1],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,2] = cv2.normalize(ds[()][:,:,2],None, norm_type=cv2.NORM_MINMAX)
                            image[:,:,3] = cv2.normalize(ds[()][:,:,3],None, norm_type=cv2.NORM_MINMAX)
                            idx = self._add_to_cache(image, file_path)
                            
                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        #print(file_path)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [10]:
#############################################################
#                                                           #
#       Data Transforms Functions                           # 
#                                                           #
#############################################################

'''
    From torchvision Transforms.py (+ Slightly changed)
    (https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py)
'''

class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, label):
        for t in self.transforms:
            img,label = t(img, label)
        return img, label

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic, label):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return to_tensor(pic), to_tensor(label)

    def __repr__(self):
        return self.__class__.__name__ + '()'


class RandomVerticalFlip(object):
    """Vertically flip the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, label):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return vflip(img), vflip(label)
        return img, label

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, label):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return hflip(img), hflip(label)
        return img, label

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomRotation(object):
    """Rotate the image by angle.
    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
            An optional resampling filter. See `filters`_ for more information.
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
        center (2-tuple, optional): Optional center of rotation.
            Origin is the upper left corner.
            Default is the center of the image.
    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
    """

    def __init__(self, degrees=360, resample=False, expand=False, center=None):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
            self.degrees = degrees

        self.resample = resample
        self.expand = expand
        self.center = center

    @staticmethod
    def get_params(degrees):
        """Get parameters for ``rotate`` for a random rotation.
        Returns:
            sequence: params to be passed to ``rotate`` for random rotation.
        """
        #angle = random.uniform(degrees[0], degrees[1])
        angle_list = [0,90,180,270]
        angle = random.choice(angle_list)
        return angle

    def __call__(self, img, label):
        """
            img (PIL Image): Image to be rotated.
        Returns:
            PIL Image: Rotated image.
        """

        #angle = self.get_params(self.degrees)
        angle = np.random.randint(self.degrees[0], self.degrees[1])
        return rotate(img, angle, self.resample, self.expand, self.center),\
                rotate(label, angle, self.resample, self.expand, self.center)

    def __repr__(self):
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
        format_string += ', resample={0}'.format(self.resample)
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
        format_string += ')'
        return format_string


class RandomAffine(object):
    """Random affine transformation of the image keeping center invariant
    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
        shear (sequence or float or int, optional): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees). Will not apply shear by default
        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
            An optional resampling filter. See `filters`_ for more information.
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
        fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
    """

    def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
                "degrees should be a list or tuple and it must be of length 2."
            self.degrees = degrees

        if translate is not None:
            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
                "translate should be a list or tuple and it must be of length 2."
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
                "scale should be a list or tuple and it must be of length 2."
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
            if isinstance(shear, numbers.Number):
                if shear < 0:
                    raise ValueError("If shear is a single number, it must be positive.")
                self.shear = (-shear, shear)
            else:
                assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
                    "shear should be a list or tuple and it must be of length 2."
                self.shear = shear
        else:
            self.shear = shear

        self.resample = resample
        self.fillcolor = fillcolor

    @staticmethod
    def get_params(degrees, translate, scale_ranges, shears, img_size):
        """Get parameters for affine transformation
        Returns:
            sequence: params to be passed to the affine transformation
        """
        angle = random.uniform(degrees[0], degrees[1])
        if translate is not None:
            max_dx = translate[0] * img_size[0]
            max_dy = translate[1] * img_size[1]
            translations = (np.round(random.uniform(-max_dx, max_dx)),
                            np.round(random.uniform(-max_dy, max_dy)))
        else:
            translations = (0, 0)

        if scale_ranges is not None:
            scale = random.uniform(scale_ranges[0], scale_ranges[1])
        else:
            scale = 1.0

        if shears is not None:
            shear = random.uniform(shears[0], shears[1])
        else:
            shear = 0.0

        return angle, translations, scale, shear

    def __call__(self, img, label):
        """
            img (PIL Image): Image to be transformed.
        Returns:
            PIL Image: Affine transformed image.
        """
        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
        return affine(img, label, *ret, resample=self.resample, fillcolor=self.fillcolor)

    def __repr__(self):
        s = '{name}(degrees={degrees}'
        if self.translate is not None:
            s += ', translate={translate}'
        if self.scale is not None:
            s += ', scale={scale}'
        if self.shear is not None:
            s += ', shear={shear}'
        if self.resample > 0:
            s += ', resample={resample}'
        if self.fillcolor != 0:
            s += ', fillcolor={fillcolor}'
        s += ')'
        d = dict(self.__dict__)
        d['resample'] = _pil_interpolation_to_str[d['resample']]
        return s.format(name=self.__class__.__name__, **d)



