In [41]:
''' Import dependencies.

Moduels:
    tensorflow: The neural network framework.
    layer: The customized single layers.
    module: The customized multi-layer modules.
    
    os: The basic module for path parsing.
    json: The basic modue for config parsing.
    zipfile: Extract zip file.
    random: Shuffle indices.
    scipy.io: Load .mat file.
    numpy: To process .mat file data.
    
    tqdm: The 3rd-party looping visualzer.
'''

import layer, module
import tensorflow as tf
import os, json, zipfile, random
import scipy.io, scipy.misc, numpy as np
from enum import Enum
from tqdm import tqdm_notebook, trange

In [2]:
''' Parsing config.

The config file is saved to root/config.
'''
CONFIG_PATH = os.path.abspath('../config.json')
with open(CONFIG_PATH) as CONFIG_FILE:
    CONFIG = json.loads(CONFIG_FILE.read())
    flags = tf.app.flags
    flags.DEFINE_string('project', CONFIG['project'], 'The project name.')
    
    flags.DEFINE_string('path_to_pretrained',
                        os.path.join(
                            os.path.expanduser(CONFIG['pretrained']['path']),
                            CONFIG['project']),
                        'The path to pretrained parameters.')
    flags.DEFINE_boolean('load_pretrained', CONFIG['pretrained']['load'], 'Load the latest pretrained parameters.')
    flags.DEFINE_boolean('save_pretrained', CONFIG['pretrained']['save'], 'Save the trained parameters.')
    
    flags.DEFINE_string('path_to_log',
                        os.path.join(
                            os.path.expanduser(CONFIG['log']['path']),
                            CONFIG['project']),
                        'The path to log.')
    flags.DEFINE_boolean('save_log', CONFIG['log']['save'], 'Save the log.')
    
    flags.DEFINE_string('path_to_data',
                        os.path.expanduser(CONFIG['data']['path']),
                        'The path to data.')
    flags.DEFINE_string('name_of_data', CONFIG['data']['name'], 'The name of data.')
    
    flags.DEFINE_string('task', CONFIG['task']['step'], 'The task to be done.')
    flags.DEFINE_integer('epoch', CONFIG['task']['train']['epoch'], 'The epoch to be trained.')
    flags.DEFINE_string('metric', CONFIG['task']['eval']['metric'], 'The evaluation metric.')
    flags.DEFINE_float('metric_coefficient', CONFIG['task']['eval']['coefficient'], 'The evaluation metric coefficient.')
    
FLAGS = flags.FLAGS

In [87]:
''' Common metadata.
'''

class DataInterface(object):
    def __init__(self):
        raise NotImplementedError()
    
    def __corrupted(self):
        raise NotImplementedError()
        
    def __reload(self):
        raise NotImplementedError()
        
    def getBatch(self, batch_size):
        raise NotImplementedError()
        
class DataCenter(object):
    def __init__(self, data_root):
        self.__data_root = data_root
    
    def request(self, data_name, task):
        path = os.path.join(self.__data_root, data_name)
        
        if data_name == 'FLIC':
            return FLIC(data_path = path, task = task)
        elif data_name == 'MPII':
            return MPII(data_path = path, task = task)
        else:
            raise NotImplementedError()
            
class Joint(Enum):

    # FLIC
    L_Shoulder  =   1
    R_Shoulder  =   2
    L_Elbow     =   3
    R_Elbow     =   4
    L_Wrist     =   5
    R_Wrist     =   6
    L_Hip       =   7
    R_Hip       =   8
    L_Knee      =   9
    R_Knee      =   10
    L_Ankle     =   11
    R_Ankle     =   12

    L_Eye       =   13
    R_Eye       =   14
    L_Ear       =   15
    R_Ear       =   16
    M_Nose      =   17

    M_Shoulder  =   18
    M_Hip       =   19
    M_Ear       =   20
    M_Torso     =   21
    M_LUpperArm =   22
    M_RUpperArm =   23
    M_LLowerArm =   24
    M_RLowerArm =   25
    M_LUpperLeg =   26
    M_RUpperLeg =   27
    M_LLowerLeg =   28
    M_RLowerLeg =   29

    # MPII
    # M_Pelvis    =   M_Hip
    # M_Thorax    =   M_Torso
    M_UpperNeck =   30
    M_HeadTOP   =   31
    
class Basis(Enum):
    x = 0
    y = 1

In [132]:
''' FLIC data reader
'''       
class FLIC(DataInterface):
    
    archive_file = 'FLIC.zip'
    extract_dir = 'FLIC'
    matlab_file = 'examples.mat'
    index_file = { 'train': 'train.txt', 'eval': 'eval.txt' }
    number_of_data = 5003
    train_ratio = 0.9
    eval_ratio = 1.0 - train_ratio
    
    joint2index = {
        Joint.L_Shoulder: 0,
        Joint.L_Elbow: 1,
        Joint.L_Wrist: 2,
        Joint.R_Shoulder: 3,
        Joint.R_Elbow: 4,
        Joint.R_Wrist: 5,
        Joint.L_Hip: 6,
        Joint.R_Hip: 9,
        Joint.L_Eye: 12,
        Joint.R_Eye: 13,
        Joint.M_Nose: 16
    }
    
    def __init__(self, data_path, task):
        self.__data_path = data_path
        
        if self.__corrupted():
            self.__reload()
            
        extract_path = os.path.join(self.__data_path, FLIC.extract_dir)
        matlab_path = os.path.join(extract_path, FLIC.matlab_file)
        self.__FLIC = scipy.io.loadmat(matlab_path)
        self.__index = { 'train': open(os.path.join(self.__data_path, FLIC.index_file['train']), 'r'),
                     'eval': open(os.path.join(self.__data_path, FLIC.index_file['eval']), 'r')}
        self.__task = task
        
    def __delete__(self):
        self.__index['train'].close()
        self.__index['eval'].close()
    
    def __corrupted(self):
        archive_path = os.path.join(self.__data_path, FLIC.archive_file)
        print('Check if the archive file exists...', end = '')
        if not os.path.exists(archive_path):
            raise Exception('You have to download the archive file from https://bensapp.github.io/flic-dataset.html')
        print('success!')
        print('\tpath:', archive_path)
        
        
        extract_path = os.path.join(self.__data_path, FLIC.extract_dir)
        print('Check if the archive file is extracted...', end = '')
        if not os.path.exists(extract_path):
            print('failed!')
            return True
        print('success!')
        print('\tpath:', extract_path)
        
        
        index_path = { 'train': os.path.join(self.__data_path, FLIC.index_file['train']),
                     'eval': os.path.join(self.__data_path, FLIC.index_file['eval'])}
        print('Check if the index file is exists...', end = '')
        if not (os.path.exists(index_path['train']) and os.path.exists(index_path['eval'])):
            print('failed!')
            return True
        print('success!')
        print('\ttrain path:', index_path['train'])
        print('\teval path:', index_path['eval'])
        
        return False
        
    def __reload(self):
        archive_path = os.path.join(self.__data_path, FLIC.archive_file)
        print('Extract the archive file...', end = '')
        with zipfile.ZipFile(archive_path, 'r') as archive_ref:
            archive_ref.extractall(self.__data_path)
        print('success!')
        extract_path = os.path.join(self.__data_path, FLIC.extract_dir)
        print('\tpath:', extract_path)
        
        
        index_path = { 'train': os.path.join(self.__data_path, FLIC.index_file['train']),
                     'eval': os.path.join(self.__data_path, FLIC.index_file['eval'])}
        index_list = [i for i in range(FLIC.number_of_data)]
        random.shuffle(index_list)
        print('Generate the random train/eval set...', end = '')
        with open(index_path['train'], 'w') as train_index_file:
            for i in range(int(FLIC.train_ratio * FLIC.number_of_data)):
                train_index_file.write(str(index_list[i]) + '\n')
                
        with open(index_path['eval'], 'w') as train_index_file:
            for i in range(int(FLIC.train_ratio * FLIC.number_of_data), FLIC.number_of_data):
                train_index_file.write(str(index_list[i]) + '\n')
        print('success!')
        print('\ttrain path:', index_path['train'])
        print('\teval path:', index_path['eval'])
        
    def getBatch(self, batch_size):
        return self.__getMasking()
        '''
        return [(lambda index: (
            self.__getImage(index),
            self.__getMasking(index)
        ))(self.__getNext()) for i in range(batch_size)]
        '''
            
    def __getNext(self):
        index = self.__index[self.__task].readline()
        if index == '':
            self.__index[self.__task].seek(0)
            index = self.__index[self.__task].readline()
        return int(index)
            
    def __getImage(self, index):
        img_path = np.squeeze(np.squeeze(self.__FLIC['examples']['filepath'])[index])
        img_dims = np.squeeze(np.squeeze(self.__FLIC['examples']['imgdims'])[index])
        return img_path
    
    def __getKeypoints(self, index):
        keypoints_ref = np.squeeze(np.squeeze(np.squeeze(self.__FLIC['examples'])[index])['coords'])
        for joint in Joint:
            if joint not in FLIC.joint2index:
                x = float('nan')
                y = float('nan')
            else:
                x = keypoints_ref[Basis.x.value][FLIC.joint2index[joint]]
                y = keypoints_ref[Basis.y.value][FLIC.joint2index[joint]]
            print(joint, x, y)
    
    def __getMasking(self):
        return [(lambda joint: joint in FLIC.joint2index)(joint) for joint in Joint]

In [133]:
''' MPII data reader
'''
        
class MPII(DataInterface):
    def __init__(self, data_path, task):
        raise NotImplementedError()
        
    def __corrupted(self):
        raise NotImplementedError()
        
    def __reload(self):
        raise NotImplementedError()

In [134]:
dataset = DataCenter(data_root = FLAGS.path_to_data).request(data_name = FLAGS.name_of_data, task = FLAGS.task)

Check if the archive file exists...success!
	path: /home/nulledge/Workspace/data/FLIC/FLIC.zip
Check if the archive file is extracted...success!
	path: /home/nulledge/Workspace/data/FLIC/FLIC
Check if the index file is exists...success!
	train path: /home/nulledge/Workspace/data/FLIC/train.txt
	eval path: /home/nulledge/Workspace/data/FLIC/eval.txt


In [135]:
batch = dataset.getBatch(batch_size = 1)
print(batch)

[True, True, True, True, True, True, True, True, False, False, False, False, True, True, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False]


In [None]:
with tf.variable_scope('input'):
    images = tf.placeholder(
        name = 'image',
        dtype = tf.float32,
        shape = [None, metadata.image.height, metadata.image.width, metadata.image.channel])
    heatmaps_groundtruth = tf.placeholder(
        name = 'heatmap_groundtruth',
        dtype = tf.float32,
        shape = [None, metadata.heatmap.height, metadata.heatmap.width, metadata.joint])
    train = tf.placeholder(
        name = 'train',
        dtype = tf.bool,
        shape = ())

In [None]:
# input size 256 * 256 * 3

with tf.variable_scope('compress'):
    with tf.variable_scope('conv_bn_relu'):
        net = layer.conv(input = images, ksize = 7, kchannel = 64, kstride = 2) # 128 * 128 * 64
        net = layer.bn(input = net, train = train)
        net = layer.relu(input = net)

    net = module.bottleneck(input = net, kchannel = 128, train = train, name = 'A') # 128 * 128 * 128
    net = layer.pool(input = net) # 64 * 64 * 128
    net = module.bottleneck(input = net, kchannel = 128, train = train, name = 'B') # 64 * 64 * 128
    net = module.bottleneck(input = net, kchannel = 256, train = train, name = 'C') # 64 * 64 * 256

In [None]:
class tf_Spectrum:
    Color = tf.constant([
        [128, 0, 0],
        [255, 0, 0],
        [0, 255, 0],
        [0, 255, 255],
        [0, 0, 255]
    ], dtype = tf.float64)

def tf_gray2color(gray, spectrum = tf_Spectrum.Color):
    indices = tf.floor_div(gray, 64)
    
    t = (gray - indices * 64) / (64)
    t = tf.stack([t]*3, 2)
    indices = tf.cast(indices, dtype = tf.int32)
    
    return (1-t)*tf.gather(spectrum, indices) + t*tf.gather(spectrum, indices+1)

def tf_merge(image, heatmaps):
    board = tf.zeros(
        shape = (
            metadata.heatmap.height,
            metadata.heatmap.width,
            metadata.image.channel
        ), dtype = tf.float64)
    for joint in range(metadata.joint):
        board = tf.maximum(
            board, 
            tf_gray2color(tf.cast(heatmaps[:, :, joint], dtype=tf.float64)))
    board = tf.image.resize_images(
        board, 
        [metadata.image.height, metadata.image.width])
    image = tf.cast(image, dtype = tf.float32)
    return tf.cast(tf.add(tf.multiply(board, 0.6), tf.multiply(image, 0.4)), dtype = tf.uint8)

In [None]:
last_stage = 8
heatmaps = []

for stage in range(1, last_stage+1):
    with tf.variable_scope('hourglass_' + str(stage)):
        prev = tf.identity(net)
        net = module.hourglass(input = net, train = train) # 64 * 64 * 256

        with tf.variable_scope('inter_hourglasses'):
            net = module.bottleneck(input = net, train = train) # 64 * 64 * 256
            net = layer.conv(input = net, ksize = 1, kchannel = 256) # 64 * 64 * 256
            net = layer.bn(input = net, train = train)
            net = layer.relu(input = net)

        with tf.variable_scope('heatmap'):
            heatmap = layer.conv(input = net, ksize = 1, kchannel = metadata.joint) # 64 * 64 * joint
            heatmaps.append(heatmap)

        if stage != last_stage:
            net = layer.conv(input = net, ksize = 1, kchannel = 256, name = 'inter')\
                + layer.conv(input = heatmap, ksize = 1, kchannel = 256, name = 'heatmap')\
                + prev # 64 * 64 * 256

merged = tf_merge(images[0], heatmaps[-1][0])

In [None]:
if FLAGS.train :
    with tf.variable_scope('loss'):
        loss = tf.losses.mean_squared_error(heatmaps_groundtruth, heatmaps[0])
        for stage in range(1, last_stage):
            loss = loss + tf.losses.mean_squared_error(heatmaps_groundtruth, heatmaps[stage])
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.AdamOptimizer(name = 'optimizer', learning_rate = 0.00025).minimize(loss)

In [None]:
sess = tf.Session()
saver = tf.train.Saver()
reader = Reader(train = os.path.expanduser('~/Temp/train19122.dat'),
                test = os.path.expanduser('~/Temp/test2125.dat'))
if FLAGS.train:
    sess.run(tf.global_variables_initializer())
    #saver.restore(sess, FLAGS.ckpt)
else:
    saver.restore(sess, FLAGS.ckpt)

In [None]:
if FLAGS.train == True:
    for epoch in range(100):
        train_iter = tqdm_notebook(total = 19122, desc = 'epoch: ' + str(epoch) + '/100')
        for i in range(3187):
            train_images, train_heatmaps = reader.batch(size = FLAGS.batch, is_train = True)
            _, result = sess.run([optimizer, loss],
                feed_dict = {
                    images: train_images,
                    heatmaps_groundtruth: train_heatmaps,
                    train: True})
            train_iter.set_postfix(loss = result)
            train_iter.update(FLAGS.batch)
        train_iter.close();
        temp = saver.save(sess, FLAGS.ckpt)
    '''
    for epoch in tqdm_notebook(tqdm(range(10), desc = 'epoch')):
        inner_iter = tqdm_notebook(tqdm(range(563), desc = 'iter'), leave = False)
        for iterator in inner_iter:
            train_images, train_heatmaps = reader.batch(size = flag.batch_size, is_train = True)
            _, result = sess.run([optimizer, loss],
                feed_dict = {
                    images: train_images,
                    heatmaps_groundtruth: train_heatmaps,
                    train: True})
            wrap = lambda label, value: label + '(' + str(value) + ')'
            inner_iter.set_postfix(loss = result)
        temp = saver.save(sess, os.path.join(ckpt_path, 'hourglass_' + str(start_epoch + epoch + 1) + '.ckpt'))
        print("Model saved in file: %s" % temp)
        '''
else:
    heatmap_idx = 0
    total_result = []
    test_iter = tqdm_notebook(total = 2125, desc = 'test')
    for i in range(2125 // 5):
        test_images, test_heatmaps = reader.batch(size = 5, is_train = False)
        result, output_heatmap = sess.run([loss, merged],
            feed_dict = {
                images: test_images,
                heatmaps_groundtruth: test_heatmaps,
                train: False})
        test_iter.update(5)
        total_result.append(result)
        cv2.imwrite('merged' + str(heatmap_idx) + '.jpg', output_heatmap)
        heatmap_idx += 1
    test_iter.close();
    print(sum(total_result)/(2125//5))

In [None]:
heatmap_idx = 0
total_result = []
test_iter = tqdm_notebook(total = 2125, desc = 'test')
for i in range(2125 // 5):
    test_images, test_heatmaps = reader.batch(size = 5, is_train = False)
    result, output_heatmap = sess.run([loss, merged],
        feed_dict = {
            images: test_images,
            heatmaps_groundtruth: test_heatmaps,
            train: False})
    test_iter.update(5)
    total_result.append(result)
    cv2.imwrite('merged' + str(heatmap_idx) + '.jpg', output_heatmap)
    heatmap_idx += 1
test_iter.close();
print(sum(total_result)/(2125//5))