In [None]:

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPool2D, UpSampling2D, Add, Input
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import mean_squared_error
import tensorflow.keras.backend as K
import numpy as np
from  tensorflow.keras import layers, models, Sequential, backend
from tensorflow import keras
from tensorflow.keras.layers.experimental import preprocessing
import os, sys, copy, argparse, pickle
import cv2

import tensorflow as tf

from tensorflow.keras.layers import (
    Add,
    Concatenate,
    Conv2D,
    Input,
    Lambda,
    ReLU,
    MaxPool2D,
    UpSampling2D,
    ZeroPadding2D,
    BatchNormalization,
)

In [None]:

def BottleneckBlock_4(inputs, filters, strides=1, downsample=False, name=None):
   
    identity = inputs
    if downsample:
        identity = Conv2D(
            filters=filters,  # lift channels first
            kernel_size=1,
            strides=strides,
            padding='same',
            kernel_initializer='he_normal')(inputs)

    x = BatchNormalization(momentum=0.9)(inputs)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters // 2,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters // 2,
        kernel_size=3,
        strides=strides,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = Add()([identity, x])
    return x


def HourglassModule_4(inputs, order, filters, num_residual):
    
    # Upper branch
    up1 = BottleneckBlock_4(inputs, filters, downsample=False)

    for i in range(num_residual):
        up1 = BottleneckBlock_4(up1, filters, downsample=False)

    # Lower branch
    low1 = MaxPool2D(pool_size=2, strides=2)(inputs)
    for i in range(num_residual):
        low1 = BottleneckBlock_4(low1, filters, downsample=False)

    low2 = low1
    if order > 1:
        low2 = HourglassModule(low1, order - 1, filters, num_residual)
    else:
        for i in range(num_residual):
            low2 = BottleneckBlock_4(low2, filters, downsample=False)

    low3 = low2
    for i in range(num_residual):
        low3 = BottleneckBlock_4(low3, filters, downsample=False)

    up2 = UpSampling2D(size=2)(low3)

    return up2 + up1


def LinearLayer(inputs, filters):
    x = Conv2D(
        filters=filters,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(inputs)
    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    return x


def StackedHourglassNetwork_4(
        input_shape=(256, 256, 3), num_stack=4, num_residual=1,
        num_heatmap=16, num_seg = 8):
    
    inputs = Input(shape=input_shape)

    # initial processing of the image
    x = Conv2D(
        filters=64,
        kernel_size=7,
        strides=2,
        padding='same',
        kernel_initializer='he_normal')(inputs)
    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = BottleneckBlock_4(x, 128, downsample=True)
    #x = MaxPool2D(pool_size=2, strides=2)(x)
    x = BottleneckBlock_4(x, 128, downsample=False)
    x = BottleneckBlock_4(x, 256, downsample=True)

    y_pose = []
    y_seg = []
    for k in range(num_stack):
        x = HourglassModule_4(x, order=4, filters=256, num_residual=num_residual)
        for i in range(num_residual):
            x = BottleneckBlock_4(x, 256, downsample=False)

        # predict 256 channels like a fully connected layer.
        x = LinearLayer(x, 256)
        
        # predict final channels, which is also the number of predicted heatmap
        y_p = Conv2D(
            filters=num_heatmap,
            kernel_size=1,
            strides=1,
            padding='same',
            kernel_initializer='he_normal', activation = 'linear', name = 'pose_'+str(k))(x)
        y_pose.append(y_p)

        y_s = Conv2D(
            filters=num_seg,
            kernel_size=1,
            strides=1,
            padding='same',
            kernel_initializer='he_normal', activation = 'sigmoid', name = 'seg_'+str(k))(x)
        y_seg.append(y_s)

        # if it's not the last stack, we need to add predictions back
        if i < num_stack - 1:
            y_intermediate_1 = Conv2D(filters=256, kernel_size=1, strides=1)(x)
            y_intermediate_2 = Conv2D(filters=256, kernel_size=1, strides=1)(y_p)
            y_intermediate_3 = Conv2D(filters=256, kernel_size=1, strides=1)(y_s)
            temp = Add()([y_intermediate_1, y_intermediate_2])
            x = Add()([temp, y_intermediate_3])
    
    y_final = y_pose+y_seg

    return tf.keras.Model(inputs, y_final, name='stacked_hourglass_4')

