In [1]:
import numpy as np

import tensorflow as tf
import tensorflow.contrib.slim as slim

import data_provider
import model

Instructions for updating:
Use the retry module or similar alternatives.


## Settings

In [2]:
TRAIN_DATA_ROOT = '/data/ImageNet/train/' # ADJUST
TRAIN_LIST_FILE = '/data/ImageNet/labels/train.txt' # ADJUST

BATCH_SIZE = 64  # ADJUST
CROP_SIZE = 224 # ADJUST
CLASSES = [281, 239] # Tabby cat, Bernese mountain dog
LR_START = 0.01 # ADJUST
LR_END = LR_START / 1e4 # ADJUST
MOMENTUM = 0.9 # ADJUST
NUM_EPOCHS = 1000 # ADJUST
OUTPUT_ROOT = '/data/logs/vggA_BN' # ADJUST
LOG_EVERY_N = 10 # ADJUST

## Preapre training data queue

In [3]:
train_image, train_label, num_samples = data_provider.imagenet_data(
    TRAIN_DATA_ROOT,
    TRAIN_LIST_FILE,
    batch_size=BATCH_SIZE,
    crop_size=(CROP_SIZE, CROP_SIZE),
    classes=CLASSES,
)

iters = NUM_EPOCHS * num_samples // BATCH_SIZE
print('Number of train samples: {}'.format(num_samples))
print('Number of train iterations: {}'.format(iters))

Number of train samples: 2600
Number of train iterations: 40625


## Build network graph

In [4]:
with tf.variable_scope('net'):
    logits = model.model(train_image, is_training=True)

## Add loss function

In [5]:
with tf.name_scope('loss'):
    loss = tf.losses.sparse_softmax_cross_entropy(
        labels=train_label, logits=logits)
_=tf.summary.scalar('loss/loss', loss)

## Setup training parameters and structures

In [6]:
global_step = tf.train.get_or_create_global_step()

lr = tf.train.polynomial_decay(LR_START, global_step, iters, LR_END)
tf.summary.scalar('learning_rate', lr)

optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
train_op = slim.learning.create_train_op(loss, 
    optimizer, global_step=global_step)

supervisor = tf.train.Supervisor(logdir=OUTPUT_ROOT,
    global_step=global_step,
    save_summaries_secs=60, # ADJUST
    save_model_secs=300, # ADJUST
)

Instructions for updating:
Please switch to tf.train.MonitoredTrainingSession


## Run training loop

In [7]:
with supervisor.managed_session() as sess:

    start_iter = sess.run(global_step)
    for it in range(start_iter, iters):

        loss_value = sess.run(train_op)

        if it % LOG_EVERY_N == 0:
            print('[{} / {}] loss_net = {}'.format(it, iters, loss_value))

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.
INFO:tensorflow:Saving checkpoint to path /data/logs/vggA_BN/model.ckpt
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Recording summary at step 0.
[0 / 40625] loss_net = 0.83679741621
[10 / 40625] loss_net = 0.723195254803
[20 / 40625] loss_net = 1.42143285275
[30 / 40625] loss_net = 1.17433404922
INFO:tensorflow:global_step/sec: 0.550915
INFO:tensorflow:Recording summary at step 33.
[40 / 40625] loss_net = 0.824943423271
[50 / 40625] loss_net = 0.632728517056
[60 / 40625] loss_net = 0.682134866714
[70 / 40625] loss_net = 0.819607257843
[80 / 40625] loss_net = 0.360083520412
[90 / 40625] loss_net = 0.87302428484
[100 / 40625] loss_net = 0.60107588768
[110 / 40625] loss_net = 0.481993943453
INFO:tensorflow:global_step/sec: 1.29998
INFO:tensorflow:Recording summary at step 111.
[120 / 40625] loss_net = 0.523851931095
[

KeyboardInterrupt: 