CIFAR-10 data from Keras

In [32]:
import tensorflow as tf
import os
import numpy as np
from random import shuffle

In [33]:
from tensorflow.keras.datasets import cifar10
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples


In [69]:
# Combine Test and Train data sets.  We will shuffle and redistribute later.

X = []
Y = []

X.append(x_train)
X.append(x_test)
Y.append(y_train)
Y.append(y_test)

X = np.concatenate(X).astype(float)
Y = np.concatenate(Y).astype(int)

# Keras downloads labels into [n,1], reshape to [n].

Y = Y.reshape([Y.shape[0]])

n = Y.shape[0]
print(X.shape,n)

(60000, 32, 32, 3) 60000


In [70]:
# Create shuffled index toi reorder dta randomly.

ix = list(range(X.shape[0]))
shuffle(ix)
ix[:3]

[37974, 13293, 38540]

In [71]:
# Set shard size based on number of shards; last shard may be smaller.
# Set shards = 11; this will make wildcard easier later in code.

from math import ceil
shards = 11
shard_size = ceil(len(ix)/shards)
shard_size

5455

In [72]:
# Helper function for creating tfrecords files

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[value]))

In [73]:
!rm -rd ./data
!mkdir data

In [74]:
# Iterate through filenames and serialize images

f_prefix = './data/cifar10_data_'
f_digits = 3
f_suffix = 0
f_name = f_prefix + str(f_suffix).zfill(f_digits) + '.tfrecords'
# create tfrecord file
writer = tf.python_io.TFRecordWriter(f_name)
print('Writing...'+f_name)
ct = 0
for idx in ix:
    image = X[idx,:,:,:]
    label = Y[idx]
    e = tf.train.Example(features=tf.train.Features(feature={
        'idx'     : _int64_feature(idx),
        'label'   : _int64_feature(label),
        'image'   : _bytes_feature(image.tostring())
        }))
    writer.write(e.SerializeToString())
    ct += 1
    if ct == shard_size:
        print('finished writing '+f_name+' with '+str(ct)+' examples')
        f_suffix += 1
        f_name = f_prefix + str(f_suffix).zfill(f_digits) + '.tfrecords'
        # create new tfrecords file
        writer = tf.python_io.TFRecordWriter(f_name)
        print('Writing...'+f_name)
        ct = 0
print ('finished writing '+f_name+' with '+str(ct)+' examples')
    


Writing..../data/cifar10_data_000.tfrecords
finished writing ./data/cifar10_data_000.tfrecords with 5455 examples
Writing..../data/cifar10_data_001.tfrecords
finished writing ./data/cifar10_data_001.tfrecords with 5455 examples
Writing..../data/cifar10_data_002.tfrecords
finished writing ./data/cifar10_data_002.tfrecords with 5455 examples
Writing..../data/cifar10_data_003.tfrecords
finished writing ./data/cifar10_data_003.tfrecords with 5455 examples
Writing..../data/cifar10_data_004.tfrecords
finished writing ./data/cifar10_data_004.tfrecords with 5455 examples
Writing..../data/cifar10_data_005.tfrecords
finished writing ./data/cifar10_data_005.tfrecords with 5455 examples
Writing..../data/cifar10_data_006.tfrecords
finished writing ./data/cifar10_data_006.tfrecords with 5455 examples
Writing..../data/cifar10_data_007.tfrecords
finished writing ./data/cifar10_data_007.tfrecords with 5455 examples
Writing..../data/cifar10_data_008.tfrecords
finished writing ./data/cifar10_data_008.tfr