In [1]:
import argparse
import sys
import os
import tensorflow as tf
from segmentation import seg_dataset_reader
from utils import weight_variable, bias_variable, conv2d_basic, conv2d_transpose_strided, max_pool_2x2
import datetime


FLAGS = None

  return f(*args, **kwds)


In [2]:
def conv_layer(input,r_field,input_c,out_c,nr):
    W = weight_variable([r_field, r_field, input_c, out_c], name="W"+str(nr))
    b = bias_variable([out_c], name="b"+str(nr))
    conv = conv2d_basic(input, W, b, name="conv"+str(nr))
    relu = tf.nn.relu(conv, name="relu"+str(nr))
    return relu

In [3]:
def deconv_layer(input,r_field,in_channels,out_channels, out_shape,nr, stride=2):
    W = weight_variable([r_field, r_field, out_channels, in_channels], name="W_t"+nr)
    b = bias_variable([out_channels], name="b_t"+nr)
    conv_t1 = conv2d_transpose_strided(input, W, b, out_shape)
    return conv_t1

In [4]:
def segment(image, keep_prob_conv, input_channels, output_channels, scope):

    with tf.variable_scope(scope):

        ###############
        # downsample  #
        ###############

        W2 = weight_variable([3, 3, input_channels, 64], name="W2")
        b2 = bias_variable([64], name="b2")
        conv2 = conv2d_basic(image, W2, b2, name="conv2")
        relu2 = tf.nn.relu(conv2, name="relu2")
        pool2 = max_pool_2x2(relu2)
        dropout2 = tf.nn.dropout(pool2, keep_prob=keep_prob_conv)

        W3 = weight_variable([3, 3, 64, 128], name="W3")
        b3 = bias_variable([128], name="b3")
        conv3 = conv2d_basic(dropout2, W3, b3, name="conv3")
        relu3 = tf.nn.relu(conv3, name="relu3")
        pool3 = max_pool_2x2(relu3)
        dropout3 = tf.nn.dropout(pool3, keep_prob=keep_prob_conv)

        W4 = weight_variable([3, 3, 128, 256], name="W4")
        b4 = bias_variable([256], name="b4")
        conv4 = conv2d_basic(dropout3, W4, b4, name="conv4")
        relu4 = tf.nn.relu(conv4, name="relu4")
        pool4 = max_pool_2x2(relu4)
        dropout4 = tf.nn.dropout(pool4, keep_prob=keep_prob_conv)

        W5 = weight_variable([3, 3, 256, 512], name="W5")
        b5 = bias_variable([512], name="b5")
        conv5 = conv2d_basic(dropout4, W5, b5, name="conv5")
        relu5 = tf.nn.relu(conv5, name="relu5")
        pool5 = max_pool_2x2(relu5)
        dropout5 = tf.nn.dropout(pool5, keep_prob=keep_prob_conv)

        W6 = weight_variable([3, 3, 512, 512], name="W6")
        b6 = bias_variable([512], name="b6")
        conv6 = conv2d_basic(dropout5, W6, b6, name="conv6")
        relu6 = tf.nn.relu(conv6, name="relu6")
        pool6 = max_pool_2x2(relu6)
        dropout6 = tf.nn.dropout(pool6, keep_prob=keep_prob_conv)

        W7 = weight_variable([3, 3, 512, 4096], name="W7")
        b7 = bias_variable([4096], name="b7")
        conv7 = conv2d_basic(dropout6, W7, b7, name="conv7")

        ############
        # upsample #
        ############

        deconv_shape1 = pool5.get_shape()
        W_t1 = weight_variable([4, 4, deconv_shape1[3].value, 4096], name="W_t1")
        b_t1 = bias_variable([deconv_shape1[3].value], name="b_t1")
        conv_t1 = conv2d_transpose_strided(conv7, W_t1, b_t1, output_shape=tf.shape(pool5))

        stacked_1 = tf.concat([conv_t1, pool5], -1)
        fuse_1_1 = conv_layer(stacked_1, 1, 2*deconv_shape1[3].value, deconv_shape1[3].value, "fuse_1_1")
        fuse_1_2 = conv_layer(fuse_1_1, 1, deconv_shape1[3].value, deconv_shape1[3].value, "fuse_1_2")

        deconv_shape2 = pool4.get_shape()
        W_t2 = weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
        b_t2 = bias_variable([deconv_shape2[3].value], name="b_t2")
        conv_t2 = conv2d_transpose_strided(fuse_1_2, W_t2, b_t2, output_shape=tf.shape(pool4))

        stacked_2 = tf.concat([conv_t2, pool4], -1)
        fuse_2_1 = conv_layer(stacked_2, 1, 2*deconv_shape2[3].value, deconv_shape2[3].value, "fuse_2_1")
        fuse_2_2 = conv_layer(fuse_2_1, 1, deconv_shape2[3].value, deconv_shape2[3].value, "fuse_2_2")

        deconv_shape3 = pool3.get_shape()
        W_t3 = weight_variable([4, 4, deconv_shape3[3].value, deconv_shape2[3].value], name="W_t3")
        b_t3 = bias_variable([deconv_shape3[3].value], name="b_t3")
        conv_t3 = conv2d_transpose_strided(fuse_2_2, W_t3, b_t3, output_shape=tf.shape(pool3))

        stacked_3 = tf.concat([conv_t3, pool3], -1)
        fuse_3_1 = conv_layer(stacked_3, 1, 2*deconv_shape3[3].value, deconv_shape3[3].value, "fuse_3_1")
        fuse_3_2 = conv_layer(fuse_3_1, 1, deconv_shape3[3].value, deconv_shape3[3].value, "fuse_3_2")

        deconv_shape4 = pool2.get_shape()
        W_t4 = weight_variable([4, 4, deconv_shape4[3].value, deconv_shape3[3].value], name="W_t4")
        b_t4 = bias_variable([deconv_shape4[3].value], name="b_t4")
        conv_t4 = conv2d_transpose_strided(fuse_3_2, W_t4, b_t4, output_shape=tf.shape(pool2))

        stacked_4 = tf.concat([conv_t4, pool2], -1)
        fuse_4_1 = conv_layer(stacked_4, 1, 2*deconv_shape4[3].value, deconv_shape4[3].value, "fuse_4_1")
        fuse_4_2 = conv_layer(fuse_4_1, 1, deconv_shape4[3].value, deconv_shape4[3].value, "fuse_4_2")

        # do the final upscaling
        shape = tf.shape(image)
        deconv_shape5 = tf.stack([shape[0], shape[1], shape[2], output_channels])
        W_t5 = weight_variable([16, 16, output_channels, deconv_shape4[3].value], name="W_t5")
        b_t5 = bias_variable([output_channels], name="b_t5")
        conv_t5 = conv2d_transpose_strided(fuse_4_2, W_t5, b_t5, output_shape=deconv_shape5, stride =2)


    annotation_pred = tf.argmax(conv_t5, dimension=3, name="prediction")
    return tf.expand_dims(annotation_pred, dim=3), conv_t5

In [5]:
def train(loss_val, var_list):
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    grads = optimizer.compute_gradients(loss_val, var_list=var_list)
    if FLAGS.debug:
        # print(len(var_list))
        for grad, var in grads:
            add_gradient_summary(grad, var)
    return optimizer.apply_gradients(grads)

In [6]:
def main(unused_argv):
    print("Setting up image reader...")
    data_reader = seg_dataset_reader(FLAGS.data_dir, crop=FLAGS.crop, crop_size=FLAGS.crop_size)
    print("Images read")

    #Placeholders for FeedDict
    keep_probability_conv = tf.placeholder(tf.float32, name="keep_probability_conv")
    image = tf.placeholder(tf.float32, shape=[None, FLAGS.crop_size[0], FLAGS.crop_size[0], 1], name="image")
    annotation = tf.placeholder(tf.int32, shape=[None, FLAGS.crop_size[0], FLAGS.crop_size[0], 1], name="labels")

    # Apply FCN
    pred_annotation, logits = segment(image, keep_probability_conv, 1, FLAGS.nr_classes, "labels")

    # compute cross-entropy loss
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="loss_labels")))
    # set up adam-optimizer
    trainable_var = tf.trainable_variables()
    train_op = train(loss, trainable_var)

    # get TF session
    sess = tf.Session()

    # set up saver
    saver = tf.train.Saver()


    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) # get the step from the last checkpoint
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")
    else:
        step = 0

    for itr in range(step, FLAGS.MAX_ITERATION):
        train_images, train_annotations= data_reader.next_batch(FLAGS.batch_size)
        feed_dict = {image: train_images, annotation: train_annotations, keep_probability_conv: 0.85}
        sess.run(train_op, feed_dict=feed_dict)

        print(itr)

        if itr % 10 == 0:
            train_loss = sess.run([loss], feed_dict=feed_dict)
            print("Step: %d, Train_loss: %g" % (itr, train_loss[0]))

        if itr % 500 == 0 and itr != 0:
            valid_images, valid_annotations = data_reader.get_test_records()
            valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability_conv: 1.0})
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
    a,b = data_reader.get_test_records()
    valid_loss, output = sess.run([loss, pred_annotation], feed_dict={image: a, annotation: b, keep_probability_conv: 1.0})

In [7]:
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str,
                      default='deep_scores/',
                      help='Directory for storing input data')
  parser.add_argument("--batch_size", type=int, default=1, help="batch size for training")
  parser.add_argument("--crop", type=bool, default=True, help="batch size for training")
  parser.add_argument("--crop_size", type=bytearray, default=[500,500], help="batch size for training")
  parser.add_argument("--nr_classes", type=int, default=124, help="batch size for training")
  parser.add_argument("--logs_dir", type=str, default="logs/", help="path to logs directory")
  parser.add_argument("--MAX_ITERATION", type=int, default=10000, help="path to logs directory")
  parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for Adam Optimizer")
  parser.add_argument("--debug", type=bool, default=False, help="debug yes/no")
  FLAGS, unparsed = parser.parse_known_args()
  print(FLAGS)
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Namespace(MAX_ITERATION=1000, batch_size=1, crop=True, crop_size=[500, 500], data_dir='deep_scores/', debug=False, learning_rate=0.0001, logs_dir='logs/', nr_classes=124)
Setting up image reader...
Initializing DeepScores Classification Batch Dataset Reader...
Splitting dataset, train: 20 images, test: 20 images
im working!3
im working!4
im working!3
im working!8
im working!5
im working!5
im working!7
im working!7
im working!6
im working!8
im working!9
im working!6
im working!4
im working!5
im working!8
im working!0
im working!8
im working!10
im working!0
im working!4
Training set done
im working!4
im working!2
im working!9
im working!8
im working!10
im working!7
im working!1
im working!5
im working!10
im working!0
im working!2
im working!1
im working!4
im working!8
im working!8
im working!6
im working!0
im working!10
im working!7
im working!1
Test set done
Images read
Instructions for updating:
Use the `axis` argument instead
0
Step: 0, Train_loss: 4.72356
1
2
3
4
5
6
7
8
9
10
Step: 1

Step: 700, Train_loss: 0.0292291
701
702
703
704
705
706
707
708
709
710
Step: 710, Train_loss: 0.00519398
711
712
713
714
715
716
717
718
719
****************** Epochs completed: 36******************
720
Step: 720, Train_loss: 0.0665416
721
722
723
724
725
726
727
728
729
730
Step: 730, Train_loss: 0.0193583
731
732
733
734
735
736
737
738
739
****************** Epochs completed: 37******************
740
Step: 740, Train_loss: 0.0987964
741
742
743
744
745
746
747
748
749
750
Step: 750, Train_loss: 0.0122018
751
752
753
754
755
756
757
758
759
****************** Epochs completed: 38******************
760
Step: 760, Train_loss: 0.0170351
761
762
763
764
765
766
767
768
769
770
Step: 770, Train_loss: 0.0100861
771
772
773
774
775
776
777
778
779
****************** Epochs completed: 39******************
780
Step: 780, Train_loss: 0.0919594
781
782
783
784
785
786
787
788
789
790
Step: 790, Train_loss: 0.00667265
791
792
793
794
795
796
797
798
799
****************** Epochs completed: 40*

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


Test out the result from the checkpoint