In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import torch, torchvision
import PIL

In [None]:
original_data = pd.read_csv(os.path.join('../data/', 'train.csv'))
original_data.columns = ['Image', 'Whale_ID']
original_data = original_data[original_data.Whale_ID != 'new_whale']

readLocation = '../data/train/'
writeLocation = '../data/augmentedImages/'


In [None]:
class AugmentImage():
    @staticmethod
    def resize_if_needed(image):
        image = torchvision.transforms.ToPILImage(image)
        width, height = image.size
        if (width < 299):
            image=torchvision.transforms.Resize((299,height))
        if (height < 299):
            image=torchvision.transforms.Resize((width, 299))
        return image
    @staticmethod
    def augment_image(img):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.Lambda(resize_if_needed),
            torchvision.transforms.TenCrop((299,299))]) # FiveCrop + HorizontalFlips
        return transform(image)
    @staticmethod
    def gaussian_noise(img, mean=0, stddev=0.1):
        noise = Variable(ins.data.new(img.size()).normal_(mean, stddev))
        return img + noise
    @staticmethod
    def add_noise_and_saturate(img):
        return torchvision.transforms.Compose([
            ## saturation in clipping values outside of [0.0, 1.0] to 0 or 1
            Lambda(gaussian_noise),
            Lambda(x: np.clip(x, 0, 1))])(img)
    @staticmethod
    def general_transform(img):
        return torchvision.transforms.Compose([
            torchvision.transforms.ColorJitter(brightness=.04, hue=.05, saturation=.05),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomRotation(20, resample=Image.BILINEAR),
            ])(img)
    @staticmethod
    def get_list_of_transformations(img):
        img = AugmentImage.resize_if_needed(img)
        l = [AugmentImage.general_transform(deepcopy(img)),
             torchvision.transforms.RandomHorizontalFlip(deepcopy(img)),
             torchvision.transforms.RandomRotation(17)(deepcopy(img))]
             + AugmentImage.augment_image(deepcopy(img))
#         l.map(lambda x:  torchvision.transforms.Compose([
#             torchvision.transforms.ToTensor(),
#             torchvision.transforms.Normalize([0.443,0.453,0.461], [0.51,0.48,0.5])])(x))
        transformations = [
            'brightness_flip_rotation',
            'hflip',
            'rotation',
            'crop_tl',
            'crop_tr',
            'crop_bl',
            'crop_br',
            'crop_center',
            'flipped_crop_tl',
            'flipped_crop_tr',
            'flipped_crop_bl',
            'flipped_crop_br',
            'flipped_crop_center'
        ]
        return l + [AugmentImage.add_noise_and_saturate(i) for i in l], transformations



In [None]:
maxLength = original_data.shape[0]
for item in range(maxLength):
    imageFile, imageLabel = original_data.iloc[item]
    img = PIL.Image.open(os.path.join(writeLocatiom, imageFile))
    listOfImgs = AugmentImage.get_list_of_transformations(img)
    for newImage in listOfImgs:
        newImage

In [None]:
new_data = original_data.copy()

new_data.append({'Image': i,
                 'Whale_ID':}, ignore_index=True)