In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras.layers import Dense, Permute, Softmax, Activation, Add

In [2]:
print(tf.__version__)

tf.random.set_seed(42)

2.6.0


In [3]:
hidden_state = tf.constant([[1, 2, 3] for _ in range(7)], dtype=tf.float32)
cell_state = tf.constant([[6, 7, 8] for _ in range(7)], dtype=tf.float32)

print(tf.concat([hidden_state, cell_state], axis=-1).shape)

(7, 6)


In [4]:
n = 4

hs = K.repeat(tf.concat([hidden_state, cell_state], axis=-1), n)
print(hs.shape)

(7, 4, 6)


- tf.keras.layers.Dense

In [5]:
T = 6

print(Dense(T)(hs).shape)

(7, 4, 6)


In [6]:
X = tf.constant([[[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8],
                  [9, 10, 11, 12], [11, 12, 13, 14]] for _ in range(7)], dtype=tf.float32)

# print(X)
print(X.shape)
print()

X_tr = Permute((2, 1))(X)
# print(X_tr)
print(X_tr.shape)

(7, 5, 4)

(7, 4, 5)


In [7]:
ux = Dense(T)(Permute((2, 1))(X))
print(ux.shape)

(7, 4, 6)


In [8]:
print(hs[0, 0, :])

print(ux[0, 0, :])

tf.Tensor([1. 2. 3. 6. 7. 8.], shape=(6,), dtype=float32)
tf.Tensor([ 5.4773407  8.245257  -1.4534266  2.8034296  4.5096674 11.279674 ], shape=(6,), dtype=float32)


### tf.keras.layers.Add

- tf.math.tanh

- tf.keras.layers.Activation

In [9]:
temp_add = Add()([hs, ux])

print(temp_add.shape)

print(temp_add[0, 0, :])

(7, 4, 6)
tf.Tensor([ 6.4773407 10.245257   1.5465734  8.80343   11.509667  19.279675 ], shape=(6,), dtype=float32)


In [10]:
tanh_math_add = tf.math.tanh(temp_add)
print(tanh_math_add.shape)

(7, 4, 6)


In [11]:
tanh_act_add = Activation(activation='tanh')(temp_add)
print(tanh_act_add.shape)

(7, 4, 6)


In [12]:
diff_tanh_add = tanh_math_add - tanh_act_add
print(sum(sum(sum(diff_tanh_add))))

tf.Tensor(0.0, shape=(), dtype=float32)


In [13]:
e_add = Dense(1)(tanh_act_add)
print(e_add[:, :, 0])
print(e_add.shape)

tf.Tensor(
[[-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]
 [-0.62823135 -0.66721845 -0.68682265 -0.6965132 ]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [14]:
attn_add = Softmax()(e_add)
print(attn_add[:, :, 0])
print(attn_add.shape)

tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [15]:
attn_add = Softmax()(Permute((2, 1))(e_add))
print(attn_add[:, 0, :])
print(attn_add.shape)

tf.Tensor(
[[0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]
 [0.26049453 0.25053403 0.24567033 0.24330117]], shape=(7, 4), dtype=float32)
(7, 1, 4)


### tf.concat

- tf.math.tanh

- tf.keras.layers.Activation

In [16]:
temp_concat = tf.concat([hs, ux], axis=-1)

print(temp_concat.shape)

print(temp_concat[0, 0, :])

(7, 4, 12)
tf.Tensor(
[ 1.         2.         3.         6.         7.         8.
  5.4773407  8.245257  -1.4534266  2.8034296  4.5096674 11.279674 ], shape=(12,), dtype=float32)


In [17]:
tanh_math_concat = tf.math.tanh(temp_concat)
print(tanh_math_concat.shape)

(7, 4, 12)


In [18]:
tanh_act_concat = Activation(activation='tanh')(temp_concat)
print(tanh_act_concat.shape)

(7, 4, 12)


In [19]:
diff_tanh = tanh_math_concat - tanh_act_concat
print(sum(sum(sum(diff_tanh))))

tf.Tensor(0.0, shape=(), dtype=float32)


In [20]:
e_act = Dense(1)(tanh_act_concat)
print(e_act[:, :, 0])
print(e_act.shape)

tf.Tensor(
[[0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]
 [0.83395565 0.80961    0.76546884 0.69632435]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [21]:
attn_input_act = Softmax()(e_act)
print(attn_input_act[:, :, 0])
print(attn_input_act.shape)

tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [22]:
attn_input_act = Softmax()(Permute((2, 1))(e_act))
print(attn_input_act[:, 0, :])
print(attn_input_act.shape)

tf.Tensor(
[[0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]
 [0.26446813 0.25810724 0.24696186 0.23046279]], shape=(7, 4), dtype=float32)
(7, 1, 4)
