In [1]:
import tensorflow as tf


In [2]:
from forward_kinematics import FK

  # def mlp_out(self, input_, reuse=False, name="mlp_out"):
  #   out = qlinear(input_, 4 * (self.n_joints + 1), name="dec_fc")
  #   return out
import numpy as np


In [5]:
class EncoderDecoderGRU(object):
  def __init__(self,
               batch_size,
               n_joints,
               layers_units,
               keep_prob,
               ):

    self.n_joints = n_joints
    self.batch_size = batch_size
    self.max_len=60
    self.kp = keep_prob
    self.layers_units=layers_units
    self.fk=FK()
    self.parents = np.array([
      -1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 0, 10, 11, 12, 3, 14, 15, 16, 3, 18, 19,
      20
  ])

    self.gen=self.generator(layers_units)

  def generator(self, layers_units):
    enc_gru = self.gru_model(layers_units)
    dec_gru = self.gru_model(layers_units)

    seqA_ = tf.keras.Input(
        shape=(self.max_len, 3 * self.n_joints + 4),
        batch_size=self.batch_size,
        name="seqA")
    skelA_ = tf.keras.Input(
        shape=(self.max_len, 3 * self.n_joints),
        batch_size=self.batch_size,
        name="skelA_")
    skelB_ = tf.keras.Input(
        shape=(self.max_len, 3 * self.n_joints),
        batch_size=self.batch_size,
        name="skelB")

    b_local = []
    b_global = []
    b_quats = []
    a_local = []
    a_global = []
    a_quats = []

    statesA_AB = []
    statesB_AB = []
    statesA_BA = []
    statesB_BA = []
    fc=tf.keras.layers.Dense(4 * (self.n_joints + 1))
    for units in layers_units:
      statesA_AB += [tf.zeros([self.batch_size, units])]
      statesB_AB += [tf.zeros([self.batch_size, units])]
      statesA_BA += [tf.zeros([self.batch_size, units])]
      statesB_BA += [tf.zeros([self.batch_size, units])]
    print(statesA_AB)
    print(seqA_[:, 0, :])
    print(seqA_.shape, tf.expand_dims(seqA_[:, 0, :], 1).shape)
    for t in range(self.max_len):
      """Retarget A to B"""       
      ptA_in = seqA_[:, t, :]

      n = enc_gru(tf.expand_dims(ptA_in, 1), initial_state=statesA_AB)
      statesA_AB=n[1:]
      if t == 0:
        ptB_in = tf.zeros([self.batch_size, 3 * self.n_joints + 4])
      else:
        ptB_in = tf.concat([b_local[-1], b_global[-1]], axis=-1)

      ptcombined = tf.concat(
          values=[skelB_[:, 0, 3:], ptB_in, statesA_AB[-1]], axis=1)
      print(ptcombined.shape)
      n = dec_gru(tf.expand_dims(ptcombined, 1), initial_state=statesB_AB)
      statesB_AB=n[1:]
      angles_n_offset = fc(statesB_AB[-1])
      output_angles = tf.reshape(angles_n_offset[:, :-4],
                                  [self.batch_size, self.n_joints, 4])
      b_global.append(angles_n_offset[:, -4:])
      b_quats.append(self.normalized(output_angles))

      skel_in = tf.reshape(skelB_[:, 0, :], [self.batch_size, self.n_joints, 3])
      skel_in = skel_in 

      output = (self.fk.run(self.parents, skel_in, output_angles))
      output = tf.reshape(output, [self.batch_size, -1])
      b_local.append(output)
      """Retarget B back to A"""
      ptB_in = tf.concat([b_local[-1], b_global[-1]], axis=-1)

      n= enc_gru(tf.expand_dims(ptB_in, 1), initial_state=statesB_BA)
      statesB_BA=n[1:]

      if t == 0:
        ptA_in = tf.zeros([self.batch_size, 3 * self.n_joints + 4])
      else:
        ptA_in = tf.concat([a_local[-1], a_global[-1]], axis=-1)

      ptcombined = tf.concat(
          values=[skelA_[:, 0, 3:], ptA_in, statesB_BA[-1]], axis=1)
      n = dec_gru(tf.expand_dims(ptcombined, 1), initial_state=statesA_BA)
      statesA_BA=n[1:]
      angles_n_offset = fc(statesA_BA[-1])
      output_angles = tf.reshape(angles_n_offset[:, :-4],
                                [self.batch_size, self.n_joints, 4])
      a_global.append(angles_n_offset[:, -4:])
      a_quats.append(self.normalized(output_angles))

      skel_in = tf.reshape(skelA_[:, 0, :], [self.batch_size, self.n_joints, 3])
      skel_in = skel_in 

      output = (self.fk.run(self.parents, skel_in, output_angles))
      output = tf.reshape(output, [self.batch_size, -1])
      a_local.append(output)

      return tf.keras.Model(inputs=[seqA_, skelA_, skelB_], outputs=[b_local, b_global, b_quats, a_local, a_global, a_quats]) 
      
  def gru_model(self, layers_units, rnn_type="GRU"):
    gru_cells = [tf.keras.layers.GRUCell(units, dropout=(1-self.kp)) for units in layers_units]
    gru_layer=tf.keras.layers.RNN(gru_cells, return_state=True)
    return gru_layer
  def normalized(self, angles):
    lengths = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(angles), axis=-1))
    return angles / lengths[..., None]

In [6]:
layers_units = []
for i in range(2):
    layers_units.append(512)
gru=EncoderDecoderGRU(16, 22, layers_units, 0.9)

[<tf.Tensor: shape=(16, 512), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, <tf.Tensor: shape=(16, 512), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]
Tensor("strided_slice_157:0", shape=(16, 70), dtype=float32)
(16, 60, 70) (16, 1, 70)
(16, 645)
