In [1]:
import tensorflow as tf
import numpy as np

import os # for os.path and os.listdir

# Preprocessing

## General Configs

In [96]:
IMAGE_DIR = os.path.abspath('../../data/png/')
IMAGE_FILE_WILDCARD = '*.png'
DOC_TEMPLATES_DIR = os.path.abspath('../../doc_pic_generator/templates/')

## Read image into tf

In [97]:
def create_filename_queue(image_dir=IMAGE_DIR, image_file_wildcard=IMAGE_FILE_WILDCARD):
    """
    Make a queue of file names including all the images files in the relative image directory.
    
    Returns:
        A queue with the output strings (filenames)
    """
    filenames = tf.train.match_filenames_once(os.path.join(image_dir, image_file_wildcard))
    filename_queue = tf.train.string_input_producer(filenames)
    return filename_queue


def doc_templates_maps(doc_templates_dir=DOC_TEMPLATES_DIR, add_unknown=True):
    """
    Create 2 mapping ditcs of document template to index e.g. {'UNKNOWN': 0, 'doc_template_01': 1, 'doc_template_02': 2}
    Args:
        doc_templates_dir (str): path of the document templates are stored
        add_unknown (bool): include unknown as a document template
    
    Returns: tuple
        label_to_index (dict): keys are the labels of document template, values are index
        index_to_label (dict): keys are index, values are the labels of document template
    """
    template_names = [os.path.splitext(os.path.basename(template_filename))[0] for template_filename in os.listdir(doc_templates_dir)]
    if add_unknown:
        template_names = ['UNKNOWN'] + template_names
    label_to_index = {temaplte_name: i for i, temaplte_name in enumerate(template_names)}
    index_to_label = {val: key for key, val in label_to_index.items()}
    return label_to_index, index_to_label
    
    
def read_png(filename_queue, channels=3):
    """
    Read from filename_queue
    
    Returns: tuple
        image (tensor): the image tensor of (width, height, channels)
        template_name (str): e.g. 'doc_template_01'
    """
    image_reader = tf.WholeFileReader()

    # Read a whole file from the queue
    file_path, image_file = image_reader.read(filename_queue)

    # Decode the image as a PNG file, this will turn it into a Tensor which we can
    # then use in training.
    image = tf.image.decode_png(image_file, channels=channels)

    file_name = tf.string_split([file_path], delimiter='/').values[-1]

    # template_name is the label we want to classify 
    template_name = tf.string_split([file_name], delimiter='.').values[0]
    return image, template_name

# Define Graph

In [None]:
def cnn_model_fn(features, labels, mode, template_count):
    """
    
    """
    input_layer = tf.placeholder(tf.float32, shape=(None, 1684, 1190, 3), name='input_layer')

    # Convolutional Layer #1
    conv1 = tf.layers.conv2d(
      inputs=input_layer,
      filters=32,
      kernel_size=[5, 5],
      padding="same",
      activation=tf.nn.relu)

    # Pooling Layer #1
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    # Convolutional Layer #2 and Pooling Layer #2
    conv2 = tf.layers.conv2d(
      inputs=pool1,
      filters=64,
      kernel_size=[5, 5],
      padding="same",
      activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)


    # Dense Layer
    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
    dropout = tf.layers.dropout(
      inputs=dense, rate=0.4, training=True)

    # Logits Layer
    logits = tf.layers.dense(inputs=dropout, units=template_count)

    predictions = {
      # Generate predictions (for PREDICT and EVAL mode)
      "classes": tf.argmax(input=logits, axis=1),
      # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
      # `logging_hook`.
      "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }

    # Calculate Loss (for both TRAIN and EVAL modes)
    onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=template_count)
    loss = tf.losses.softmax_cross_entropy(
      onehot_labels=onehot_labels, logits=logits)

# Train

In [98]:
# Create the mapping dicts
label_to_index, index_to_label = doc_templates_maps(doc_templates_dir=DOC_TEMPLATES_DIR, add_unknown=True)

In [70]:
# Start a new session to show example output.
with tf.Session() as sess:
    # Required to get the filename matching to run.
    # lesson learned... https://stackoverflow.com/questions/44143139/tensorflow-operation-tf-train-match-filenames-once-not-working
    init = (tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init)

    # Coordinate the loading of image files.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    try:
        while True:
            # Get an image tensor and print its value.
            label_tensor, image_tensor = sess.run([template_name, image])
            print(image_tensor.shape)
            print(label_tensor)
            
            coord.request_stop()
            coord.join(threads)
    except tf.errors.OutOfRangeError:
        # This will be raised when you reach the end of an epoch (i.e. the
        # iterator has no more elements).
        print('no more elements')

    # Perform any end-of-epoch computation here.
    print('Done training')



(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'doc_template_01'
(1684, 1190, 3)
b'doc_template_02'
(1684, 1190, 3)
b'do