In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
import matplotlib.pyplot as plt

In [2]:
def parser(record):
    img_features = tf.parse_single_example(
            record,
            features={ 'Label'    : tf.FixedLenFeature([], tf.int64),
                       'image_raw': tf.FixedLenFeature([], tf.string),
                       'height':tf.FixedLenFeature([],tf.int64),
                       'width':tf.FixedLenFeature([],tf.int64),
                       'channel':tf.FixedLenFeature([],tf.int64)})
    height = tf.cast(img_features['height'], tf.int64)
    width = tf.cast(img_features['width'], tf.int64)
    channel = tf.cast(img_features['channel'], tf.int64)
#     image_shape = tf.parallel_stack([height, width, channel])
    label = tf.cast(img_features['Label'], tf.int64)
    
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [299,299,3])
#     image.set_shape([299,299,3])
    
    return image, label

In [3]:
def preprocess_for_train(image):
    image_data = tf.image.resize_images(image,[299,299],method=0)
    image_data = tf.image.random_saturation(image_data,lower=0.3,upper=1.3)
    image_data = tf.image.random_brightness(image_data,max_delta=60. /255.) #亮度
#     image_data = tf.image.random_contrast(image_data,0.5,1.5)
    image_data = tf.image.random_flip_left_right(image_data)
    return image_data

In [4]:
train_files = tf.train.match_filenames_once("./data/Train.tfrecords")
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser)

test_files = tf.train.match_filenames_once("./data/Train.tfrecords")
test_dataset = tf.data.TFRecordDataset(test_files)
test_dataset = test_dataset.map(parser)

dataset = dataset.map(
    lambda image,label:(
        preprocess_for_train(image),label))
dataset = dataset.shuffle(2000).batch(8)

test_dataset = test_dataset.map(
    lambda image,label:(
        tf.image.resize_images(image,[299,299],method=0),label))
test_dataset = test_dataset.batch(32)

In [5]:
def get_tuned_variable():
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    variable_to_restore = []
#     slim.get_variables_to_restore()
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variable_to_restore.append(var)
    return variable_to_restore

def get_trainable_variable():
    scopes = [scope.strip() for scope in TRAINBLE_SCOPES.split(",")]
    variable_to_train = []
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope)
        variable_to_train.append(variables)
    return variable_to_train

In [6]:
NUM_EPOCH = 20
dataset = dataset.repeat(NUM_EPOCH)
Label_size = 830

iterator = dataset.make_initializable_iterator()
image_batch,label_batch = iterator.get_next()

test_iterator = test_dataset.make_initializable_iterator()
test_image_batch,test_label_batch = test_iterator.get_next()


In [7]:
CKPT_FILE = "inception_v3.ckpt"
MODEL_SAVE = "./transfer_model"

CHECKPOINT_EXCLUDE_SCOPES = "InceptionV3/Logits,InceptionV3/AuxLogits"

TRAINBLE_SCOPES = "InceptionV3/Logits,InceptionV3/AuxLogits"

learning_rate = 0.0001
batch_size = 8
#tf.reset_default_graph() 

initializer = tf.contrib.layers.xavier_initializer()
def weight_variable(shape):
    return tf.Variable(initializer(shape))

def bias_variable(shape):
    return tf.Variable(initializer(shape))

def conv2d(x,W):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding="SAME")

def max_pool(x):
    return tf.nn.max_pool(x,ksize=[1,4,4,1],strides=[1,4,4,1],padding="SAME")

x = tf.placeholder(tf.float32,[None,299,299,3])
y = tf.placeholder(tf.int32,[None])
image_labels = tf.one_hot(y,Label_size)

with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits,_ = inception_v3.inception_v3(x,num_classes=Label_size)

trainable_variable = get_trainable_variable()


with tf.name_scope("loss"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=image_labels,logits=logits))
    tf.summary.scalar('loss',loss)
with tf.name_scope('train'):
#     train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)

with tf.name_scope('evaluation'):
    with tf.name_scope('correct_prediction'):  
        correct_prediction = tf.equal(tf.argmax(logits,1),tf.argmax(image_labels,1))
    with tf.name_scope('accuracy'):
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        tf.summary.scalar('accuracy',accuracy)
        
load_fn = slim.assign_from_checkpoint_fn(
    CKPT_FILE,
    get_tuned_variable(),
    ignore_missing_vars=True
)

merged = tf.summary.merge_all()

In [11]:
with tf.Session() as sess:
    load_fn(sess)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run(iterator.initializer)
    
    count = 0
    while True:
        try:
            image,label = sess.run([image_batch,label_batch])
            sess.run(train_step, feed_dict={x: image, y: label})
            count+=1
            if count % 100 == 0:
                sess.run(test_iterator.initializer)
                test_acc = []
                test_losses = []
                while True:
                    try:
                        test_image,test_label = sess.run([test_image_batch,test_label_batch])
                        test_accuracy,test_loss = sess.run([accuracy,loss],feed_dict={x: test_image,y:test_label})
                        test_acc.append(test_accuracy)
                        test_losses.append(test_loss)
                    except:
                        print(np.mean(test_acc),np.mean(test_losses))
                        break
#                     image,label = sess.run([image_batch,label_batch])
#                 train_accuracy = accuracy.eval(feed_dict={x: image, y: label})
#                 print('Iter %d, accuracy %4.2f%%' % (count, train_accuracy*100))
            
        except err:
            print(err)
            print("Done")
            break

INFO:tensorflow:Restoring parameters from inception_v3.ckpt
0.008384147 6.9812703
0.0060975607 6.6870527
0.006859756 6.5500846
0.011779379 6.428483
0.013719512 6.3462524
0.017114745 6.247007
0.017530488 6.157304
0.022865854 6.085096
0.026261086 5.9564896
0.023628049 5.8510528
0.025914634 5.769358
0.019817073 5.6991334
0.037347563 5.6067066


NameError: name 'err' is not defined

* LOSS 下降速度太慢(Learning rate 過小)，缺Training_acc,Training_loss，沒畫圖。 之後更新