In [114]:
from skimage.io import imread
def get_input(filename, data_dir):
    
    img = imread(data_dir + filename)
    
    return(img)

In [115]:
import numpy as np
import pandas as pd    
def get_output(filename, data_dir,label_file=None):
    
    img_id = get_img_id(data_dir + filename)
    labels = label_file.loc[filename]
    
    return(labels)

In [116]:
def preprocess_input(image):
    import imgaug as ia
    import imgaug.augmenters as iaa
    
    seq = iaa.Sometimes(0.5,iaa.SomeOf((0, None),[
        iaa.Fliplr(0.3),
        iaa.Flipud(0.2),
        iaa.GaussianBlur(sigma=(0.0, 3.0)),
        iaa.Multiply((1.2, 1.5)), # change brightness, doesn't affect BBs
        iaa.AdditiveGaussianNoise(scale=(0, 0.05*255))
        ], random_order=True))
    
    # Augment BBs and images.
    image_aug = seq(image=image)
    
    return image_aug

In [117]:
def image_generator(files,data_dir,label_file, batch_size = 64):
    
    while True:
          # Select files (paths/indices) for the batch
          batch_paths = np.random.choice(a = files, 
                                         size = batch_size)
          batch_input = []
          batch_output = [] 
          
          # Read in each input, perform preprocessing and get labels
          for filename in batch_paths:
              input = get_input(filename, data_dir )
              output = get_output(filename,data_dir,label_file=label_file )
            
              input = preprocess_input(image=input)
              batch_input += [ input ]
              batch_output += [ output ]
          # Return a tuple of (input,output) to feed the network
          batch_x = np.array( batch_input )
          batch_y = np.array( batch_output )
        
          yield( batch_x, batch_y )

In [112]:
import cv2

data_dir = "./train_imgs/"
image_names = ["Pipstrel-Virus_Bodensee_2018-02-13_15-41-05.jpg","Pipstrel-Virus_Bodensee_2018-02-13_15-41-06.jpg"]
label_file = pd.Series(['boat','boat'])
label_file.index = image_names
imgs_list = np.array([image_name,image_name])
generator = image_generator(imgs_list,data_dir,label_file,batch_size = 2)

In [113]:
next(generator)

(array([[[[135, 131,  93],
          [136, 131,  93],
          [136, 132,  93],
          ...,
          [132, 126, 132],
          [132, 126, 132],
          [132, 126, 133]],
 
         [[136, 132,  93],
          [136, 132,  93],
          [136, 132,  93],
          ...,
          [132, 126, 132],
          [132, 126, 132],
          [132, 126, 133]],
 
         [[136, 132,  93],
          [136, 132,  93],
          [136, 133,  93],
          ...,
          [132, 126, 132],
          [132, 126, 132],
          [132, 126, 133]],
 
         ...,
 
         [[ 86,  86,  85],
          [ 86,  86,  86],
          [ 86,  86,  86],
          ...,
          [179, 177, 177],
          [178, 176, 177],
          [177, 175, 176]],
 
         [[ 85,  85,  85],
          [ 86,  86,  85],
          [ 86,  86,  85],
          ...,
          [179, 177, 177],
          [178, 176, 177],
          [177, 175, 176]],
 
         [[ 85,  85,  85],
          [ 86,  86,  85],
          [ 86,  86,  85],
   