# 使用 tensorflow 进行图像数据增广
- https://blog.csdn.net/lordofrobots/article/details/77160191
- https://blog.csdn.net/medium_hao/article/details/79227056

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf  
import numpy as np 
import random
import os
import sys
import matplotlib.pyplot as plt
from PIL import Image 
from PIL import ImageEnhance

  from ._conv import register_converters as _register_converters


In [2]:
def distort_color(image, color_ordering=0):  
#     print('color_ordering=', color_ordering)
    if color_ordering == 0:  
        image = tf.image.random_brightness(image, max_delta=0.5)#亮度  
        image = tf.image.random_saturation(image, lower=0.2, upper=1.8)#饱和度  
        image = tf.image.random_hue(image, max_delta=0.2)#色相  
        image = tf.image.random_contrast(image, lower=0.2, upper=1.8)#对比度  
    if color_ordering == 1:  
        image = tf.image.random_saturation(image, lower=0.2, upper=1.8)#饱和度  
        image = tf.image.random_hue(image, max_delta=0.2)#色相  
        image = tf.image.random_contrast(image, lower=0.2, upper=1.8)#对比度  
        image = tf.image.random_brightness(image, max_delta=0.5)#亮度  
    if color_ordering == 2:  
        image = tf.image.random_hue(image, max_delta=0.2)#色相  
        image = tf.image.random_contrast(image, lower=0.2, upper=1.8)#对比度  
        image = tf.image.random_brightness(image, max_delta=0.5)#亮度  
        image = tf.image.random_saturation(image, lower=0.2, upper=1.8)#饱和度  
    if color_ordering == 3:  
        image = tf.image.random_contrast(image, lower=0.2, upper=1.8)#对比度  
        image = tf.image.random_brightness(image, max_delta=0.5)#亮度   
        image = tf.image.random_saturation(image, lower=0.2, upper=1.8)#饱和度  
        image = tf.image.random_hue(image, max_delta=0.2)#色相  
    # return image  
    # 否则会报错 ValueError: Floating point image RGB values must be in the 0..1 range.
    return tf.clip_by_value(image, 0.0, 1.0)  


def preprocess_for_train(image, height, width, bbox=None):  
    if bbox is None:  
        bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])  
    if image.dtype != tf.float32:  
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)  
    
#     print('width=', width)
    # bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox, min_object_covered=0.1)  
    # distorted_image = tf.slice(image, bbox_begin, bbox_size)  
    # distorted_image = tf.image.resize_images(distorted_image, [height, width], method=np.random.randint(4))  

#     distorted_image = tf.image.resize_images(image, [height, width], method=np.random.randint(4))  
#     distorted_image = tf.image.random_flip_left_right(distorted_image)  
    distorted_image = distort_color(image, np.random.randint(4))  
    return distorted_image  


def get_filenames(dataset_dir):
    """Returns a list of filenames
    """
    flower_root = dataset_dir
#     print('flower_root=', flower_root)
    dirname = os.path.dirname(dataset_dir)
    basename = os.path.basename(dataset_dir)
    augment_root = os.path.join(dirname, basename+'_aug')
#     print('augment_root=', augment_root)
    
    photo_filenames = []
    for filename in os.listdir(flower_root): #根目录
        path = os.path.join(flower_root, filename)
        # print('path =', path)
        if not os.path.isdir(path): #子目录
            photo_filenames.append(path)
#         break
    
    return photo_filenames, augment_root


def preprocess_each(filenames, aug_max=5):
    for i, fname in enumerate(filenames):
        with tf.gfile.FastGFile(fname, 'rb') as fimg:
            value = fimg.read()
            image_data = tf.image.decode_jpeg(value, channels=3) 
            #对一个图片进行多次增广 1~9 共 9 次，加上原图，训练集总共扩大到原来的 10 倍
            for j in range(1, aug_max): 
                result = preprocess_for_train(image_data, image_data.shape[0], image_data.shape[1])
                yield (i, j, fname, result)


def get_outputname(augment_root, filename, ix):
    basename = os.path.basename(filename)
#     print('basename=', basename)
    augment_root = '%s%s' % (augment_root, ix)
#     print('augment_root=', augment_root)
    if not tf.gfile.Exists(augment_root):
        tf.gfile.MakeDirs(augment_root)
    output_filename = '%s/aug%s_%s' % (augment_root, ix, basename)
#     print('output_filename=', output_filename)
    return output_filename

In [3]:
dataset_dir = '../../data/train_1w'  
filenames, augment_root = get_filenames(dataset_dir)
image_num = len(filenames)
# print('filenames=', filenames)

aug_max = 5 #每张图片增广的副本个数
step_size = 100 #处理的图片个数上限，否则内存溢出: aug_max * step_size = 500
cur_step = 52  #当前进度 0,1,2 ...
index_start = cur_step * step_size
index_end = (cur_step + 1) * step_size
if index_end > image_num:
    index_end = image_num

with tf.Session() as sess:  
    init = tf.global_variables_initializer()  
    sess.run(init)  
    
    for ix, jx, fname, result in preprocess_each(filenames[index_start:index_end], aug_max): 
        output_filename = get_outputname(augment_root, fname, jx)
        reimg = result.eval()  
        # 指定格式 format='JPEG' 否则生成 tfRecord 报错
        # 默认格式 PNG 文件太大 
        plt.imsave(output_filename, reimg, format='JPEG')
        sys.stdout.write('\r>> image_num=%s/%s aug=%s' % ((index_start+ix), image_num, jx))
        sys.stdout.flush()
    sys.stdout.write('\n')
    sys.stdout.flush()

>> image_num=5299/10602 aug=4
