In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow import keras
from tensorflow.keras.layers import Dense, Permute

In [2]:
print(tf.__version__)
print()
print(tf.constant(42))

2.3.0

tf.Tensor(42, shape=(), dtype=int32)


In [3]:
hidden_state = tf.constant([[1, 2, 3] for _ in range(7)], dtype=tf.float32)
cell_state = tf.constant([[6, 7, 8] for _ in range(7)], dtype=tf.float32)

print(tf.concat([hidden_state, cell_state], axis=-1).shape)

(7, 6)


In [4]:
n = 4

hs = K.repeat(tf.concat([hidden_state, cell_state], axis=-1), n)
print(hs.shape)

(7, 4, 6)


- tf.keras.layers.Dense

In [5]:
T = 6

print(Dense(T)(hs).shape)

(7, 4, 6)


In [6]:
X = tf.constant([[[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8],
                  [9, 10, 11, 12], [11, 12, 13, 14]] for _ in range(7)], dtype=tf.float32)

# print(X)
print(X.shape)
print()

X_tr = Permute((2, 1))(X)
# print(X_tr)
print(X_tr.shape)

(7, 5, 4)

(7, 4, 5)


In [7]:
ux = Dense(T)(Permute((2, 1))(X))
print(ux.shape)

(7, 4, 6)


In [8]:
print(hs[0, 0, :])

print(ux[0, 0, :])

tf.Tensor([1. 2. 3. 6. 7. 8.], shape=(6,), dtype=float32)
tf.Tensor([  8.911063   -5.3239098  -3.0053906 -14.583972  -13.258175   -0.6091072], shape=(6,), dtype=float32)


### tf.keras.layers.Add

- tf.math.tanh

- tf.keras.layers.Activation

In [9]:
temp_add = tf.keras.layers.Add()([hs, ux])

print(temp_add.shape)

print(temp_add[0, 0, :])

(7, 4, 6)
tf.Tensor(
[ 9.9110632e+00 -3.3239098e+00 -5.3906441e-03 -8.5839720e+00
 -6.2581749e+00  7.3908930e+00], shape=(6,), dtype=float32)


In [10]:
tanh_math_add = tf.math.tanh(temp_add)
print(tanh_math_add.shape)

(7, 4, 6)


In [11]:
tanh_act_add = tf.keras.layers.Activation(activation='tanh')(temp_add)
print(tanh_act_add.shape)

(7, 4, 6)


In [12]:
diff_tanh_add = tanh_math_add - tanh_act_add
print(sum(sum(sum(diff_tanh_add))))

tf.Tensor(0.0, shape=(), dtype=float32)


In [13]:
e_add = tf.keras.layers.Dense(1)(tanh_act_add)
print(e_add[:, :, 0])
print(e_add.shape)

tf.Tensor(
[[-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]
 [-1.6520028 -1.5468907 -1.4531701 -1.3780476]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [14]:
attn_add = tf.keras.layers.Softmax()(e_add)
print(attn_add[:, :, 0])
print(attn_add.shape)

tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(7, 4), dtype=float32)
(7, 4, 1)


### tf.concat

- tf.math.tanh

- tf.keras.layers.Activation

In [15]:
temp_concat = tf.concat([hs, ux], axis=-1)

print(temp_concat.shape)

print(temp_concat[0, 0, :])

(7, 4, 12)
tf.Tensor(
[  1.          2.          3.          6.          7.          8.
   8.911063   -5.3239098  -3.0053906 -14.583972  -13.258175   -0.6091072], shape=(12,), dtype=float32)


In [16]:
tanh_math_concat = tf.math.tanh(temp_concat)
print(tanh_math_concat.shape)

(7, 4, 12)


In [17]:
tanh_act_concat = tf.keras.layers.Activation(activation='tanh')(temp_concat)
print(tanh_act_concat.shape)

(7, 4, 12)


In [18]:
diff_tanh = tanh_math_concat - tanh_act_concat
print(sum(sum(sum(diff_tanh))))

tf.Tensor(0.0, shape=(), dtype=float32)


In [19]:
e_act = tf.keras.layers.Dense(1)(tanh_act_concat)
print(e_act[:, :, 0])
print(e_act.shape)

tf.Tensor(
[[-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]
 [-1.2826318 -1.5102277 -1.7723515 -1.9854836]], shape=(7, 4), dtype=float32)
(7, 4, 1)


In [20]:
attn_input_act = tf.keras.layers.Softmax()(e_act)
print(attn_input_act[:, :, 0])
print(attn_input_act.shape)

tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(7, 4), dtype=float32)
(7, 4, 1)
