In [19]:
import glob, os
import tensorflow as tf
from itertools import groupby
from collections import defaultdict

# 因为我装的是CPU版本的，运行起来会有'warning'，解决方法入下，眼不见为净~  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 数据源：http://vision.stanford.edu/aditya86/ImageNetDogs/
image_filenames = glob.glob( './imagenet-dogs/n02*/*.jpg' )

training_dataset = defaultdict(list)
testing_dataset = defaultdict(list)

image_filename_with_breed = list( map(lambda filename:(filename.split('/')[2], filename), image_filenames) )

for dog_breed, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):
    for i, breed_image in enumerate(breed_images):
        if i % 5 == 0:
            testing_dataset[dog_breed].append(breed_image[1])
        else:
            training_dataset[dog_breed].append(breed_image[1])

    breed_training_count = len(training_dataset[dog_breed])
    breed_testing_count = len(testing_dataset[dog_breed])
    
    assert round(breed_testing_count / (breed_training_count + breed_testing_count), 2) > 0.18, \
    "Not enough testing images."
    
    

def imageset_to_records_files(dataset, record_location, resize, \
                              channels = 1, image_type = 'jpg', one_record_count = 100):
    writer = None
    current_index = 0
    image_number = 0
    sess = tf.Session()
    
    for breed, images_filenames in dataset.items():
        for image_filename in images_filenames:
            if image_number % one_record_count == 0:
                if writer:
                    writer.close()
                    
                record_filename = "{record_location}-{current_index}.tfrecords". \
                format(record_location=record_location, current_index=current_index)

                writer = tf.python_io.TFRecordWriter(record_filename)
                current_index += 1
                    
            image_number += 1

            image_file = tf.read_file(image_filename)

            try:
                if( 'jpg' == image_type ):
                    image = tf.image.decode_jpeg(image_file, channels)
                elif( 'png' == image_type ):
                    image = tf.image.decode_png(image_file, channels)
            except:
                print('image decode fail: ', image_filename)
                continue

            resized_image = tf.image.resize_images( image, resize )

            #image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()
            image_loaded = sess.run(tf.cast(resized_image, tf.uint8))
            image_bytes = image_loaded.tobytes()
            image_label = breed.encode("utf-8")

            feature = {
                'label':tf.train.Feature(bytes_list=tf.train.BytesList(value = [image_label])),
                'image':tf.train.Feature(bytes_list=tf.train.BytesList(value = [image_bytes]))
            }
            
            example = tf.train.Example( features=tf.train.Features(feature = feature) )

            writer.write(example.SerializeToString())

    writer.close()                

# pil_imageset_to_records_files 的速度是 imageset_to_records_file 的好几倍
def pil_imageset_to_records_files(dataset, record_location, resize, \
                              channels = 1, image_type = 'jpg', one_record_count = 100):
    from PIL import Image
    writer = None
    current_index = 0
    image_number = 0
    
    for breed, images_filenames in dataset.items():
        for image_filename in images_filenames:
            if image_number % one_record_count == 0:
                if writer:
                    writer.close()
                    
                record_filename = "{record_location}-{current_index}.tfrecords". \
                format(record_location=record_location, current_index=current_index)

                writer = tf.python_io.TFRecordWriter(record_filename)
                current_index += 1
                    
            image_number += 1

            img = None
            try:
                img = Image.open( image_filename )
                if( 1 == channels ):
                    img = img.convert('L')
                img = img.resize( resize )
            
            except:
                print('image decode fail: ', image_filename)
                continue

            image_bytes = img.tobytes()
            image_label = breed.encode("utf-8")

            feature = {
                'label':tf.train.Feature(bytes_list=tf.train.BytesList(value = [image_label])),
                'image':tf.train.Feature(bytes_list=tf.train.BytesList(value = [image_bytes]))
            }
            
            example = tf.train.Example( features=tf.train.Features(feature = feature) )

            writer.write(example.SerializeToString())

    writer.close()


imageset_to_records_files(testing_dataset, './record_files/test-image/testing-image', [250, 151])
pil_imageset_to_records_files(training_dataset, './record_files/train-image/training-image', [250, 151])

print('convert finish.')

convert finish.
