# TF Record Writing
- - - 
Followed codes are examples of writing TF Records with multi-threading

## 1. Environment Setting
- - - 
* Below block sets paths and parameters. See the details at the end of parameters.
* __You have to set bool parameters properly below__
* Multi directory data loading is not supported so far. (190319)

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from PIL import Image
Image.LOAD_TRUNCATED_IMAGES = True
import os
import random
import sys
import threading
import io

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

tf.enable_eager_execution()

%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=6

#################################################
# You have to set this bool parameters properly #
#################################################
USE_LABEL_FILE = True
USE_GRAYSCALE = True
USE_RESIZING = False
USE_DATA_SLICING = False
USE_TRAIN_AUG = False
USE_VAL_AUG = False
#################################################
#################################################
#################################################

TARGET_TRAIN = "train"
TARGET_VAL = "val"

IMAGE_THRESHOLD = 700 # Threshold of width and height

MIN_IMAGE_DATASET = 3000 # Number of minimum count of image datasets
MAX_IMAGE_DATASET = 0 # Number of maximum count of image datasets

VAL_DATA_PROPORTION = 0.2 # Proportion of validation data out of 100

TRAIN_SHARDS = 8 # Number of shards in training TFRecord files. Should be multiple of threads
VAL_SHARDS = 4 # Number of shards in validation TFRecord files. Should be multiple of threads
NUM_OF_THREADS = 4 # Number of threads to preprocess the images

IMAGE_DIR_PATH = '/work/nas/emotion/01_Data/emotion_data' # Image directory used for one image source
# TRAIN_DIR_PATH = '/work/nas/emotion/01_Data/emotion_data' # Train image directory used for separated image source
# VAL_DIR_PATH = '/work/nas/emotion/01_Data/emotion_data' # Validation image directory used for separated image source
TRAIN_DIR_PATH = '/work/data/emotion/FER2013_Emotion/Training'
VAL_DIR_PATH = '/work/data/emotion/FER2013_Emotion/PrivateTest'

# TRAIN_DIR_PATH2 = '/home/youngtak.na/imagenet/ILSVRC/Data/CLS-LOC/train' # Use if there is one more folder
# VAL_DIR_PATH2 = '/home/youngtak.na/imagenet/ILSVRC/Data/CLS-LOC/val' # Use if there is one more folder

# OUTPUT_DIR_PATH = '/work/nas/emotion/02_TFRecords/190315_min_3000'
OUTPUT_DIR_PATH = './tfrecords/FER2013'
TRAIN_OUTPUT_DIR_PATH = OUTPUT_DIR_PATH + "/train"
VAL_OUTPUT_DIR_PATH = OUTPUT_DIR_PATH + "/val"

# LABEL_PATH = '/work/nas/emotion/01_Data/emotion_data/emotion_label.txt' # Could be file or dir that has folders named as labels
LABEL_PATH = '/work/data/emotion/FER2013_Emotion/label.txt' # Could be file or dir that has folders named as labels

LOG_FILE_PATH = OUTPUT_DIR_PATH + '/log.txt'

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=6


## 2. TF Record Utils
- - - 
* This block is about TF Record utils to convert to TF Record example.
* __You have to edit convert_to_example function for fit your model__

In [2]:
def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def float_feature(value):
    """Wrapper for inserting float features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to_example(filename, image_buffer, label, height, width):
    colorspace = b'RGB'
    channels = 3
    image_format = b'JPEG'
    
    if USE_GRAYSCALE :
        colorspace = b'L'
        channels = 1    

    example = tf.train.Example(features=tf.train.Features(feature={
          'image/height': int64_feature(height),
          'image/width': int64_feature(width),
          'image/colorspace': bytes_feature(colorspace),
          'image/channels': int64_feature(channels),
          'image/class/label': int64_feature(label),
          'image/format': bytes_feature(image_format),
          'image/filename': bytes_feature(os.path.basename(filename).encode('UTF-8')),
          'image/encoded': bytes_feature(image_buffer)}))
    return example

## 3. Image Processing Utils
- - - 
* This block is about image processing utils.

In [3]:
def png_to_jpeg(image_data, channels):
    image = tf.image.decode_jpeg(image_data, channels=channels) # 3
    return tf.image.encode_jpeg(image, format='rgb', quality=100)

def cmyk_to_rgb(image_data, channels):
    image = tf.image.decode_jpeg(image_data, channels=channels) # 0
    return tf.image.encode_jpeg(image, format='rgb', quality=100)

def decode_jpeg(image_data, channels):
    image = tf.image.decode_jpeg(image_data, channels=channels) # 3
    assert len(image.shape) == 3
    assert image.shape[2] == channels
    return image

def check_valid_image(img_path, grayscale=False) :
    try :
        img = Image.open(img_path)
        img.verify()
        if not grayscale and img.mode != 'RGB' and img.mode != 'CMYK':
            print(img_path, ' is excluded [', img.mode, ']')
            return False
    except Exception as ex :
        print_and_write_log("[Exception] : %s" % ex) 
        return False
       
    return True

def resize_image_threshold(image, threshold):
    resized_image = image
    width, height = image.size

    while ((width > threshold) | (height > threshold)):
        height = int(height / 2)
        width = int(width / 2)
        resized_image = image.resize((width, height))

    return resized_image

def convert_image_grayscale(image) :
    gray_image = image
    if image.mode != 'L' :
        gray_image = image.convert('L')

    return gray_image

def get_image_bytes(image) :
    image_bytes = io.BytesIO()
    image.save(image_bytes, format='JPEG')
    return image_bytes.getvalue()
    
def process_image(file_name):
    with Image.open(file_name) as image :
        if USE_RESIZING :
            image = resize_image_threshold(image, IMAGE_THRESHOLD)
        if USE_GRAYSCALE :
            image = convert_image_grayscale(image)

        width, height = image.size
        image_data = get_image_bytes(image)
        
#     image_data = tf.gfile.GFile(filename, 'rb').read()

    # Clean the dirty data.
    #   if _is_png(filename):
    #     # 1 image is a PNG.
    #     print('Converting PNG to JPEG for %s' % filename)
    #     image_data = coder.png_to_jpeg(image_data)
    #   elif _is_cmyk(filename):
    #     print('Converting CMYK to RGB for %s' % filename)
    #     image_data = coder.cmyk_to_rgb(image_data)
#     if USE_GRAYSCALE :
#         image = decode_jpeg(image_data, 1)

#     height = image.shape[0]
#     width = image.shape[1]

    return image_data, height, width

## 4. Basic Utils
- - - 
* This block is about basic utils for making TF Record

In [4]:
def print_and_write_log(log, print_log=True) :
    if print_log is True :
        print("{} : {}".format(datetime.now(), log))
        sys.stdout.flush()
    
    if not os.path.exists(OUTPUT_DIR_PATH) :
        make_directory(OUTPUT_DIR_PATH)
            
    with open(LOG_FILE_PATH, "a") as f_log :
        f_log.write("{} : {}\n".format(datetime.now(), log))
        
def make_directory(dir_path) :
    if not os.path.isdir(dir_path):
        os.makedirs(dir_path)
        print("Directory is created in ", dir_path)

def load_labels(label_path, use_file) :    
    labels = []
    if use_file :
        with open(label_path, 'r') as f :
            for line in f.readlines() :
                labels.append(line.strip())
#             labels.sort()
    else :
        image_dirs = os.listdir(label_path)
        image_dirs.sort()
        labels = image_dirs
    
    print_and_write_log("Label count : %d" % len(labels))
    return labels

def get_shards_and_batch(target) :
    assert target == TARGET_TRAIN or target == TARGET_VAL, "[Error] Target name is not matching"

    if target == TARGET_TRAIN :
        num_of_shards = TRAIN_SHARDS    
    elif target == TARGET_VAL :
        num_of_shards = VAL_SHARDS
    
    assert not num_of_shards % NUM_OF_THREADS, "[Error] Shards should be multiple of threads"
    num_shards_per_batch = int(num_of_shards / NUM_OF_THREADS)
    
    return num_shards_per_batch, num_of_shards

def get_augmented_index(index_len, aug_len) :
    np.random.seed(int(datetime.timestamp(datetime.now())))
    augmented_index = []
    while len(augmented_index) < aug_len :
        random_index = np.random.choice(range(0, index_len), index_len, replace=False)
        augmented_index.extend(random_index)
    
    return augmented_index[:aug_len]

def get_shuffled_index(index_len) :
    shuffled_index = np.arange(index_len)
    random.seed(int(datetime.timestamp(datetime.now())))
    random.shuffle(shuffled_index)
    return shuffled_index

## 5. Making Dataset and Writing TF Records
- - - 
* This block is about making dataset and TF Records.
> * Make_dataset
>> * Read datas from data directory with labels.
>> * You can get augmented datas with set __data_aug__ flag True.
>> * And it supports slicing dataset to divide training set and validation set.
> * Make_tfrecords
>> * Make TF Records by multi-threading.
>> * You have to set target with __"train" or "val"__ to get shards and batch.

In [5]:
def make_dataset(data_dir_path, original_labels, data_aug=False, slice_range=[0,1]):
    labels = []
    file_paths = []
    
    for label_index in range(len(original_labels)):
        matching_file_paths = []
        img_file_path = '%s/%s/*.jpg' % (data_dir_path, original_labels[label_index])
        
        cand_file_paths = tf.gfile.Glob(img_file_path)
        for img_path in cand_file_paths :
            if check_valid_image(img_path, grayscale=USE_GRAYSCALE) :
                matching_file_paths.append(img_path)
        
        matching_file_paths.sort()
        print_and_write_log('Matching files : %d' % len(matching_file_paths))
            
        if slice_range[0] != 0 or slice_range[1] != 1 :
            start_index = int(len(matching_file_paths) * slice_range[0])
            end_index = int(len(matching_file_paths) * slice_range[1])
            matching_file_paths = matching_file_paths[start_index : end_index]
            print_and_write_log('Slice matching files %d : %d to %d ' % (end_index - start_index, start_index, end_index))
       
        if data_aug == True and len(matching_file_paths) > MAX_IMAGE_DATASET :
            matching_file_paths = matching_file_paths[:MAX_IMAGE_DATASET]
        
        if data_aug == True and len(matching_file_paths) < MIN_IMAGE_DATASET :
            aug_len = MIN_IMAGE_DATASET - len(matching_file_paths)
            
            augmented_index = get_augmented_index(len(matching_file_paths), aug_len)
            augmented_file_paths = [matching_file_paths[i] for i in augmented_index]
            matching_file_paths.extend(augmented_file_paths)
            
            print_and_write_log("Add augmented files : %d" % aug_len)

        file_paths.extend(matching_file_paths)
        labels.extend([label_index] * len(matching_file_paths))

        print_and_write_log('Finished finding %s files %s in %d of %d classes.' % (len(matching_file_paths), 
                original_labels[label_index], label_index + 1, len(original_labels)))
        
    shuffled_index = get_shuffled_index(len(file_paths))

    file_paths = [file_paths[i] for i in shuffled_index]
    labels = [labels[i] for i in shuffled_index]
    
    print_and_write_log('Found %d JPEG files across %d labels inside %s.' %
        (len(file_paths), len(original_labels), data_dir_path))
    
    return file_paths, labels

def make_tfrecords_by_thread(file_paths, labels, target) :
    assert len(file_paths) == len(labels), "[Error] Files and labels length should be same"
    
    ranges = []
    threads = []
    coordinator = tf.train.Coordinator()

    spacing = np.linspace(0, len(file_paths), NUM_OF_THREADS + 1).astype(np.int)
    for i in xrange(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i+1]])
    print_and_write_log('Launching {} threads for spacings: {}'.format(NUM_OF_THREADS, ranges))

    make_directory(os.path.join(OUTPUT_DIR_PATH, target))
    
    for thread_index in xrange(len(ranges)):
        args = (thread_index, ranges, file_paths, labels, target)
        t = threading.Thread(target=make_tfrecords_batch, args=args)
        t.start()
        threads.append(t)
    
    coordinator.join(threads)
    print_and_write_log('Finished writing all %d images in data set.' % (len(file_paths)))

def make_tfrecords_batch(thread_index, ranges, file_paths, labels, target):
    total_counter = 0
    num_shards_per_batch, num_of_shards = get_shards_and_batch(target)
    num_files_per_thread = ranges[thread_index][1] - ranges[thread_index][0]
    shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
                             num_shards_per_batch + 1).astype(int)
    
    for index in xrange(num_shards_per_batch):
        shard_counter = 0
        shard_index = thread_index * num_shards_per_batch + index
              
        output_filename = '%s-%.3d-of-%.3d.tfrecord' % (target, shard_index + 1, num_of_shards)
        output_file = os.path.join(OUTPUT_DIR_PATH, target + "/" + output_filename)
        
        with tf.python_io.TFRecordWriter(output_file) as writer :
            for i in range(shard_ranges[index], shard_ranges[index + 1]) :
                try :
                    image_buffer, height, width = process_image(file_paths[i])

                    example = convert_to_example(file_paths[i], image_buffer, labels[i], height, width)
                    writer.write(example.SerializeToString())

                    shard_counter += 1
                    total_counter += 1

                    if not total_counter % 1000:
                        print_and_write_log('[thread %d]: Processed %d of %d images in thread batch.' %
                              (thread_index, total_counter, num_files_per_thread))
                except OSError as e :
                    print_and_write_log(str(e))

        print_and_write_log('[thread %d]: Wrote %d images to %s' %
              (thread_index, shard_counter, output_file))
    
    print_and_write_log('[thread %d]: Wrote %d images to %d shards.' %
        (thread_index, total_counter, num_shards_per_batch))

In [6]:
def main(unused_argv):
    if os.path.exists(LOG_FILE_PATH) :
        os.remove(LOG_FILE_PATH)
    
    original_labels = load_labels(LABEL_PATH, USE_LABEL_FILE)
    
    if USE_DATA_SLICING :
        assert VAL_DATA_PROPORTION > 0 and VAL_DATA_PROPORTION < 1, "[Error] Proportion should be between 0 and 1"
        train_file_paths, train_labels = make_dataset(IMAGE_DIR_PATH, original_labels, USE_TRAIN_AUG, [0, 1-VAL_DATA_PROPORTION])
        val_file_paths, val_labels = make_dataset(IMAGE_DIR_PATH, original_labels, USE_VAL_AUG, [1-VAL_DATA_PROPORTION, 1])
    else :
        train_file_paths, train_labels = make_dataset(TRAIN_DIR_PATH, original_labels, USE_TRAIN_AUG)
        val_file_paths, val_labels = make_dataset(VAL_DIR_PATH, original_labels, USE_VAL_AUG)
    
    make_tfrecords_by_thread(train_file_paths, train_labels, TARGET_TRAIN)
    make_tfrecords_by_thread(val_file_paths, val_labels, TARGET_VAL)

if __name__ == '__main__':
    tf.app.run()
    pass

2019-03-19 16:04:18.257435 : Label count : 7
2019-03-19 16:04:18.636855 : Matching files : 3995
2019-03-19 16:04:18.638725 : Finished finding 3995 files angry in 1 of 7 classes.
2019-03-19 16:04:18.703541 : Matching files : 436
2019-03-19 16:04:18.704915 : Finished finding 436 files disgust in 2 of 7 classes.
2019-03-19 16:04:19.102089 : Matching files : 4097
2019-03-19 16:04:19.104005 : Finished finding 4097 files fear in 3 of 7 classes.
2019-03-19 16:04:20.023710 : Matching files : 7215
2019-03-19 16:04:20.026326 : Finished finding 7215 files happy in 4 of 7 classes.
2019-03-19 16:04:20.658915 : Matching files : 4965
2019-03-19 16:04:20.660374 : Finished finding 4965 files neutral in 5 of 7 classes.
2019-03-19 16:04:21.358889 : Matching files : 4830
2019-03-19 16:04:21.361258 : Finished finding 4830 files sad in 6 of 7 classes.
2019-03-19 16:04:21.674659 : Matching files : 3171
2019-03-19 16:04:21.676562 : Finished finding 3171 files surprise in 7 of 7 classes.
2019-03-19 16:04:21.72

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
import matplotlib.pyplot as plt
import shutil
%matplotlib inline

def parse(data):
    features = {"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/class/label": tf.FixedLenFeature((), tf.int64, default_value=0),
                "image/height": tf.FixedLenFeature((), tf.int64, default_value=0),
                "image/width": tf.FixedLenFeature((), tf.int64, default_value=0),
                "image/channels": tf.FixedLenFeature((), tf.int64, default_value=0)}
    parsed_features = tf.parse_single_example(data, features)
    return parsed_features["image/filename"], parsed_features["image/class/label"],\
            parsed_features["image/encoded"], parsed_features["image/width"], parsed_features["image/height"],\
            parsed_features["image/channels"]

def read_tfrecord(data_dir_path) :
    label_names = load_labels(LABEL_PATH, USE_LABEL_FILE)
    
    tfrecord_paths = os.listdir(data_dir_path)
    tfrecord_paths.sort()
    
    for tfrecord_file in tfrecord_paths :
        tfrecord_path = os.path.join(data_dir_path, tfrecord_file)
        print("tfrecord_path : {}".format(tfrecord_path))
        if not os.path.isfile(tfrecord_path):
            continue
        tfrecord_dataset = tf.data.TFRecordDataset(tfrecord_path)
        tfrecord_dataset = tfrecord_dataset.map(parse)
        tfrecord_dataset = tfrecord_dataset.repeat()
        tfrecord_dataset = tfrecord_dataset.batch(5)
        iterator = tfrecord_dataset.make_one_shot_iterator()
        for _ in range(2):
            try :
                file_name, label, image_buffer, width, height, channels = iterator.get_next()
            except :
                print("error")

            fig, axes = plt.subplots(1, 5, figsize=(20,20))
            for idx in range(5) :
                print("file_name : {}, label : {}".format(file_name[idx], label[idx]))
                image_data = tf.image.decode_jpeg(image_buffer[idx], channels=channels[idx])
                image = tf.reshape(image_data, tf.stack([width[idx], height[idx], channels[idx]]))
                
                axes[idx].imshow(image)
                axes[idx].set_title(label_names[label[idx]])
#                 axes[idx].set_title(label[idx].numpy())

In [None]:
# read_tfrecord("/work/nas/tfrecords/smile_final/train")
read_tfrecord(TRAIN_OUTPUT_DIR_PATH)
# read_tfrecord(VAL_OUTPUT_DIR_PATH)
