In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import tensorflow as tf
from preprocessing import inception_preprocessing
from nets import inception_v3
from PIL import Image

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tensorflow.python.ops import control_flow_ops

slim = tf.contrib.slim

In [None]:
def load_batch(dataset, batch_size, height, width, is_training=True):
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset, common_queue_capacity=32, common_queue_min=8)
    image_raw, label = data_provider.get(['image','label'])
    image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)
    
    image_raw = tf.expand_dims(image_raw, 0)
    image_raw = tf.image.resize_images(image_raw, [height, width])
    image_raw = tf.squeeze(image_raw)
    
    images, images_raw, labels = tf.train.batch(
        [image, image_raw, label],
        batch_size=batch_size,
        num_threads=1,
        capacity=2*batch_size)

    return images, images_raw, labels
    

In [None]:
def get_dataset(dataset_dir, train_sample_size, split_name):
    CLASS_NAMES=['t72','willys']
    file_pattern = os.path.join(dataset_dir, 't72willys_%s_*.tfrecord' % split_name)
    
    ITEMS_TO_DESCRIPTIONS = {
        'image' : 'images',
        'label' : 'labels'
    }
    
    keys_to_features = {
        'image/encoded' : tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format' : tf.FixedLenFeature((), tf.string, default_value='png'),
        'image/class/label' : tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }
    
    items_to_handlers = {
        'image' : slim.tfexample_decoder.Image(),
        'label' : slim.tfexample_decoder.Tensor('image/class/label'),
    }
    
    labels_to_names = {}
    
    for i in range(0, len(CLASS_NAMES)):
        labels_to_names[i] = CLASS_NAMES[i]
        
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    
    return slim.dataset.Dataset(
        data_sources=file_pattern,
        reader = tf.TFRecordReader,
        decoder = decoder,
        num_samples = train_sample_size,
        items_to_descriptions = ITEMS_TO_DESCRIPTIONS,
        num_classes = len(CLASS_NAMES),
        labels_to_names=labels_to_names)

In [None]:
train_dataset = get_dataset('/root/qq/t72_willys', 100, 'train')
images, images_raw, labels = load_batch(train_dataset, 100, 299, 299)
with tf.Session() as sess:
    try:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        im, im_raw, lab = sess.run([images, images_raw, labels])

        coord.request_stop()
        coord.join(threads)
    except:
        print ("Unexpected error:", sys.exc_info()[0])

In [None]:
def apply_with_random_selector(x, func, num_cases):
  """Computes func(x, sel), with sel sampled from [0...num_cases-1].
  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.
  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  # Pass the real x only to one of the func calls.
  return control_flow_ops.merge([
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
      for case in range(num_cases)])[0]

In [None]:
def distort_color(image, color_ordering=3, fast_mode=True, scope=None):
  """Distort the color of a Tensor image.
  Each color distortion is non-commutative and thus ordering of the color ops
  matters. Ideally we would randomly permute the ordering of the color ops.
  Rather then adding that level of complication, we select a distinct ordering
  of color ops for each preprocessing thread.
  Args:
    image: 3-D Tensor containing single image in [0, 1].
    color_ordering: Python int, a type of distortion (valid values: 0-3).
    fast_mode: Avoids slower ops (random_hue and random_contrast)
    scope: Optional scope for name_scope.
  Returns:
    3-D Tensor color-distorted image on range [0, 1]
  Raises:
    ValueError: if color_ordering not in [0, 3]
  """
  with tf.name_scope(scope, 'distort_color', [image]):
    if fast_mode:
      if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
      else:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
    else:
      if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
      elif color_ordering == 1:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
      elif color_ordering == 2:
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
      elif color_ordering == 3:
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
      else:
        raise ValueError('color_ordering must be in [0, 3]')

    # The random_* ops do not necessarily clamp.
    return tf.clip_by_value(image, 0.0, 1.0)

In [None]:
scope=None
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = im_raw[0]
with tf.name_scope(scope, 'distort_image', [image, 299, 299, bbox]):
    # bbox
    image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), bbox)
    #distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
    # bbox crop
    bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box(tf.shape(image),
                                                                                 bounding_boxes = bbox,
                                                                                 min_object_covered=0.1,
                                                                                 aspect_ratio_range=(0.75, 1.33),
                                                                                 area_range=(0.05 , 1.0),
                                                                                 max_attempts=100,
                                                                                 use_image_if_no_bounding_boxes=True)
    cropped_image = tf.slice(image, bbox_begin, bbox_size)
    cropped_image.set_shape([None, None, 3])
    image_with_distorted_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), distort_bbox)
    num_resize_cases = 4
    distorted_image = apply_with_random_selector(cropped_image,
                                                lambda x, method : tf.image.resize_images(x, [299, 299], method),
                                                num_cases = num_resize_cases)
    distorted_image_with_flip = tf.image.random_flip_left_right(distorted_image)
    '''
    distorted_image_with_color = apply_with_random_selector(
        distorted_image_with_flip,
        lambda x, ordering: distort_color(x, ordering, False),
        num_cases=4)
    '''
    
    distorted_image_with_contrast = tf.image.random_contrast(distorted_image_with_flip, lower=0.5, upper=1.5)
    distorted_image_with_hue = tf.image.random_hue(distorted_image_with_contrast, max_delta=0.2)
    distorted_image_with_saturation = tf.image.random_saturation(distorted_image_with_hue, lower=0.5, upper=1.5)
    distorted_image_with_brightness = tf.image.random_brightness(distorted_image_with_saturation, max_delta=32./255.)
    distorted_image_with_subtract = tf.subtract(distorted_image_with_brightness, 0.5)
    distorted_image_with_multiply = tf.multiply(distorted_image_with_subtract, 2.0)
 

In [None]:
print (image.shape)
print (image_with_box)
print (bbox_begin)
print (bbox_size)
print (distort_bbox)
print (distorted_image)
print (distorted_image_with_flip)
print (distorted_image_with_contrast)
print (distorted_image_with_hue)
print (distorted_image_with_saturation)
print (distorted_image_with_brightness)
print (distorted_image_with_subtract)
print (distorted_image_with_multiply)

In [None]:
with tf.Session() as sess:
    iwb = sess.run(image_with_box)
    cropped_image = sess.run(cropped_image)
    #distort_bbox = sess.run(distort_bbox)
    image_with_distorted_box = sess.run(image_with_distorted_box)
    distorted_image = sess.run(distorted_image)
    distorted_image_with_flip = sess.run(distorted_image_with_flip)
    distorted_image_with_contrast = sess.run(distorted_image_with_contrast)
    distorted_image_with_hue = sess.run(distorted_image_with_hue)
    distorted_image_with_saturation = sess.run(distorted_image_with_saturation)
    distorted_image_with_brightness = sess.run(distorted_image_with_brightness)
    distorted_image_with_subtract = sess.run(distorted_image_with_subtract)
    distorted_image_with_multiply = sess.run(distorted_image_with_multiply)

In [None]:
fig = plt.figure()

a = fig.add_subplot(121)
img_plot = plt.imshow(image.astype('uint8'))
a.set_title('original')
a = fig.add_subplot(122)
img_plot = plt.imshow(iwb[0].astype('uint8'))
a.set_title('Image with box')

fig = plt.figure()
a = fig.add_subplot(121)
img_plot = plt.imshow(cropped_image.astype('uint8'))
a.set_title('Cropped with bbox')
a = fig.add_subplot(122)
img_plot = plt.imshow(image_with_distorted_box[0].astype('uint8'))
a.set_title('Image with distorted bbox')

fig = plt.figure()
a = fig.add_subplot(121)
img_plot = plt.imshow(distorted_image.astype('uint8'))
a.set_title('Distorted with random selector')
a = fig.add_subplot(122)
img_plot = plt.imshow(distorted_image_with_flip.astype('uint8'))
a.set_title('Distorted image with flip')


fig = plt.figure()
a = fig.add_subplot(121)
img_plot = plt.imshow(distorted_image_with_contrast.astype('uint8'))
a.set_title('Distorted image with contrast')
a = fig.add_subplot(122)
img_plot = plt.imshow(distorted_image_with_hue.astype('uint8'))
a.set_title('Distorted image with hue')

fig = plt.figure()
a = fig.add_subplot(121)
img_plot = plt.imshow(distorted_image_with_saturation.astype('uint8'))
a.set_title('Distorted image with saturation')
a = fig.add_subplot(122)
img_plot = plt.imshow(distorted_image_with_brightness.astype('uint8'))
a.set_title('Distorted image with brightness')

fig = plt.figure()
a = fig.add_subplot(121)
img_plot = plt.imshow(distorted_image_with_subtract.astype('uint8'))
a.set_title('Distorted image with subtract')
a = fig.add_subplot(122)
img_plot = plt.imshow(distorted_image_with_multiply.astype('uint8'))
a.set_title('Distorted image with multiply')