In [1]:
import sys
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim

sys.path.append('../datasets/')
sys.path.append('../nets/')
sys.path.append('../preprocessing/')

import orientset
oriset = orientset.get_split('', '../data/tfrecord/')

In [2]:
def ori_network(inputs, is_training=True, keep_prob=0.5, scope='ori_network'):
    with tf.variable_scope(scope, 'ori_network', [inputs]) as sc:
        with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        activation_fn=tf.nn.relu,
                        weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
                        weights_regularizer=slim.l2_regularizer(0.0005)):
            net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
            net = slim.max_pool2d(net, [2, 2], scope='pool1')
            net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
            net = slim.max_pool2d(net, [2, 2], scope='pool2')
            net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
            net = slim.max_pool2d(net, [2, 2], scope='pool3')
            net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
            net = slim.max_pool2d(net, [2, 2], scope='pool4')
            net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
            net = slim.max_pool2d(net, [2, 2], scope='pool5')
            net = slim.conv2d(net, 1024, [7, 7], padding='VALID', scope='fc6')
            net = slim.dropout(net, keep_prob, is_training=is_training,
                         scope='dropout6')
            # fully-connected network
            net = slim.fully_connected(net, 512, scope='fc7')
            net = slim.dropout(net, keep_prob, is_training=is_training, scope='dropout7')
            # output layer which has four classes
            net = slim.fully_connected(net, 4, scope='fc8', activation_fn=None)
            net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
            return net

## restore and trainable spec

In [3]:
import vgg_preprocessing, vgg

In [None]:
with tf.Graph().as_default():
    # read train data
    data_provider = slim.dataset_data_provider.DatasetDataProvider(oriset)
    image, label = data_provider.get(['image', 'label'])
    VGG_IMAGE_SIZE = vgg.vgg_16.default_image_size
    image = vgg_preprocessing.preprocess_for_train(image, VGG_IMAGE_SIZE, VGG_IMAGE_SIZE)
    
    # batch data
    batch_image, batch_label = tf.train.batch([image, label], batch_size=32, allow_smaller_final_batch=True)
    batch_one_hot_label = slim.one_hot_encoding(batch_label, oriset.num_classes)
    batch_one_hot_label = tf.squeeze(batch_one_hot_label, [1])
    
    # create the training net
    logits = ori_network(batch_image, is_training=True)

    # create loss
    total_loss = tf.losses.softmax_cross_entropy(batch_one_hot_label, logits)
    tf.summary.scalar('total_loss', total_loss)
    
    # find the variablee we want to train
    scopes =['ori_network/fc6', 'ori_network/fc7', 'ori_network/fc8']
    variables_to_train =[]
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    
    # restore the specified layers' parameters
    variables_to_restore = slim.get_variables_to_restore(exclude=scopes)
    variables_to_restore = { var.op.name.replace('ori_network', 'vgg_16'):var for var in variables_to_restore}
    
    # create optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)

    # create train_op
    train_op = slim.learning.create_train_op(total_loss, 
                                             optimizer, 
                                             variables_to_train=variables_to_train,
                                             summarize_gradients=True)
    
    # restore parameters
    init_fn = slim.assign_from_checkpoint_fn('../checkpoints/vgg_16.ckpt',variables_to_restore)
    
    # start to learn
    slim.learning.train(train_op, './logs/', log_every_n_steps=1, 
                        init_fn=init_fn, 
                        save_summaries_secs=30)

INFO:tensorflow:Restoring parameters from ../checkpoints/vgg_16.ckpt
INFO:tensorflow:Starting Session.
INFO:tensorflow:Starting Queues.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:global_step/sec: 0
