In [1]:
import argparse
import time

import tensorflow as tf
from PIL import ImageDraw, Image
from numpy.distutils.fcompiler import str2bool

from constants import VALIDATION_SIZE, TEST_SIZE
from dataset import load_dataset
from models import ALEXNET, VGGNET, LENET

def test(model, sess, saver, test_data, function, difficulty, log=True):
    """
    Tester
    """
    batch_image, batch_image_index = test_data.get_batch_tensor(batch_size=TEST_SIZE)

    if function == 'count' and difficulty == 'hard':
        metric = 'rmse'
    else:
        metric = 'accuracy'

    with tf.Session() as _sess:
        _sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=_sess, coord=coord)
        try:
            while not coord.should_stop():
                images, indices = _sess.run([batch_image, batch_image_index])
                labels = test_data.get_labels_from_indices(indices, function, difficulty)
                test_metric, test_pred = model.eval_metric(images, labels)
                img = Image.fromarray(images[0])
                draw = ImageDraw.Draw(img)
                #draw.text((0, 0), 'label: {0}'.format(labels[0]), (255, 255, 255))
                img.show()
                print('test ' + metric + ': %.4f' % (test_metric))
                print(labels[0])
                print(test_pred)
        except tf.errors.OutOfRangeError:
            print('Something went wrong while testing')
        finally:
            coord.request_stop()
            coord.join(threads)
        return test_metric

def run(model_name, function, difficulty, learning_rate = 0.0025) :
    with tf.Session() as sess:
            # Define computation graph & Initialize
            print('Building network & initializing variables')
            if model_name == 'ALEXNET':
                model = ALEXNET(function, learning_rate, difficulty)
            elif model_name == 'VGGNET':
                model = VGGNET(function, learning_rate, difficulty)
            else:
                model = LENET(function, learning_rate, difficulty)

            model.init_sess(sess)
            saver = tf.train.Saver()

            # Process data
            print("Load dataset")
            dataset = load_dataset()
            test_data = dataset.test

            print('Loading best checkpointed model')
            saver.restore(sess, model.model_filename)
            test_metric = test(model, sess, saver, test_data, function, difficulty)

In [None]:
run("ALEXNET", "count", "moderate")