In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import tarfile
from six.moves import cPickle as pickle
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import matplotlib.pyplot as plt
import numpy as np

In [None]:
CIFAR_FILENAME = 'cifar-10-python.tar.gz'
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'

In [None]:
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _get_file_names():
  """Returns the file names expected to exist in the input_dir."""
  file_names = {}
  file_names['train'] = ['data_batch_%d' % i for i in xrange(1, 5)]
  file_names['validation'] = ['data_batch_5']
  file_names['eval'] = ['test_batch']
  return file_names


def read_pickle_from_file(filename):
  with tf.gfile.Open(filename, 'rb') as f:
    if sys.version_info >= (3, 0):
      data_dict = pickle.load(f, encoding='bytes')
    else:
      data_dict = pickle.load(f)
  return data_dict


def convert_to_tfrecord(input_files, output_file):
  """Converts a file to TFRecords."""
  print('Generating %s' % output_file)
  with tf.python_io.TFRecordWriter(output_file) as record_writer:
    for input_file in input_files:
      data_dict = read_pickle_from_file(input_file)
      data = data_dict[b'data']
      labels = data_dict[b'labels']
      num_entries_in_batch = len(labels)
      for i in range(num_entries_in_batch):
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image': _bytes_feature(data[i].tobytes()),
                'label': _int64_feature(labels[i])
            }))
        record_writer.write(example.SerializeToString())


def main(data_dir):
  print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
  # download_and_extract(data_dir)
  file_names = _get_file_names()
  input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
  for mode, files in file_names.items():
    input_files = [os.path.join(input_dir, f) for f in files]
    output_file = os.path.join(data_dir, mode + '.tfrecords')
    try:
      os.remove(output_file)
    except OSError:
      pass
    # Convert to tf.train.Example and write the to TFRecords.
    convert_to_tfrecord(input_files, output_file)
  print('Done!')

In [None]:
_get_file_names()

In [None]:
data_dict = read_pickle_from_file('/root/data/cifar10/cifar-10-batches-py/')

In [None]:
# main('/root/data/cifar10/')

In [None]:
data_dict = read_pickle_from_file('/root/data/cifar10/cifar-10-batches-py/data_batch_1')
data = data_dict[b'data']
labels = data_dict[b'labels']

In [None]:
data_dict[b'data'].shape

In [None]:
data_dict[b'data'][0, ...]

In [None]:
plt.imshow(data_dict[b'data'][0, ...].reshape(3, 32, 32).transpose(1, 2, 0))

Read tf records

In [None]:
HEIGHT = 32
WIDTH = 32
DEPTH = 3

In [None]:
# def preprocess(self, image):
# """Preprocess a single image in [height, width, depth] layout."""
# if self.subset == 'train' and self.use_distortion:
#   # Pad 4 pixels on each dimension of feature map, done in mini-batch
#   image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
#   image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
#   image = tf.image.random_flip_left_right(image)
# return image

In [None]:
def parser(serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    # Dimensions of the images in the CIFAR-10 dataset.
    # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
    # input format.
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features['image'], tf.uint8)
    image.set_shape([DEPTH * HEIGHT * WIDTH])

    # Reshape from [depth * height * width] to [depth, height, width].
    image = tf.cast(
        tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
        tf.float32)
    label = tf.cast(features['label'], tf.int32)

    # Custom preprocessing.
#     image = self.preprocess(image)

    return image, label

In [None]:
dataset = tf.data.TFRecordDataset(['/root/data/cifar10/train.tfrecords']).repeat()

# Parse records.
dataset = dataset.map(parser, num_parallel_calls=4)
# Batch it up.
dataset = dataset.batch(4)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()

In [None]:
with tf.Session() as sess:
    while True:
        image = sess.run(image_batch)
        print(image.shape)
        plt.figure(figsize=(10, 10))
        plt.imshow(image[0, ...].squeeze().astype(np.uint64))
        plt.show()