In [1]:
'''Import required modules.

Moduels:
    layer: The customized single layers.
    module: The customized residual and hourglass module.
    tensorflow: The framework for deep learning.
    os: This module is necessary to search a path.
    cv2: The result of network is saved using openCV2.
    tqdm: Looping visualize tool.
    reader: The interface for training and test data.
'''

import layer
import module
import tensorflow as tf
import os
import cv2
from tqdm import tqdm_notebook, trange
from reader import *

  return f(*args, **kwds)


In [2]:
reader = Reader(train = os.path.expanduser('~/Temp/train19122.dat'),
                test = os.path.expanduser('~/Temp/test2125.dat'))

In [3]:
image, heatmap = reader.batch(2, True)

In [4]:
cv2.imwrite(os.path.expanduser('~/Temp/rgb_00001.jpg'), image[0])
cv2.imwrite(os.path.expanduser('~/Temp/rgb_00002.jpg'), image[1])
cv2.imwrite(os.path.expanduser('~/Temp/heat_00001.jpg'), heatmap[0][:, :, 12])
cv2.imwrite(os.path.expanduser('~/Temp/heat_00002.jpg'), heatmap[1][:, :, 12])

True

In [5]:
class Metadata(object):
    pass

class Image(Metadata):
    __slots__ = ['width', 'height', 'channel']
    
    def __init__(self, width, height, channel):
        self.width = width
        self.height = height
        self.channel = channel
        
metadata = Metadata()
metadata.image = Image(256, 256, 3)
metadata.heatmap = Image(64, 64, 1)
metadata.joint = 16

In [6]:
flags = tf.app.flags
flags.DEFINE_string('ckpt', os.path.expanduser('~/Checkpoints/hourglass_MPII.ckpt'), 'The path for checkpoint.')
flags.DEFINE_bool('train', True, 'Whether update parameter or not.')
flags.DEFINE_integer('batch', 6, 'The batch size.')

FLAGS = flags.FLAGS

In [7]:
with tf.variable_scope('input'):
    images = tf.placeholder(
        name = 'image',
        dtype = tf.float32,
        shape = [None, metadata.image.height, metadata.image.width, metadata.image.channel])
    heatmaps_groundtruth = tf.placeholder(
        name = 'heatmap_groundtruth',
        dtype = tf.float32,
        shape = [None, metadata.heatmap.height, metadata.heatmap.width, metadata.joint])
    train = tf.placeholder(
        name = 'train',
        dtype = tf.bool,
        shape = ())

In [8]:
# input size 256 * 256 * 3

with tf.variable_scope('compress'):
    with tf.variable_scope('conv_bn_relu'):
        net = layer.conv(input = images, ksize = 7, kchannel = 64, kstride = 2) # 128 * 128 * 64
        net = layer.bn(input = net, train = train)
        net = layer.relu(input = net)

    net = module.bottleneck(input = net, kchannel = 128, train = train, name = 'A') # 128 * 128 * 128
    net = layer.pool(input = net) # 64 * 64 * 128
    net = module.bottleneck(input = net, kchannel = 128, train = train, name = 'B') # 64 * 64 * 128
    net = module.bottleneck(input = net, kchannel = 256, train = train, name = 'C') # 64 * 64 * 256

In [9]:
class tf_Spectrum:
    Color = tf.constant([
        [128, 0, 0],
        [255, 0, 0],
        [0, 255, 0],
        [0, 255, 255],
        [0, 0, 255]
    ], dtype = tf.float64)

def tf_gray2color(gray, spectrum = tf_Spectrum.Color):
    indices = tf.floor_div(gray, 64)
    
    t = (gray - indices * 64) / (64)
    t = tf.stack([t]*3, 2)
    indices = tf.cast(indices, dtype = tf.int32)
    
    return (1-t)*tf.gather(spectrum, indices) + t*tf.gather(spectrum, indices+1)

def tf_merge(image, heatmaps):
    board = tf.zeros(
        shape = (
            metadata.heatmap.height,
            metadata.heatmap.width,
            metadata.image.channel
        ), dtype = tf.float64)
    for joint in range(metadata.joint):
        board = tf.maximum(
            board, 
            tf_gray2color(tf.cast(heatmaps[:, :, joint], dtype=tf.float64)))
    board = tf.image.resize_images(
        board, 
        [metadata.image.height, metadata.image.width])
    image = tf.cast(image, dtype = tf.float32)
    return tf.cast(tf.add(tf.multiply(board, 0.6), tf.multiply(image, 0.4)), dtype = tf.uint8)

In [10]:
last_stage = 8
heatmaps = []

for stage in range(1, last_stage+1):
    with tf.variable_scope('hourglass_' + str(stage)):
        prev = tf.identity(net)
        net = module.hourglass(input = net, train = train) # 64 * 64 * 256

        with tf.variable_scope('inter_hourglasses'):
            net = module.bottleneck(input = net, train = train) # 64 * 64 * 256
            net = layer.conv(input = net, ksize = 1, kchannel = 256) # 64 * 64 * 256
            net = layer.bn(input = net, train = train)
            net = layer.relu(input = net)

        with tf.variable_scope('heatmap'):
            heatmap = layer.conv(input = net, ksize = 1, kchannel = metadata.joint) # 64 * 64 * joint
            heatmaps.append(heatmap)

        if stage != last_stage:
            net = layer.conv(input = net, ksize = 1, kchannel = 256, name = 'inter')\
                + layer.conv(input = heatmap, ksize = 1, kchannel = 256, name = 'heatmap')\
                + prev # 64 * 64 * 256

merged = tf_merge(images[0], heatmaps[-1][0])

In [11]:
if FLAGS.train :
    with tf.variable_scope('loss'):
        loss = tf.losses.mean_squared_error(heatmaps_groundtruth, heatmaps[0])
        for stage in range(1, last_stage):
            loss = loss + tf.losses.mean_squared_error(heatmaps_groundtruth, heatmaps[stage])
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.AdamOptimizer(name = 'optimizer', learning_rate = 0.00025).minimize(loss)

In [12]:
sess = tf.Session()
saver = tf.train.Saver()
reader = Reader(train = os.path.expanduser('~/Temp/train19122.dat'),
                test = os.path.expanduser('~/Temp/test2125.dat'))
if FLAGS.train:
    sess.run(tf.global_variables_initializer())
    #saver.restore(sess, FLAGS.ckpt)
else:
    saver.restore(sess, FLAGS.ckpt)

In [13]:
if FLAGS.train == True:
    for epoch in range(100):
        train_iter = tqdm_notebook(total = 19122, desc = 'epoch: ' + str(epoch) + '/100')
        for i in range(3187):
            train_images, train_heatmaps = reader.batch(size = FLAGS.batch, is_train = True)
            _, result = sess.run([optimizer, loss],
                feed_dict = {
                    images: train_images,
                    heatmaps_groundtruth: train_heatmaps,
                    train: True})
            train_iter.set_postfix(loss = result)
            train_iter.update(FLAGS.batch)
        train_iter.close();
        temp = saver.save(sess, FLAGS.ckpt)
    '''
    for epoch in tqdm_notebook(tqdm(range(10), desc = 'epoch')):
        inner_iter = tqdm_notebook(tqdm(range(563), desc = 'iter'), leave = False)
        for iterator in inner_iter:
            train_images, train_heatmaps = reader.batch(size = flag.batch_size, is_train = True)
            _, result = sess.run([optimizer, loss],
                feed_dict = {
                    images: train_images,
                    heatmaps_groundtruth: train_heatmaps,
                    train: True})
            wrap = lambda label, value: label + '(' + str(value) + ')'
            inner_iter.set_postfix(loss = result)
        temp = saver.save(sess, os.path.join(ckpt_path, 'hourglass_' + str(start_epoch + epoch + 1) + '.ckpt'))
        print("Model saved in file: %s" % temp)
        '''
else:
    heatmap_idx = 0
    total_result = []
    test_iter = tqdm_notebook(total = 2125, desc = 'test')
    for i in range(2125 // 5):
        test_images, test_heatmaps = reader.batch(size = 5, is_train = False)
        result, output_heatmap = sess.run([loss, merged],
            feed_dict = {
                images: test_images,
                heatmaps_groundtruth: test_heatmaps,
                train: False})
        test_iter.update(5)
        total_result.append(result)
        cv2.imwrite('merged' + str(heatmap_idx) + '.jpg', output_heatmap)
        heatmap_idx += 1
    test_iter.close();
    print(sum(total_result)/(2125//5))













































































































































































































































































































In [14]:
heatmap_idx = 0
total_result = []
test_iter = tqdm_notebook(total = 2125, desc = 'test')
for i in range(2125 // 5):
    test_images, test_heatmaps = reader.batch(size = 5, is_train = False)
    result, output_heatmap = sess.run([loss, merged],
        feed_dict = {
            images: test_images,
            heatmaps_groundtruth: test_heatmaps,
            train: False})
    test_iter.update(5)
    total_result.append(result)
    cv2.imwrite('merged' + str(heatmap_idx) + '.jpg', output_heatmap)
    heatmap_idx += 1
test_iter.close();
print(sum(total_result)/(2125//5))


203.425501888
