In [1]:
import tensorflow as tf
import numpy as np
import tensorflow.contrib.slim as slim


from bokeh.plotting import figure
from bokeh.io import output_notebook, show

from PIL import Image, ImageDraw
import os

In [2]:
output_notebook()

# Generator
![image.png](attachment:image.png)

In [3]:
class Generator:
    def __init__(self, batch_size, image_shape, stack_size=2, module_size=2, channel_size=[32, 64, 128]):
        self.batch_size = batch_size
        self.image_shape = image_shape
        self.stack_size = stack_size
        self.module_size = module_size
        self.channel_size = channel_size
        self.alpha = 0.2
        #with tf.variable_scope("generator"):
        self.inputs = tf.placeholder(tf.float32, shape=[None, *image_shape], name="inputs")
    
    def leaky_relu(self, inputs):
        return tf.maximum(inputs, inputs * self.alpha)
    
    def kinit(size, dtype, partition_info):
        return tf.random_normal(size, stddev=0.02)

    def build_module(self, inputs, name, channels=[32, 64, 128]):
        with tf.variable_scope(name + "_module"):
            with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], padding="SAME", kernel_size=3, stride=2, activation_fn=self.leaky_relu, 
                          weights_initializer=tf.truncated_normal_initializer(stddev=0.01)):
                return self.conv_module(inputs, channels)

    def conv_module(self, inputs, channels):
        conv_layer = slim.conv2d(inputs, channels[0])
        if len(channels) > 1:
            #recursive
            inner_layer = self.conv_module(conv_layer, channels[1:])
            concat_layer = tf.concat([conv_layer, inner_layer], axis=-1)
        else:
            concat_layer = conv_layer
        deconv_layer = slim.conv2d_transpose(concat_layer, channels[0])
        return deconv_layer
    
    def build_stack(self, inputs, training=True):
        with slim.arg_scope([slim.conv2d], kernel_size=5, stride=1, padding="SAME",
                       activation_fn=self.leaky_relu, weights_initializer=tf.truncated_normal_initializer(stddev=0.02)):
            conv_layer = slim.conv2d(inputs, num_outputs=32, )
            pose_heatmap = self.build_module(conv_layer, "pose", self.channel_size)
            concat_layer = tf.concat([pose_heatmap, conv_layer], axis=-1)
            occlusion_heatmap = self.build_module(concat_layer, "occlusion", self.channel_size)
            pose_heatmap = slim.conv2d(pose_heatmap, 1)
            occlusion_heatmap = slim.conv2d(occlusion_heatmap, 1)
            pose_heatmap = tf.nn.tanh(pose_heatmap)
            occlusion_heatmap = tf.nn.tanh(occlusion_heatmap)
            output = tf.concat([conv_layer, pose_heatmap, occlusion_heatmap], axis=-1, name="stack_output")
            return output, pose_heatmap, occlusion_heatmap
    
    def build(self, training=True):
        module_input = self.inputs
        for i in range(self.module_size):
            with tf.variable_scope("generator_"+str(i), reuse=not training):
                module_output, pose_heatmap, occlusion_heatmap = self.build_stack(module_input, training)
                module_input = module_output
        self.pose_heatmap = pose_heatmap
        self.occlusion_heatmap = occlusion_heatmap

In [4]:
def resize_img(fileName, new_size=[512, 512]):
    filePattern = "./sample_img/{}.jpg"
    fileSource = filePattern.format(fileName)
    avatar = Image.open(fileSource)
    return avatar.resize(new_size, Image.ANTIALIAS)

In [5]:
test_img = "4e497571583a077564df4b547e40408fd9915ecc"

tmp_img = resize_img(test_img)
tmp_np_img = np.array(tmp_img)
tmp_np_img.shape
tmp_img.size

(512, 512)

In [6]:
import pickle
save_path = "./model_res/keypoint_annotation.pkl"

with open(save_path, mode="rb") as f:
    kp_data = pickle.load(f)

In [7]:
class Preprocessor:
    def __init__(self, image_folder, kp_data, batch_size=32, image_size=[512, 512]):
        self.image_folder = image_folder
        self.file_pattern = image_folder+"/{}.jpg"
        self.batch_size = batch_size
        self.image_size = image_size
        self.image_names = [os.path.splitext(file)[0] for file in os.listdir(image_folder)]
        self.kp_data = kp_data
    
    def get_batch(self):
        if self.image_names == None:
            return None
        total_amount = len(self.image_names)
        batch_count = total_amount // self.batch_size
        for i in range(0, batch_count * self.batch_size, self.batch_size):
            the_batch = self.image_names[i:i+self.batch_size]
            heatmaps = np.array(list(map(self.make_heatmap, the_batch)))
            yield np.array(list(map(self.preprocess, the_batch))), heatmaps[:, 0], heatmaps[:, 1]
        pass
    
    def preprocess(self, image_name):
        file_path = self.file_pattern.format(image_name)
        avatar = Image.open(file_path)

        resized_image = avatar.resize(self.image_size, Image.ANTIALIAS)
        return np.array(resized_image) / 255.
        
    
    def get_kp_data(self, image_name):
        human_kp_data = self.kp_data[image_name]['keypoint_annotations']['human1']
        return human_kp_data
    
    def make_heatmap(self, image_name):
        file_path = self.file_pattern.format(image_name)
        avatar = Image.open(file_path)
        width, height = avatar.size
        kp_data = self.get_kp_data(image_name)
        heatmap = np.ones((self.image_size[1], self.image_size[0], 1)) * -1
        occlusion_heatmap = np.ones((self.image_size[1], self.image_size[0], 1)) * -1
        for i in range(0, len(kp_data), 3):
            ori_x, ori_y, status = kp_data[i:i+3]
            new_x = int(ori_x * (self.image_size[0] * 1. / width))
            new_y = int(ori_y * (self.image_size[1] * 1. / height))
            if status == 1:
                heatmap[new_y, new_x, 0] = 1
            elif status == 2:
                occlusion_heatmap[new_y, new_x, 0] = 1
        return heatmap, occlusion_heatmap
    
    def DrawImage(self, image_name):
        file_path = self.file_pattern.format(image_name)
        avatar = Image.open(file_path)
        print(avatar.size)
        drawAvatar = ImageDraw.Draw(avatar)
        annotation_data = self.kp_data[image_name]
        draw_human_boundary(drawAvatar, annotation_data)
        draw_keypoint(drawAvatar, annotation_data)
        del drawAvatar
        return avatar

    def draw_human_boundary(drawAvatar, annotation_data):
        if 'human_annotations' in annotation_data:
            human_data_set = annotation_data['human_annotations']
            for humman_data in human_data_set:
                drawAvatar.rectangle(human_data_set[humman_data], outline=(255, 10, 0))

    def draw_keypoint(drawAvatar, annotation_data):
        if "keypoint_annotations" in annotation_data:
            keypoint_data_set = annotation_data["keypoint_annotations"]
            for keypoint_key in keypoint_data_set:
                points = keypoint_data_set[keypoint_key]
                for i in range(0, len(points), 3):
                    if points[i+2] == 1:
                        fill = (10, 255, 10)
                    elif points[i+2] == 2:
                        fill = (255, 10, 10)
                    else:
                        fill = None
                    if fill is not None:
                        arc_points = [points[i] - 3, points[i+1] - 3, points[i] + 3, points[i+1] + 3]
                        drawAvatar.arc(arc_points, start=0, end=360, fill=fill)
                        drawAvatar.text((points[i]+10, points[i+1]), "{}".format((i // 3)+1), fill=fill)

    

In [11]:
"""preprocessor.get_kp_data(test_img)

heatmap, occlusion_heatmap = preprocessor.make_heatmap(test_img)

Image.fromarray(((heatmap + 1) * 255).squeeze().astype(np.int32))


heatmap[((heatmap + 1) * 255).squeeze().astype(np.int32).nonzero()]

test_img = "1b75657cd05ff89859bf800a30c0691c776dd880"
preprocessor.DrawImage(test_img)


a, b, c = preprocessor.image_names[0:3]
print(a, b, c)

len(preprocessor.image_names)
list(map(lambda x: x[-1], preprocessor.image_names))

"""

'preprocessor.get_kp_data(test_img)\n\nheatmap, occlusion_heatmap = preprocessor.make_heatmap(test_img)\n\nImage.fromarray(((heatmap + 1) * 255).squeeze().astype(np.int32))\n\n\nheatmap[((heatmap + 1) * 255).squeeze().astype(np.int32).nonzero()]\n\ntest_img = "1b75657cd05ff89859bf800a30c0691c776dd880"\npreprocessor.DrawImage(test_img)\n\n\na, b, c = preprocessor.image_names[0:3]\nprint(a, b, c)\n\nlen(preprocessor.image_names)\nlist(map(lambda x: x[-1], preprocessor.image_names))\n\n'

In [30]:


class Discriminator:
    def __init__(self, batch_size, image_shape, output_size=14, channel_size=[32, 64, 128]):
        self.batch_size = batch_size
        self.image_size = image_size
        self.channel_size = channel_size
        self.output_size = output_size
        self.alpha = 0.2

    def leaky_relu(self, inputs):
        return tf.maximum(inputs, inputs * self.alpha)
    
    def kinit(size, dtype, partition_info):
        return tf.random_normal(size, stddev=0.02)
    
    def build_module(self, inputs, name, channels=[32, 64, 128]):
        with tf.variable_scope(name + "_module"):
            with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], padding="SAME", kernel_size=3, stride=2, activation_fn=self.leaky_relu, 
                          weights_initializer=tf.truncated_normal_initializer(stddev=0.01)):
                return self.conv_module(inputs, channels)

    def conv_module(self, inputs, channels):
        conv_layer = slim.conv2d(inputs, channels[0])
        if len(channels) > 1:
            #recursive
            inner_layer = self.conv_module(conv_layer, channels[1:])
            concat_layer = tf.concat([conv_layer, inner_layer], axis=-1)
        else:
            concat_layer = conv_layer
        deconv_layer = slim.conv2d_transpose(concat_layer, channels[0])
        return deconv_layer
    
    def build(self, inputs, name, reuse=False):
        with tf.variable_scope("discriminator_"+name, reuse=reuse):
            conv_output = self.build_module(inputs, name, self.channel_size)
            flatten_layer = slim.flatten(conv_output)
            self.digits = slim.fully_connected(flatten_layer, self.output_size, activation_fn=None, 
                                 weights_initializer=tf.random_normal_initializer(stddev=0.01))
            
            self.outputs = tf.nn.sigmoid(self.digits)

In [58]:
batch_size= 2
image_size = [512, 512, 3]

In [59]:
tf.reset_default_graph()

#build generator model
generator = Generator(batch_size, image_size)
generator.build(True)
#get generator output
pose_heatmap = generator.pose_heatmap
occlusion_heatmap = generator.occlusion_heatmap

#create discriminator
pose_discriminator = Discriminator(batch_size, image_size)
conf_discriminator = Discriminator(batch_size, image_size)

#true heatmap input
pose_pl = tf.placeholder(tf.float32, [None, image_size[0], image_size[1], 1], name="pose_true_heatmap")
occlusion_pl = tf.placeholder(tf.float32, [None, image_size[0], image_size[1], 1], name="occlusion_true_heatmap")

#discriminator for true heatmap
pose_true_inputs = tf.concat([generator.inputs, pose_pl, occlusion_pl], axis=-1,
                            name="pose_true_inputs")
conf_true_inputs = tf.concat([pose_pl, occlusion_pl], axis=-1,
                            name="conf_true_inputs")
pose_discriminator.build(pose_true_inputs, "pose")
conf_discriminator.build(conf_true_inputs, "confidence")
pose_true_output = [pose_discriminator.digits, pose_discriminator.outputs]
conf_true_output = [conf_discriminator.digits, conf_discriminator.outputs]

#discriminator for fake heatmap
pose_fake_inputs = tf.concat([generator.inputs, generator.pose_heatmap, generator.occlusion_heatmap], axis=-1,
                            name="pose_fake_inputs")
conf_fake_inputs = tf.concat([generator.pose_heatmap, generator.occlusion_heatmap], axis=-1,
                            name="conf_fake_inputs")
pose_discriminator.build(pose_fake_inputs, "pose", True)
conf_discriminator.build(conf_fake_inputs, "confidence", True)
pose_fake_output = [pose_discriminator.digits, pose_discriminator.outputs]
conf_fake_output = [conf_discriminator.digits, conf_discriminator.outputs]

variables = tf.trainable_variables()
slim.summarize_tensors(variables)
merged = tf.summary.merge_all()


In [60]:
lam = 10
pose_eps = tf.random_uniform([image_size[0], image_size[1], 1], minval=0., maxval=1.)
pose_inter = pose_eps * pose_pl + (1. - pose_eps) * generator.pose_heatmap
occlusion_eps = tf.random_uniform([image_size[0], image_size[1], 1], minval=0., maxval=1.)
occlusion_inter = occlusion_eps * occlusion_pl + (1. - occlusion_eps) * generator.occlusion_heatmap

pose_inter_input = tf.concat([generator.inputs, pose_inter, occlusion_inter], axis=-1,
                            name="pose_fake_inputs")
pose_discriminator.build(pose_inter_input, "pose", True)
pose_grad = tf.gradients(pose_discriminator.digits, [pose_inter_input])[0]
pose_grad_norm = tf.sqrt(tf.reduce_sum((pose_grad)**2, axis=1))
pose_grad_pen = lam * tf.reduce_mean((pose_grad_norm - 1)**2)
D_pose_loss = tf.reduce_mean(pose_fake_output[0]) - tf.reduce_mean(pose_true_output[0]) + pose_grad_pen


conf_inter_input = tf.concat([pose_inter, occlusion_inter], axis=-1,
                            name="conf_fake_inputs")
conf_discriminator.build(conf_inter_input, "confidence", True)
conf_grad = tf.gradients(conf_discriminator.digits, [conf_inter_input])[0]
conf_grad_norm = tf.sqrt(tf.reduce_sum((conf_grad)**2, axis=1))
conf_grad_pen = lam * tf.reduce_mean((conf_grad_norm - 1)**2)
D_conf_loss = tf.reduce_mean(conf_fake_output[0]) - tf.reduce_mean(conf_true_output[0]) + conf_grad_pen

G_loss = -tf.reduce_mean(pose_fake_output[0]) - tf.reduce_mean(conf_fake_output[0])


In [61]:
all_variables = tf.trainable_variables()
generate_vars = [var for var in all_variables if var.name.startswith("generator")]
pose_d_vars = [var for var in all_variables if var.name.startswith("discriminator_pose")]
conf_d_vars = [var for var in all_variables if var.name.startswith("discriminator_conf")]

In [62]:

pose_d_vars


[<tf.Variable 'discriminator_pose/pose_module/Conv/weights:0' shape=(3, 3, 5, 32) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv/biases:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv_1/weights:0' shape=(3, 3, 32, 64) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv_1/biases:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv_2/weights:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv_2/biases:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv2d_transpose/weights:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv2d_transpose/biases:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv2d_transpose_1/weights:0' shape=(3, 3, 64, 192) dtype=float32_ref>,
 <tf.Variable 'discriminator_pose/pose_module/Conv2d_transpose

In [66]:
learning_rate = 0.001
beta1 = 0.2
preprocessor = Preprocessor("sample_img", kp_data, batch_size=batch_size)
epoches = 50
import time

In [64]:
D_pose_solver = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(D_pose_loss, var_list=pose_d_vars)
D_conf_solver = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(D_conf_loss, var_list=conf_d_vars)
G_solver = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(G_loss, var_list=generate_vars)


In [69]:
with tf.Session() as sess:
    fileWriter = tf.summary.FileWriter("summary", sess.graph)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    count = 1
    for epoch in range(epoches):
        for i, (each_batch, pose_heatmap, occlusion_heatmap) in enumerate(preprocessor.get_batch()):
            #print(each_batch.shape, pose_heatmap.shape, occlusion_heatmap.shape)

            feed_dict={
                    generator.inputs:each_batch,
                    pose_pl:pose_heatmap,
                    occlusion_pl:occlusion_heatmap
                }
            begin = time.time()
            d_pose_loss, _ = sess.run([D_pose_loss, D_pose_solver], feed_dict=feed_dict)
            d_conf_loss, _ = sess.run([D_conf_loss, D_conf_solver], feed_dict=feed_dict)
            g_loss, _ = sess.run([G_loss, G_solver], feed_dict=feed_dict)
            spend = time.time() - begin
            summary = sess.run(merged, feed_dict=feed_dict)
            print("pose loss:{}".format(d_pose_loss),
                 "conf loss:{}".format(d_conf_loss),
                 "generate loss:{}".format(g_loss),
                 "spend {} sec".format(spend))
            fileWriter.add_summary(summary, i+1)
    saver.save(sess, "save/model.ckpt")

pose loss:9.910619735717773 conf loss:9.909833908081055 generate loss:-0.9241063594818115 spend 140.21509981155396 sec
pose loss:1.341780662536621 conf loss:2.6288623809814453 generate loss:-2.733858346939087 spend 109.09760022163391 sec
pose loss:-48.272586822509766 conf loss:-38.166419982910156 generate loss:21.18093490600586 spend 97.32292056083679 sec
pose loss:-219.35037231445312 conf loss:-178.5438232421875 generate loss:141.35708618164062 spend 108.5876157283783 sec
pose loss:-712.5606689453125 conf loss:-532.9873657226562 generate loss:14.20733642578125 spend 107.81012225151062 sec
pose loss:-1123.3970947265625 conf loss:-846.3607788085938 generate loss:-14515.1953125 spend 102.03330564498901 sec
pose loss:422.00494384765625 conf loss:462.3708801269531 generate loss:-9265.98046875 spend 83.09481430053711 sec
pose loss:47.05528259277344 conf loss:26.80214500427246 generate loss:-6425.888671875 spend 74.54572343826294 sec
pose loss:17.468074798583984 conf loss:9.030993461608887 g

KeyboardInterrupt: 