In [1]:
import numpy as np
import tensorflow as tf

from tensorflow.keras import layers as tf_layers

%matplotlib inline
import matplotlib as mpl
from matplotlib import pyplot as plt

from stacked_hourglass import hourglass_block, create_stacked_hourglass_model

In [2]:
WIDTH, HEIGHT = 128, 96
N_KEYPOINTS = 9

# Basic FCN-like model

In [3]:
N_START = 32
N_FILTERS_BLOCK1 = 64
N_FILTERS_BLOCK2 = 128

In [4]:
img_input = tf.keras.Input(shape=(WIDTH, HEIGHT, 3))

x = img_input

# Encoder block 1
x = tf_layers.Conv2D(N_FILTERS_BLOCK1, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.Conv2D(N_FILTERS_BLOCK1, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.MaxPooling2D((2, 2), strides=(2, 2))(x)

# Encoder block 2
x = tf_layers.Conv2D(N_FILTERS_BLOCK2, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.Conv2D(N_FILTERS_BLOCK2, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.MaxPooling2D((2, 2), strides=(2, 2))(x)

# Decoder block 1
x = tf_layers.UpSampling2D((2 ,2))(x)
x = tf_layers.Conv2D(N_FILTERS_BLOCK1, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.Conv2D(N_FILTERS_BLOCK1, (3, 3), padding='same', activation='relu')(x)

# Decoder block 1
x = tf_layers.UpSampling2D((2 ,2))(x)
x = tf_layers.Conv2D(N_START, (3, 3), padding='same', activation='relu')(x)
x = tf_layers.Conv2D(N_START, (3, 3), padding='same', activation='relu')(x)

# output
output = tf_layers.Conv2D(N_KEYPOINTS, (1, 1), activation='softmax')(x)

# model
basic_fcn_model = tf.keras.Model(img_input, output)
basic_fcn_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 128, 96, 3)]      0         
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 96, 64)       1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 96, 64)       36928     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 64, 48, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 48, 128)       73856     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 64, 48, 128)       147584    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 24, 128)       0     

# Single hourglass

In [5]:
# "simple" variant
# def residual_module(x, n_filters):
#     return tf_layers.Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)

# def residual_module(x, n_filters):
#     skip = tf_layers.Conv2D(n_filters, (1, 1), activation='relu')(x)
    
#     x = tf_layers.Conv2D(n_filters // 2, (1, 1), activation='relu')(x)
#     x = tf_layers.Conv2D(n_filters // 2, (3, 3), padding='same', activation='relu')(x)
#     x = tf_layers.Conv2D(n_filters, (1, 1), activation='relu')(x)
#     x = tf_layers.Add()([skip, x])
    
#     return x

# def hourglass_module(x, n_filters, max_filters=256):
#     if n_filters >= max_filters: # bottleneck
#         x = residual_module(x, n_filters)
#         x = residual_module(x, n_filters // 2)
#     else:                        # left and right half blocks
#         x = residual_module(x, n_filters)
#         skip = x
#         x = tf_layers.MaxPooling2D((2, 2), strides=(2, 2))(x)
        
#         x = hourglass_module(x, n_filters * 2, max_filters)
        
#         x = tf_layers.UpSampling2D((2, 2))(x)
#         x = tf_layers.Add()([skip, x])
#         x = residual_module(x, n_filters // 2)
        
#     return x

In [6]:
img_input = tf.keras.Input(shape=(WIDTH, HEIGHT, 3))
x = tf_layers.Conv2D(32, (3, 3), padding='same', activation='relu')(img_input)
x = hourglass_block(x, 64, 256, mode='simple')
output = tf_layers.Conv2D(N_KEYPOINTS, (1, 1), activation='softmax')(x)

single_hourglass_model = tf.keras.Model(img_input, output)
single_hourglass_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 128, 96, 3)] 0                                            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 128, 96, 32)  896         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 128, 96, 64)  18496       conv2d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 48, 64)   0           conv2d_10[0][0]                  
____________________________________________________________________________________________

# Stacked hourglass

In [7]:
N_HOURGLASSES = 4
START_FILTERS = 32
MAX_FILTERS = 128

In [8]:
# def create_stacked_hourglass_model(img_input, n_keypoints, n_hourglasses, start_filters, max_filters):
#     x = tf_layers.Conv2D(start_filters, (3, 3), padding='same', activation='relu')(img_input)
    
#     skip = None
#     output_list = []
#     for i in range(n_hourglasses):
#         x = hourglass_module(x, start_filters * 2, max_filters=max_filters)

#         output = tf_layers.Conv2D(n_keypoints, (1, 1), activation='softmax', name=f'output_{i}')(x)
#         output_list.append(output)
#         mapped_output = tf_layers.Conv2D(start_filters, (1, 1), activation='softmax')(output)

#         if skip is not None:
#             x = tf_layers.Add()([skip, mapped_output, x])
#         else:
#             x = tf_layers.Add()([mapped_output, x])
#         skip = x
    
#     stacked_hourglass_model = tf.keras.Model(img_input, output_list)
#     return stacked_hourglass_model

In [9]:
img_input = tf.keras.Input(shape=(WIDTH, HEIGHT, 3))
stacked_hourglass_model = create_stacked_hourglass_model(img_input, N_KEYPOINTS, N_HOURGLASSES, START_FILTERS, MAX_FILTERS, mode='simple')
stacked_hourglass_model.summary()

stacked_hourglass_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.mean_squared_error)

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 128, 96, 3)] 0                                            
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 96, 32)  896         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 128, 96, 64)  18496       conv2d_17[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 64, 48, 64)   0           conv2d_18[0][0]                  
____________________________________________________________________________________________

# Sample input & output

In [10]:
sample_inputs = tf.zeros([8, 128, 96, 3])
sample_outputs = tf.ones([8, 128, 96, 9])

stacked_hourglass_model.fit(sample_inputs, sample_outputs, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x2b733ad5250>