In [1]:
import os
import numpy as np
import tensorflow as tf
from datetime import datetime
from alexnet import AlexNet
from datagenerator import ImageDataGenerator

"""
Configuration settings
"""

# Path to the textfiles for the trainings and validation set

train_file = 'train.txt'
val_file = 'val.txt'


# Learning params
learning_rate = 0.01
num_epochs = 5
batch_size = 5

# Network params
dropout_rate = 0.5
num_classes = 4
train_layers = ['fc8', 'fc7']

# How often we want to write the tf.summary data to disk
display_step = 1

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "C:/Users/Jiankun/Desktop/AlexNet/filewriter"
checkpoint_path = "C:/Users/Jiankun/Desktop/AlexNet/checkpoint"

# Create parent path if it doesn't exist
if not os.path.isdir(checkpoint_path): os.mkdir(checkpoint_path)

In [2]:
# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
y = tf.placeholder(tf.float32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32)

In [3]:
#tf.reset_default_graph()

# Initialize model
model = AlexNet(x, keep_prob, num_classes, train_layers)

# Link variable to model output
score = model.fc8

In [4]:
# List of trainable variables of the layers we want to train
var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]

# Op for calculating the loss
with tf.name_scope("cross_ent"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = score, labels = y))  

# Train op
with tf.name_scope("train"):
    # Get gradients of all trainable variables
    gradients = tf.gradients(loss, var_list)
    gradients = list(zip(gradients, var_list))
  
    # Create optimizer and apply gradient descent to the trainable variables
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)

# Add gradients to summary  
for gradient, var in gradients:
    tf.summary.histogram(var.name + '/gradient', gradient)

# Add the variables we train to the summary  
for var in var_list:
    tf.summary.histogram(var.name, var)

    
# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)
  


INFO:tensorflow:Summary name fc7/weights:0/gradient is illegal; using fc7/weights_0/gradient instead.
INFO:tensorflow:Summary name fc7/biases:0/gradient is illegal; using fc7/biases_0/gradient instead.
INFO:tensorflow:Summary name fc8/weights:0/gradient is illegal; using fc8/weights_0/gradient instead.
INFO:tensorflow:Summary name fc8/biases:0/gradient is illegal; using fc8/biases_0/gradient instead.
INFO:tensorflow:Summary name fc7/weights:0 is illegal; using fc7/weights_0 instead.
INFO:tensorflow:Summary name fc7/biases:0 is illegal; using fc7/biases_0 instead.
INFO:tensorflow:Summary name fc8/weights:0 is illegal; using fc8/weights_0 instead.
INFO:tensorflow:Summary name fc8/biases:0 is illegal; using fc8/biases_0 instead.


<tf.Tensor 'cross_entropy:0' shape=() dtype=string>

In [5]:
# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
    correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Add the accuracy to the summary
tf.summary.scalar('accuracy', accuracy)

# Merge all summaries together
merged_summary = tf.summary.merge_all()

# Initialize the FileWriter
writer = tf.summary.FileWriter(filewriter_path)

# Initialize an saver for store model checkpoints
saver = tf.train.Saver()

# Initalize the data generator seperately for the training and validation set
train_generator = ImageDataGenerator(train_file, 
                                     horizontal_flip = True, shuffle = True, nb_classes = num_classes)
val_generator = ImageDataGenerator(val_file, shuffle = False, nb_classes = num_classes) 

# Get the number of training/validation steps per epoch
train_batches_per_epoch = np.floor(train_generator.data_size / batch_size).astype(np.int16)
val_batches_per_epoch = np.floor(val_generator.data_size / batch_size).astype(np.int16)


In [6]:
# Start Tensorflow session
with tf.Session() as sess:
 
    # Initialize all variables
    sess.run(tf.global_variables_initializer())
  
    # Add the model graph to TensorBoard
    writer.add_graph(sess.graph)
  
    # Load the pretrained weights into the non-trainable layer
    model.load_initial_weights(sess)
  
    print("{} Start training...".format(datetime.now()))
    print("{} Open Tensorboard at --logdir {}".format(datetime.now(), 
                                                    filewriter_path))
  
    # Loop over number of epochs
    for epoch in range(num_epochs):
    
        print("{} Epoch number: {}".format(datetime.now(), epoch+1))
        
        step = 1
        
        while step < train_batches_per_epoch:
            
            # Get a batch of images and labels
            batch_xs, batch_ys = train_generator.next_batch(batch_size)
            
            # And run the training op
            sess.run(train_op, feed_dict={x: batch_xs, 
                                          y: batch_ys, 
                                          keep_prob: dropout_rate})
            
            # Generate summary with the current batch of data and write to file
            if step%display_step == 0:
                s = sess.run(merged_summary, feed_dict={x: batch_xs, 
                                                        y: batch_ys, 
                                                        keep_prob: 1.})
                writer.add_summary(s, epoch*train_batches_per_epoch + step)
                
            step += 1
            
        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        test_acc = 0.
        test_count = 0
        for _ in range(val_batches_per_epoch):
            batch_tx, batch_ty = val_generator.next_batch(batch_size)
            acc = sess.run(accuracy, feed_dict={x: batch_tx, 
                                                y: batch_ty, 
                                                keep_prob: 1.})
            test_acc += acc
            test_count += 1
        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(), test_acc))
        
        # Reset the file pointer of the image data generator
        val_generator.reset_pointer()
        train_generator.reset_pointer()
        
        print("{} Saving checkpoint of model...".format(datetime.now()))  
        
        #save checkpoint of the model
        checkpoint_name = os.path.join(checkpoint_path, 'model_epoch'+str(epoch+1)+'.ckpt')
        save_path = saver.save(sess, checkpoint_name)  
        
        print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name))
        
        

2017-11-04 16:15:39.402071 Start training...
2017-11-04 16:15:39.402071 Open Tensorboard at --logdir C:/Users/Jiankun/Desktop/AlexNet/filewriter
2017-11-04 16:15:39.402071 Epoch number: 1
2017-11-04 16:15:55.345748 Start validation
2017-11-04 16:15:56.023641 Validation Accuracy = 0.3000
2017-11-04 16:15:56.024142 Saving checkpoint of model...
2017-11-04 16:16:00.300105 Model checkpoint saved at C:/Users/Jiankun/Desktop/AlexNet/checkpoint\model_epoch1.ckpt
2017-11-04 16:16:00.300105 Epoch number: 2
2017-11-04 16:16:17.299408 Start validation
2017-11-04 16:16:17.971194 Validation Accuracy = 0.2000
2017-11-04 16:16:17.971695 Saving checkpoint of model...
2017-11-04 16:16:22.968850 Model checkpoint saved at C:/Users/Jiankun/Desktop/AlexNet/checkpoint\model_epoch2.ckpt
2017-11-04 16:16:22.968850 Epoch number: 3
2017-11-04 16:16:38.429826 Start validation
2017-11-04 16:16:39.061637 Validation Accuracy = 0.3000
2017-11-04 16:16:39.062138 Saving checkpoint of model...
2017-11-04 16:16:44.66160