In [1]:
import os
import tensorflow as tf
import os.path as ops
import time
import numpy as np
from crnn_model import crnn_model
from local_utils import data_utils, log_utils
from global_configuration import config
logger = log_utils.init_logger()
tf.logging.set_verbosity(tf.logging.ERROR)

In [2]:
default_ignore_prune_condition = lambda layer: 'kernel:0' not in layer.name and 'w:0' not in layer.name and 'W:0' not in layer.name

def prune(sess, opt, loss, percent_to_prune, global_step, ignore_prune_condition=default_ignore_prune_condition, verbose=True):
    pruned_train_gradient = []
    computed_losses = opt.compute_gradients(loss)
    print("Pruning...")
    #print(tf.trainable_variables())
    for layer in tf.trainable_variables():
        with (tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))):
            layer_gradient = [x for x in computed_losses if layer.name==x[1].name][0]
            if ignore_prune_condition(layer):
                pruned_train_gradient.append(layer_gradient)
                continue
            layer_weights = sess.run(layer)

            # Calculate threshold and masks
            magnitudes = np.abs(layer_weights.flatten())
            magnitudes = magnitudes[magnitudes!=0] # Keep the given percent of the un-pruned weights
            magnitudes.sort()
            threshold = magnitudes[int(magnitudes.shape[0]*percent_to_prune)]
            mask = (layer_weights >= threshold) | (layer_weights <= -threshold)

            # Report % pruned
            percent_pruned = 1-np.average(np.array(mask, dtype='int'))
            if (verbose):
                print("Percent Pruned: ", percent_pruned, "("+layer.name+")")

            pruned_train_gradient.append([tf.multiply(layer_gradient[0], mask), layer_gradient[1]])
        sess.run(tf.assign(layer,tf.multiply(layer,mask)))
    with (tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))):
        train_op = opt.apply_gradients(pruned_train_gradient, global_step=global_step)
    return train_op

In [3]:
def train_shadownet(dataset_dir, weights_path=None):
    """

    :param dataset_dir:
    :param weights_path:
    :return:
    """
    # decode the tf records to get the training data
    decoder = data_utils.TextFeatureIO().reader
    images, labels, imagenames = decoder.read_features(ops.join(dataset_dir, 'train_feature.tfrecords'),
                                                       num_epochs=None)
    inputdata, input_labels, input_imagenames = tf.train.shuffle_batch(
        tensors=[images, labels, imagenames], batch_size=32, capacity=1000+2*32, min_after_dequeue=100, num_threads=1)

    inputdata = tf.cast(x=inputdata, dtype=tf.float32)

    # initializa the net model
    shadownet = crnn_model.ShadowNet(phase='Train', hidden_nums=256, layers_nums=2, seq_length=25, num_classes=37)

    with tf.variable_scope('shadow', reuse=False):
        net_out = shadownet.build_shadownet(inputdata=inputdata)

    cost = tf.reduce_mean(tf.nn.ctc_loss(labels=input_labels, inputs=net_out, sequence_length=25*np.ones(32)))

    decoded, log_prob = tf.nn.ctc_beam_search_decoder(net_out, 25*np.ones(32), merge_repeated=False)

    sequence_dist = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), input_labels))

    global_step = tf.Variable(0, name='global_step', trainable=False)

    starter_learning_rate = config.cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                               config.cfg.TRAIN.LR_DECAY_STEPS, config.cfg.TRAIN.LR_DECAY_RATE,
                                               staircase=True)
    opt=tf.train.AdadeltaOptimizer(learning_rate=learning_rate)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = opt.minimize(loss=cost, global_step=global_step)

    # Set saver configuration
    saver = tf.train.Saver(max_to_keep=0)
    model_save_dir = 'model/shadownet/checkpoints'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    model_save_path = model_save_dir + "/model.ckpt"
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))

    # Set sess configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.cfg.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.cfg.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    # Set the training parameters
    train_epochs = 500
    max_iters = 30
    display_steps = 10
    pruning_percent_per_iteration = 5

    with sess.as_default():
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for iteration in range(27, max_iters):
            print("Iteration %d/%d\n" % (iteration+1, max_iters))
            saver.save(sess=sess, save_path=model_save_dir+"/model_pre_pruned.ckpt", global_step=iteration, write_meta_graph=False)
            train_op = prune(sess, opt, cost, pruning_percent_per_iteration/100, global_step)
            nonzero_weights = 0
            total_weights = 0
            for layer in tf.trainable_variables():
                    layer_weights = sess.run(layer)
                    #print(layer.name, np.count_nonzero(layer_weights)/np.prod(layer_weights.shape))
                    nonzero_weights += np.count_nonzero(layer_weights)
                    total_weights += np.prod(layer_weights.shape)
            print("Reduced to %f percent of its original size." % (100*(nonzero_weights/total_weights)))
            saver.save(sess=sess, save_path=model_save_dir+"/model_post_pruned.ckpt", global_step=iteration, write_meta_graph=False)
            for epoch in range(train_epochs):
                _, c, seq_distance, preds, gt_labels = sess.run(
                    [train_op, cost, sequence_dist, decoded, input_labels])

                # calculate the precision
                preds = decoder.sparse_tensor_to_str(preds[0])
                gt_labels = decoder.sparse_tensor_to_str(gt_labels)

                accuracy = []

                for index, gt_label in enumerate(gt_labels):
                    pred = preds[index]
                    totol_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / totol_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)
                accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
                #
                if epoch % display_steps == 0:
                    logger.info('Epoch: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'.format(
                        epoch + 1, c, seq_distance, accuracy))
            saver.save(sess=sess, save_path=model_save_path, global_step=iteration, write_meta_graph=False)
        coord.request_stop()
        coord.join(threads=threads)

    sess.close()

    return

In [None]:
dataset_dir = "dataset/old_only"
#weights_path = "model/shadownet/shadownet_2019-02-14-03-49-38.ckpt-6076"
weights_path = "model/shadownet/checkpoints/5_prune_25_iter_50_epochs/model.ckpt-26"
if not ops.exists(dataset_dir):
    raise ValueError('{:s} doesn\'t exist'.format(dataset_dir))
train_shadownet(dataset_dir, weights_path)

INFO: 03-06 16:26:35: <ipython-input-3-e72a5828ab27>:66 * 139638111557440 Restore model from model/shadownet/checkpoints/5_prune_25_iter_50_epochs/model.ckpt-26


Iteration 28/30

Pruning...
Percent Pruned:  0.7453703703703703 (shadow/conv1/W:0)
Percent Pruned:  0.7495524088541667 (shadow/conv2/W:0)
Percent Pruned:  0.7496270073784722 (shadow/conv3/W:0)
Percent Pruned:  0.7496422661675347 (shadow/conv4/W:0)
Percent Pruned:  0.7496490478515625 (shadow/conv5/W:0)
Percent Pruned:  0.7496528625488281 (shadow/conv6/W:0)
Percent Pruned:  0.7496490478515625 (shadow/conv7/W:0)
Percent Pruned:  0.7496452331542969 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7496452331542969 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7496452331542969 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7496439615885417 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7492609797297297 (shadow/LSTMLayers/

INFO: 03-06 16:26:48: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 1 cost= 20.617172 seq distance=  0.671875 train accuracy=  0.229167
INFO: 03-06 16:27:24: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 11 cost= 14.091206 seq distance=  0.588542 train accuracy=  0.333333
INFO: 03-06 16:28:01: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 21 cost= 13.748120 seq distance=  0.531250 train accuracy=  0.411458
INFO: 03-06 16:28:38: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 31 cost= 12.187616 seq distance=  0.432292 train accuracy=  0.442708
INFO: 03-06 16:29:14: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 41 cost=  9.777232 seq distance=  0.427083 train accuracy=  0.447917
INFO: 03-06 16:29:50: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 51 cost=  7.025039 seq distance=  0.307292 train accuracy=  0.645833
INFO: 03-06 16:30:27: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 61 cost=  7.7

Iteration 29/30

Pruning...
Percent Pruned:  0.7581018518518519 (shadow/conv1/W:0)
Percent Pruned:  0.7620713975694444 (shadow/conv2/W:0)
Percent Pruned:  0.7621426052517362 (shadow/conv3/W:0)
Percent Pruned:  0.7621595594618056 (shadow/conv4/W:0)
Percent Pruned:  0.7621663411458334 (shadow/conv5/W:0)
Percent Pruned:  0.762170155843099 (shadow/conv6/W:0)
Percent Pruned:  0.7621660232543945 (shadow/conv7/W:0)
Percent Pruned:  0.7621625264485677 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7621625264485677 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7621625264485677 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7621612548828125 (shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/basic_lstm_cell/kernel:0)
Percent Pruned:  0.7617715371621622 (shadow/LSTMLayers/w

INFO: 03-06 16:58:10: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 1 cost= 34.921154 seq distance=  0.916667 train accuracy=  0.182292
INFO: 03-06 16:58:48: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 11 cost= 25.020504 seq distance=  0.776042 train accuracy=  0.250000
INFO: 03-06 16:59:25: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 21 cost= 20.896587 seq distance=  0.630208 train accuracy=  0.322917
INFO: 03-06 17:00:02: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 31 cost= 18.344847 seq distance=  0.635417 train accuracy=  0.296875
INFO: 03-06 17:00:39: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 41 cost= 15.774158 seq distance=  0.536458 train accuracy=  0.385417
INFO: 03-06 17:01:16: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 51 cost= 14.454824 seq distance=  0.552083 train accuracy=  0.364583
INFO: 03-06 17:01:52: <ipython-input-3-e72a5828ab27>:117 * 139638111557440 Epoch: 61 cost= 13.2

## Trainable Layers:


[<tf.Variable 'shadow/conv1/W:0' shape=(3, 3, 3, 64) dtype=float32_ref>, <tf.Variable 'shadow/conv2/W:0' shape=(3, 3, 64, 128) dtype=float32_ref>, <tf.Variable 'shadow/conv3/W:0' shape=(3, 3, 128, 256) dtype=float32_ref>, <tf.Variable 'shadow/conv4/W:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'shadow/conv5/W:0' shape=(3, 3, 256, 512) dtype=float32_ref>, <tf.Variable 'shadow/BatchNorm/beta:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'shadow/BatchNorm/gamma:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'shadow/conv6/W:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'shadow/BatchNorm_1/beta:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'shadow/BatchNorm_1/gamma:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'shadow/conv7/W:0' shape=(2, 2, 512, 512) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(768, 1024) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(1024,) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/basic_lstm_cell/kernel:0' shape=(768, 1024) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_0/bidirectional_rnn/bw/basic_lstm_cell/bias:0' shape=(1024,) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(768, 1024) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(1024,) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/basic_lstm_cell/kernel:0' shape=(768, 1024) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/stack_bidirectional_rnn/cell_1/bidirectional_rnn/bw/basic_lstm_cell/bias:0' shape=(1024,) dtype=float32_ref>, <tf.Variable 'shadow/LSTMLayers/w:0' shape=(512, 37) dtype=float32_ref>]