In [None]:
import tensorflow as tf

from tensorflow.keras.layers import Layer, Flatten, Conv2D, MaxPool2D, UpSampling2D, Add, ReLU, BatchNormalization
from tensorflow.keras import Model, Input
from tensorflow.keras.backend import shape

In [None]:
class Residual(Layer):
  def __init__(self, input_shape, output_shape):
    super(Residual, self).__init__()    
    self.conv1 = Conv2D(int(output_shape/2), kernel_size=(1,1), strides=(1,1),padding='same')
    self.conv2 = Conv2D(int(output_shape/2), kernel_size=(3,3), strides=(1,1),padding='same')
    self.conv3 = Conv2D(output_shape, kernel_size=(1,1), strides=(1,1),padding='same')
    self.relu = ReLU()
    self.bn1 = BatchNormalization()
    self.bn2 = BatchNormalization()
    self.bn3 = BatchNormalization()
    self.identity = Conv2D(output_shape, kernel_size=(1,1), strides=(1,1), padding='same')

    if input_shape == output_shape:
      self.need_skip = False
    else:
      self.need_skip = True

  def call(self, inputs):
    if self.need_skip:
      res = self.identity(inputs)
    else:
      res = inputs
    # print("input", (inputs))
    x = self.bn1(inputs)
    # print("after bn 1", (x))
    x = self.relu(x)
    y = self.conv1(x)
    # print("after conv1", (y))
    x = self.bn2(y)
    # print("after bn2", x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.bn3(x)
    x = self.relu(x)
    x = self.conv3(x)
    x += res
    return x

In [None]:
# recursive hourglass
class Hourglass(Layer):
  def __init__(self, n, filter_size):
    super(Hourglass, self).__init__()
    self.l1 = Residual(filter_size, filter_size)
    self.m = MaxPool2D(pool_size=(2, 2))
    self.l2 = Residual(filter_size, filter_size)

    self.n = n

    # Recursive hourglass
    if self.n > 1:
        self.l3 = Hourglass(n-1, filter_size)
    else:
        self.l3 = Residual(filter_size, filter_size)
    self.l4 = Residual(filter_size, filter_size)
    self.up = UpSampling2D(size=(2,2))
  
  def call(self, inputs):
    l1 = self.l1(inputs)
    x = self.m(inputs)
    x = self.l2(x)
    x = self.l3(x)
    x = self.l4(x)
    up = self.up(x)
    return l1 + up

In [None]:
def heatmap_loss_function(y_actual, y_predicted):
  l = (y_predicted-y_actual)**2
  l = l.mean(dim=3).mean(dim=2).mean(dim=1)
  return l ## l of dim bsize

In [None]:
def combined_loss_function(hm_gt, hm):
  num_hourglasses = 4
        #   combined_loss = []
        # for i in range(self.nstack):
        #     combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
        # combined_loss = torch.stack(combined_loss, dim=1)
        # return combined_loss

  combined_loss = []
  for i in range(num_hourglasses):
    combined_loss.append(heatmap_loss_function(hm[0][:,i], hm_gt))
  combined_loss = tf.keras.backend.stack(combined_loss, axis=1)
  return combined_loss


In [None]:
class Features(Layer):
  def __init__(self, input_shape):
    super(Features, self).__init__()
    self.l1 = Residual(input_shape, input_shape)
    self.c1 = Conv2D(input_shape, kernel_size=(1,1), strides=(1,1), padding='same')
  
  def call(self, inputs):
    x = self.l1(inputs)
    x = self.c1(x)
    return x


In [None]:
def PoseEstimationModel(num_hourglasses=4):
  inputs = Input(shape=(256, 256, 3))

  # initial processing - could put this into Sequential?
  x = Conv2D(64, kernel_size = (7,7), strides=(2, 2), padding='same', activation='relu')(inputs)
  x = Residual(64, 128)(x)
  x = MaxPool2D(pool_size=(2,2))(x)
  x = Residual(128, 128)(x)
  x = Residual(128, 256)(x)
  features = [Features(256) for i in range(num_hourglasses)]
  outs = [Conv2D(256, kernel_size = (1,1), strides = (1,1), padding = 'same') for i in range(num_hourglasses)]
  merge_features =  [Conv2D(256, kernel_size = (1,1), strides = (1,1), padding = 'same') for i in range(num_hourglasses-1)]
  merge_predictions =  [Conv2D(256, kernel_size = (1,1), strides = (1,1), padding = 'same') for i in range(num_hourglasses-1)]

  combined_heatmap_predictions = []

  for i in range(num_hourglasses):
    h = Hourglass(n=4, filter_size=256)(x)
    # add intermediate predictions
    f = features[i](h)
    prediction = outs[i](f)
    combined_heatmap_predictions.append(prediction)
    if i < num_hourglasses - 1:
      x = x + merge_features[i](f) + merge_predictions[i](prediction)
  output = tf.keras.backend.stack(combined_heatmap_predictions, axis=1)
    
  model = Model(inputs, output)

  model.compile(
        optimizer='adam',
        loss='combined_loss_function',
        metrics=['accuracy']
    )

  return model



In [None]:
pem = PoseEstimationModel()