In [13]:
import tensorflow as tf
import numpy as np
from matplotlib.pyplot import imshow
from tensorflow.contrib.layers.python.layers import batch_norm
from tensorflow.python.framework import ops
import os
TRAIN = True
INFER = False
BATCH_SIZE = 128
SAVE = True
global_step = tf.contrib.framework.get_or_create_global_step()
import urllib, StringIO
from PIL import Image

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [15]:
INFER_URLS = [
    ('deer', 'http://gfp.sd.gov/hunting/big-game/images/deer1-01.jpg'),
    ('frog', 'http://www.gaylordfunkyfish.com/images/Red-Eye-Tree-Frog-300x244.png'),
    ('airplane', 'https://media.licdn.com/mpr/mpr/jc/AAEAAQAAAAAAAAMiAAAAJGVlYTU5Y2YyLWQwMzYtNDlmZS04MDdlLWI0ZjJjZWRhYjk4ZQ.jpg'),
    ('ship', 'http://www.vships.com/media/92241/passenger-vessel.jpg'),
    ('dog', 'https://images-na.ssl-images-amazon.com/images/G/01/img15/pet-products/small-tiles/23695_pets_vertical_store_dogs_small_tile_8._CB312176604_.jpg'),
    ('bird', 'http://www.audubon.org/sites/default/files/styles/engagement_card/public/sfw_apa_2013_28342_232388_briankushner_blue_jay_kk_high.jpg')
]

In [2]:
def read_and_decode_single_example(filename, epochs=None):
    # first construct a queue containing a list of filenames.
    # this lets a user split up there dataset in multiple files to keep
    # size down
    filename_queue = tf.train.string_input_producer([filename],
                                                    num_epochs=epochs)
    # Unlike the TFRecordWriter, the TFRecordReader is symbolic
    reader = tf.TFRecordReader()
    # One can read a single serialized example from a filename
    # serialized_example is a Tensor of type string.
    _, serialized_example = reader.read(filename_queue)
    # The serialized example is converted back to actual values.
    # One needs to describe the format of the objects to be returned
    features = tf.parse_single_example(
        serialized_example,
        features={
            # We know the length of both fields. If not the
            # tf.VarLenFeature could be used
            'label': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], 'string')
        })
    # now return the converted data
    label = features['label']
    image_raw = features['image']
    image = tf.decode_raw(image_raw, tf.uint8)
    image = tf.reshape(image, [32,32,3])
    image = tf.cast(image, tf.float32) / 255
    return label, image

In [3]:
def distort_image(image):
    distorted_image = tf.random_crop(image, [24, 24, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=0.2)
    distorted_image = tf.image.random_contrast(distorted_image, lower=0.7, upper=1.3)
    # distorted_image = tf.image.per_image_whitening(distorted_image) # renamed to per_image_standardization in latest release
    distorted_image = tf.clip_by_value(distorted_image, 0, 1)
    return distorted_image

def simple_crop(image):
    return tf.random_crop(image, [24, 24, 3])

In [25]:
import urllib2
def load_image(url):
    print url
    data = urllib2.urlopen(url).read()
    print data
    file = StringIO.StringIO(urllib.urlopen(url).read())
    img = Image.open(file)
    file.close()
    img.thumbnail((24,24))
    img = img.convert('RGB')
    return np.array(img.getdata()) / 255.0
print load_image('http://images-na.ssl-images-amazon.com/images/G/01/img15/pet-products/small-tiles/23695_pets_vertical_store_dogs_small_tile_8._CB312176604_.jpg')

http://images-na.ssl-images-amazon.com/images/G/01/img15/pet-products/small-tiles/23695_pets_vertical_store_dogs_small_tile_8._CB312176604_.jpg


HTTPError: HTTP Error 403: Forbidden

In [4]:
%matplotlib inline

# create image batches:

label, image = read_and_decode_single_example(filename, epochs=(None if TRAIN else 1))
if INFER:
    
    def url_producer(urls):
        def load_nth_url(n):
            return io.imread(urls[n])
        url_iter = tf.Variable(0, name='url_iter').count_up_to(len(urls))
        return tf.py_func(load_nth_url, [url_iter], stateful=True, Tout=[tf.uint8])
else:
    filename = "cifar-train.tfrecords" if TRAIN else "cifar-test.tfrecords"
    if TRAIN:
        images_batch, labels_batch = tf.train.shuffle_batch([distort_image(image), label], batch_size=BATCH_SIZE, capacity=1000, min_after_dequeue=500)
    else:
        images_batch, labels_batch = tf.train.batch([simple_crop(image), label], batch_size=BATCH_SIZE, allow_smaller_final_batch=True)

In [5]:
dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')        

def weight_var(shape, stddev=0.1, weight_decay=0):
    initial = tf.truncated_normal(shape, stddev=stddev)
    v = tf.Variable(initial)
    if weight_decay > 0:
        l2 = tf.nn.l2_loss(v) * weight_decay
        tf.add_to_collection('losses', l2)
    return v

def create_fc(input, out_size):
    # input_dropped = tf.nn.dropout(input, dropout_keep_prob)
    in_size = input.get_shape()[-1].value
    w = weight_var([in_size, out_size], weight_decay=0.004)
    b = weight_var([out_size], weight_decay=0.004)
    x = tf.matmul(input, w)
    return tf.nn.relu(x + b)

def create_conv(input, out_channels, patch_size=3, stride=1, batch_norm=False, dropout=False):
    in_channels = input.get_shape()[-1].value
    w = weight_var([patch_size, patch_size, in_channels, out_channels])
    b = weight_var([out_channels], stddev=0)
    conv = tf.nn.conv2d(input, w, strides=[1,stride,stride,1], padding='SAME')
    if batch_norm: conv = create_batch_norm(conv)
    activation = tf.nn.relu(conv + b)
    if dropout: activation = create_dropout(activation)
    return activation

def create_max_pool(inputs, ksize=2, stride=2):
    return tf.nn.max_pool(inputs, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME')

def create_batch_norm(inputs):
    return batch_norm(inputs, is_training=TRAIN)

def create_dropout(inputs):
    return tf.nn.dropout(inputs, dropout_keep_prob)

In [6]:
def forward(images):
    # images are 24x24x3
    for size in [32, 64, 64]:
        images = create_conv(images, size, dropout=True, batch_norm=False)
        images = create_max_pool(images)
    # now images are 6x6x32 = 1176:
    n_channels = images.get_shape()[-1].value
    vecs = tf.reshape(images, [-1, 3*3*n_channels])
    # vecs = tf.reshape(images, [-1, 24*24*3])
    vecs = create_fc(vecs, 512)
    vecs = batch_norm(vecs)
    logits = create_fc(vecs, 10)
    return logits

In [7]:
NUM_EPOCHS_PER_DECAY = 300.0      # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.001       # Initial learning rate.
EXAMPLES_PER_EPOCH = 60000

learn_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  EXAMPLES_PER_EPOCH * NUM_EPOCHS_PER_DECAY / BATCH_SIZE,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=False)

# create optimizer:
logits = forward(images_batch)
cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_batch))
loss = cross_entropy # + tf.add_n(tf.get_collection('losses'))
optimizer = tf.train.AdamOptimizer(learn_rate)
train_step = optimizer.minimize(loss, global_step=global_step)

predictions = tf.argmax(tf.nn.softmax(logits), 1)
n_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, labels_batch), tf.float32))

In [8]:
# 16 is a good model
save_path = 'models/21'
if not os.path.exists(save_path):
    os.mkdir(save_path)

session = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
session.run(init_op)
    
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(save_path)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(session, ckpt.model_checkpoint_path)
    print 'Restored from checkpoint', ckpt.model_checkpoint_path
else:
    print 'Did not restore from checkpoint'

Restored from checkpoint models/21/model.ckpt-32850


In [9]:
tf.train.start_queue_runners(sess=session)
if TRAIN:
    while True:
        feed_dict = {dropout_keep_prob: 0.7}
        _, cur_loss, step, pred_labels, lr = session.run([train_step, loss, global_step, predictions, learn_rate], feed_dict=feed_dict)
        if step % 50 == 1:
            print "Step: {0}; Loss: {1}; learn rate: {2}".format(step, cur_loss, lr)
        if step % 150 == 0:
            if SAVE:
                saver.save(session, save_path + '/model.ckpt', global_step=step)
                print 'saved'
            print pred_labels
else:
    correct = 0
    total = 0
    i = 0
    while True:
        try:
            feed_dict = {dropout_keep_prob: 1}
            nc, pred = session.run([n_correct, predictions], feed_dict=feed_dict)
            correct += nc
            # if i % 10 == 0:
            #     print pred
            # this assumes all batches are BATCH_SIZE (the last one might be smaller) but that's okay
            total += BATCH_SIZE
            i += 1
            if i % 10 == 0:
                print "Accuracy so far: {0}".format(correct * 1.0 / total)
        except tf.errors.OutOfRangeError:
            print 'Done!'
            print "Final accuracy: {0}".format(correct * 1.0 / total)
            break

Step: 32851; Loss: 0.67416882515; learn rate: 0.000583982735407
Step: 32901; Loss: 0.563223361969; learn rate: 0.000583504850511
Step: 32951; Loss: 0.581252098083; learn rate: 0.00058302731486
saved
[1 3 2 5 2 8 1 6 6 9 3 7 8 8 7 3 5 8 7 4 2 9 3 1 9 9 9 2 8 1 1 1 8 8 4 8 8
 1 0 5 8 6 5 4 7 0 3 8 4 2 7 9 1 3 4 8 7 1 7 6 9 4 0 6 0 3 3 7 2 8 2 8 4 4
 1 4 7 4 6 3 3 5 7 1 6 2 1 3 3 0 7 0 7 9 3 0 7 1 4 0 8 7 3 5 2 1 3 1 6 0 5
 8 2 6 5 9 0 3 6 0 1 1 3 4 2 0 5 9]
Step: 33001; Loss: 0.729931712151; learn rate: 0.000582550186664
Step: 33051; Loss: 0.582016944885; learn rate: 0.000582073465921
Step: 33101; Loss: 0.673863530159; learn rate: 0.000581597094424
saved
[0 0 5 0 6 3 9 1 6 4 2 7 4 9 0 1 2 6 5 8 3 0 5 0 1 4 1 3 7 8 4 2 2 3 0 8 1
 8 8 1 6 7 0 0 6 2 0 9 5 8 0 8 7 5 6 7 3 2 0 4 1 8 4 9 3 1 4 1 6 0 4 8 5 5
 7 1 9 2 7 3 7 8 0 8 6 8 1 8 0 2 3 5 6 6 6 9 8 2 2 1 0 5 4 2 8 3 0 5 0 8 6
 0 2 2 3 1 3 4 9 4 8 5 5 3 1 2 6 6]
Step: 33151; Loss: 0.735530495644; learn rate: 0.000581121188588
Step: 33201; 

KeyboardInterrupt: 