In [1]:
import os
import numpy as np
import h5py
import random
from PIL import Image
import tensorflow as tf
import meta

In [10]:
path_train = 'data/train'
path_test = 'data/test'
path_extra = 'data/extra'

In [11]:
def process_image(path_image_file, left, top, width, height):
    image = Image.open(path_image_file)
    crop_left, crop_top, crop_width, crop_height = (
        int(round(left - 0.15 * width)), int(round(top - 0.15 * height)),
        int(round(width * 1.3)), int(round(height * 1.3))
    )
    image = image.crop([crop_left, crop_top, crop_left + crop_width, crop_top + crop_height])
    image.resize([64, 64])
    return image

In [15]:
def convert_mat_file(path_image_file, path_mat_file):
    index = int(path_image_file.split('/')[-1].split('.')[0]) - 1
    mat_file = h5py.File(path_mat_file, 'r')
    mat_file_item = mat_file['digitStruct']['bbox'][index].item()
    image_attributes = {}
    keys = ['label', 'left', 'top', 'width', 'height']
    for key in keys:
        attribute = mat_file[mat_file_item][key]
        image_attributes[key] = [mat_file[attribute.value[i].item()].value[0][0] 
                                 for i in range(len(attribute))] if len(attribute) > 1 else [attribute.value[0][0]]
    digit_labels = image_attributes['label']
    num_digit_labels = len(digit_labels)
    if num_digit_labels > 5:
        # ignore this example
        return None
    
    digits = [10, 10, 10, 10, 10]
    for i, digit_label in enumerate(digits):
            digits[i] = int(digit_label if digit_label != 10 else 0)
    image_attributes_left, image_attributes_top, image_attributes_width, image_attributes_height = map(
        lambda x: [int(i) for i in x], [image_attributes['left'], image_attributes['top'], image_attributes['width'], image_attributes['height']]
    )
    min_attributes_left = min(image_attributes_left)
    min_attributes_top = min(image_attributes_top)
    max_attributes_right = max(map(lambda x, y: x + y, image_attributes_left, image_attributes_width))
    max_attributes_bottom = max(map(lambda x, y: x + y, image_attributes_top, image_attributes_height))
    center_x = float(min_attributes_left + max_attributes_right) / 2.0
    center_y = float(min_attributes_top + max_attributes_bottom) / 2.0
    max_width = float(max(max_attributes_right - min_attributes_left, max_attributes_bottom - min_attributes_top))
    image_left, image_right, image_width, image_height = (
        center_x - max_width / 2.0, center_y - max_width / 2.0, max_width, max_width
    )
    image = np.array(process_image(path_image_file, image_left, image_right, image_width, image_height)).tobytes()
    tf_example = tf.train.Example(features=tf.train.Features(
            feature={
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                'num_digit_labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[num_digit_labels])),
                'digits': tf.train.Feature(int64_list=tf.train.Int64List(value=digits))
            }
        ))
    return tf_example

In [13]:
def convert_to_tfrecord(file_path):
    path_mat_file = os.path.join(file_path, 'digitStruct.mat')
    path_tfrecord_file = os.path.join(file_path, file_path.split('/')[-1] + '.tfrecords')
    print('new tfrecord file: ' + path_tfrecord_file)
    writer_tfrecord = tf.python_io.TFRecordWriter(path_tfrecord_file)
    path_image_files = tf.gfile.Glob(os.path.join(file_path, '*.png'))
    num_files = len(path_image_files)
    num_examples = 0
    print(str(num_files) + ' total image files in ' + file_path)
    for i, path_image_file in enumerate(path_image_files):
        tf_example = convert_mat_file(path_image_file , path_mat_file)
        if tf_example is None:
            continue
        else:
            writer_tfrecord.write(tf_example.SerializeToString())
            num_examples += 1
    print(str(num_examples) + ' total image files in'+ file_path + ' have been processed')
    writer_tfrecord.close()

In [None]:
convert_to_tfrecord(path_train)
convert_to_tfrecord(path_test)
convert_to_tfrecord(path_extra)