In [1]:
import os
from datetime import datetime
import time
import tensorflow as tf
from utils import Utils
from models import Models
import json

In [3]:
path_tfrecords = 'data/tfrecords'
path_results = 'data/results'

In [4]:
with open(os.path.join(path_tfrecords, 'meta.json'), 'r') as f:
    content = json.load(f)
    num_train = content['num_examples']['train']
    num_test = content['num_examples']['test']
    num_extra = content['num_examples']['extra']

In [5]:
training_options = {
    'batch_size': 32,
    'learning_rate': 0.01,
    'patience': 100,
    'decay_steps': 10000,
    'decay_rate': 0.9
}
init_patience = training_options['patience']

In [None]:
with tf.Graph().as_default():
    image_batch, length_batch, digits_batch = Utils.build_batch(
        os.path.join(path_tfrecords, 'train.tfrecords'), num_train, training_options['batch_size'], True) # shuffled = True  
    length, digits = Models.cnn_inference(image_batch, 0.2) # drop_rate = 0.2
    loss = Models.cnn_loss(length, digits, length_batch, digits_batch)
    tf_step = tf.Variable(0, name='tf_step', trainable=False)
    learning_rate = tf.train.exponential_decay(
        training_options['learning_rate'], 
        global_step=tf_step, 
        decay_steps=training_options['decay_steps'], 
        decay_rate=training_options['decay_rate'], 
        staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_optimizer = optimizer.minimize(loss, global_step=tf_step)
    tf.summary.image('image', image_batch)
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('learning_rate', learning_rate)
    summary = tf.summary.merge_all()
    with tf.Session() as sess:
        summary_writer = tf.summary.FileWriter(os.path.join(path_results, 'train'), sess.graph)
        sess.run(tf.global_variables_initializer())
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
        print('Training')
        patience = initial_patience
        best_accuracy = 0.0
        while True:
            _, loss_result, summary_result, tf_step_result, learning_rate_result = sess.run(
                [train_optimizer, loss, summary, tf_step, learning_rate])
            summary_writer.add_summary(summary_result, global_step=tf_step_result)
            print('Validation')
            accuracy = Models.evaluate(path_results, num_extra, tf_step_result)           
            if accuracy > best_accuracy:
                patience = initial_patience
                best_accuracy = accuracy
            else:
                patience -= 1
            if patience == 0:
                break
        coordinator.request_stop()
        coordinator.join(threads)
    print('Done')