In [1]:
from lab12_util import *

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
DEST_DIRECTORY = 'dataset/cifar10'
DATA_DIRECTORY = DEST_DIRECTORY + '/cifar-10-batches-bin'
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
IMAGE_DEPTH = 3
IMAGE_SIZE_CROPPED = 24
BATCH_SIZE = 128
NUM_CLASSES = 10 
LABEL_BYTES = 1
IMAGE_BYTES = 32 * 32 * 3
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

# download it

maybe_download_and_extract(DEST_DIRECTORY, DATA_URL)

>> Done


In [2]:
from tensorflow.contrib.data import FixedLengthRecordDataset, Iterator

def cifar10_record_distort_parser(record):
    ''' Parse the record into label, cropped and distorted image
    -----
    Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped and distorted image in the record.
    '''
    # TODO1
    record_uint8 = tf.decode_raw(record, tf.uint8)
    # get the label and cast it to int32
    label = tf.cast(tf.strided_slice(record_uint8, [0], [LABEL_BYTES]), tf.int32)
    # [depth, height, width], uint8
    depth_major = tf.reshape(
          tf.strided_slice(record_uint8, [LABEL_BYTES],
                           [LABEL_BYTES + IMAGE_BYTES]),
          [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])
    # change to [height, width, depth], uint8
    image = tf.transpose(depth_major, [1, 2, 0])
    height = IMAGE_SIZE_CROPPED
    width = IMAGE_SIZE_CROPPED
    float_image = tf.cast(image, tf.float32)
    distorted_image = tf.random_crop(float_image, [height, width, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
    distorted_image = tf.image.random_contrast(
          distorted_image, lower=0.2, upper=1.8)
    # standardization: subtract off the mean and divide by the variance of the pixels
    distorted_image = tf.image.per_image_standardization(distorted_image)
    # Set the shapes of tensors.
    distorted_image.set_shape([height, width, 3])
    label.set_shape([1])
    return label, distorted_image 


def cifar10_record_crop_parser(record):
    ''' Parse the record into label, cropped image
    -----
    Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped image in the record.
    '''
    # TODO2
    record_uint8 = tf.decode_raw(record, tf.uint8)
    # get the label and cast it to int32
    label = tf.cast(tf.strided_slice(record_uint8, [0], [LABEL_BYTES]), tf.int32)
    # [depth, height, width], uint8
    depth_major = tf.reshape(
          tf.strided_slice(record_uint8, [LABEL_BYTES],
                           [LABEL_BYTES + IMAGE_BYTES]),
                [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])
    # change to [height, width, depth], uint8
    image = tf.transpose(depth_major, [1, 2, 0])
    height = IMAGE_SIZE_CROPPED
    width = IMAGE_SIZE_CROPPED
    float_image = tf.cast(image, tf.float32)
    
    resized_image = tf.image.resize_image_with_crop_or_pad(
      float_image, height, width)
    image_eval = tf.image.per_image_standardization(resized_image)
    image_eval.set_shape([height, width, 3])
    label.set_shape([1])
    return label, image_eval


def cifar10_iterator(filenames, batch_size, cifar10_record_parser):
    ''' Create a dataset and return a tf.contrib.data.Iterator 
    which provides a way to extract elements from this dataset.
    -----
    Args:
        filenames: 
            a tensor of filenames.
        batch_size: 
            batch size.
    Returns:
        iterator: 
            an Iterator providing a way to extract elements from the created dataset.
        output_types: 
            the output types of the created dataset.
        output_shapes: 
            the output shapes of the created dataset.
    '''
    record_bytes = LABEL_BYTES + IMAGE_BYTES
    dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
    
    # TODO3
    # tips: use dataset.map with cifar10_record_parser(record)
   
    dataset = dataset.map(lambda record: cifar10_record_parser(record), num_parallel_calls=16)
    dataset = dataset.repeat()  # Repeat the input indefinitely.
    dataset = dataset.batch(BATCH_SIZE) # stack BATCH_SIZE elements into one
    iterator = dataset.make_one_shot_iterator() # iterator
    
    return iterator, dataset.output_types, dataset.output_shapes

In [3]:
tf.reset_default_graph()

training_files = [
    os.path.join(DATA_DIRECTORY, 'data_batch_%d.bin' % i) for i in range(1, 6)]
testing_files = [os.path.join(DATA_DIRECTORY, 'test_batch.bin')]

filenames_train = tf.constant(training_files)
filenames_test = tf.constant(testing_files)

iterator_train, types, shapes = cifar10_iterator(filenames_train, BATCH_SIZE,
                                                 cifar10_record_distort_parser)
iterator_test, _, _ = cifar10_iterator(filenames_test, BATCH_SIZE,
                                       cifar10_record_crop_parser)

# use to handle training and testing
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(handle, types, shapes)
labels_images_pairs = iterator.get_next()

# CNN model
model = CNN_Model(
    batch_size=BATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_training_example=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN,
    num_epoch_per_decay=350.0,
    init_lr=0.1,
    moving_average_decay=0.9999)

with tf.device('/cpu:0'):
  labels, images = labels_images_pairs
  labels = tf.reshape(labels, [BATCH_SIZE])
  images = tf.reshape(
      images, [BATCH_SIZE, IMAGE_SIZE_CROPPED, IMAGE_SIZE_CROPPED, IMAGE_DEPTH])
with tf.variable_scope('model'):
  logits = model.inference(images)
# train
global_step = tf.train.get_or_create_global_step()
total_loss = model.loss(logits, labels)
train_op = model.train(total_loss, global_step)
# test
top_k_op = tf.nn.in_top_k(logits, labels, 1)

In [4]:
%%time
# TODO4:
# 1. train the CNN model 10 epochs
# 2. show the loss per epoch
# 3. get the accuracy of this 10-epoch model
# 4. measure the time using '%%time' instruction
# tips:
# use placeholder handle to determine if training or testing. 

NUM_EPOCH = 10
NUM_BATCH_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN // BATCH_SIZE
ckpt_dir = './model/'

# train
saver = tf.train.Saver()
with tf.Session() as sess:
  ckpt = tf.train.get_checkpoint_state(ckpt_dir)
  if (ckpt and ckpt.model_checkpoint_path):
    saver.restore(sess, ckpt.model_checkpoint_path)
    # assume the name of checkpoint is like '.../model.ckpt-1000'
    gs = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
    sess.run(tf.assign(global_step, gs))
  else:
    # no checkpoint found
    sess.run(tf.global_variables_initializer())
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  loss = []
  train_iterator_handle = sess.run(iterator_train.string_handle())
  for i in range(NUM_EPOCH):
    _loss = []
    for _ in range(NUM_BATCH_PER_EPOCH):
      l, _ = sess.run([total_loss, train_op], feed_dict={handle: train_iterator_handle})
      _loss.append(l)
    loss_this_epoch = np.sum(_loss)
    gs = global_step.eval()
    print('loss of epoch %d: %f' % (gs / NUM_BATCH_PER_EPOCH, loss_this_epoch))
    loss.append(loss_this_epoch)
    saver.save(sess, ckpt_dir + 'model.ckpt', global_step=gs)
  coord.request_stop()
  coord.join(threads)
  
print('Done')

variables_to_restore = model.ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
  # Restore variables from disk.
  ckpt = tf.train.get_checkpoint_state(ckpt_dir)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_iter = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL // BATCH_SIZE
    total_sample_count = num_iter * BATCH_SIZE
    true_count = 0
    test_iterator_handle = sess.run(iterator_test.string_handle())
    for _ in range(num_iter):
      predictions = sess.run(top_k_op, feed_dict={handle: test_iterator_handle})
      true_count += np.sum(predictions)
    print('Accurarcy: %d/%d = %f' % (true_count, total_sample_count,
                                     true_count / total_sample_count))
    coord.request_stop()
    coord.join(threads)
  else:
    print('train first')

loss of epoch 1: 1516.859375
loss of epoch 2: 1192.032959
loss of epoch 3: 972.954529
loss of epoch 4: 814.379822
loss of epoch 5: 699.359558
loss of epoch 6: 614.506470
loss of epoch 7: 552.036133
loss of epoch 8: 508.715881
loss of epoch 9: 473.167358
loss of epoch 10: 447.772247
Done
INFO:tensorflow:Restoring parameters from ./model/model.ckpt-3900
Accurarcy: 7616/9984 = 0.762821
CPU times: user 4min 26s, sys: 20.4 s, total: 4min 46s
Wall time: 3min 18s
