In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

In [9]:
X = np.random.randint(1000, size=(1000, 16))
Y = np.random.randint(1000, size=(1000, 11))

In [10]:
dataset = tf.data.Dataset.from_tensor_slices((X, Y))

In [11]:
dataset = dataset.batch(32)

In [13]:
x0, y0 = next(iter(dataset))
x0.shape, y0.shape

(TensorShape([32, 16]), TensorShape([32, 11]))

In [16]:
x = layers.Embedding(1000, 16)(x0)
x.shape

TensorShape([32, 16, 16])

In [20]:
x1, h1 = layers.GRU(8, return_sequences=True,
                       return_state=True,
                       recurrent_initializer='glorot_uniform')(x)
x1.shape, h1.shape

(TensorShape([32, 16, 8]), TensorShape([32, 8]))

In [35]:
h11 = tf.expand_dims(h1, axis=0)
h11.shape

TensorShape([1, 32, 8])

In [37]:
ht = h11[-1]
ht.shape

TensorShape([32, 8])

In [74]:
c_global = ht

In [41]:
# q_t = tf.expand_dims(h11, 1)
# q_t.shape

In [42]:
q1 = layers.Dense(8, use_bias=False)(x1)
q1.shape

TensorShape([32, 16, 8])

In [44]:
q2 = layers.Dense(8, use_bias=False)(ht)
q2.shape

TensorShape([32, 8])

In [46]:
mask = tf.ones((32,16))
mask.shape

TensorShape([32, 16])

In [55]:
q2_expand = tf.broadcast_to(tf.expand_dims(q2, 1), q1.shape)
q2_expand.shape

TensorShape([32, 16, 8])

In [57]:
q2_masked = tf.broadcast_to(tf.expand_dims(mask, -1), q1.shape)*q2_expand
q2_masked.shape

TensorShape([32, 16, 8])

In [65]:
a = layers.Dense(1, use_bias=False)(tf.reshape(tf.sigmoid(q1 + q2_masked), shape=(-1, 8)))
anpha = tf.reshape(a, mask.shape)
anpha.shape

TensorShape([32, 16])

In [67]:
anpha_exp = tf.broadcast_to(tf.expand_dims(anpha, 2), shape=x1.shape)
anpha_exp.shape

TensorShape([32, 16, 8])

In [75]:
c_local = tf.reduce_sum(anpha_exp*x1, axis=1)
c_local.shape

TensorShape([32, 8])

In [78]:
c_t = tf.concat([c_local, c_global], 1)
c_t.shape

TensorShape([32, 16])

In [80]:
item_indices = tf.range(1000)
item_indices.shape

TensorShape([1000])

In [83]:
item_embs = layers.Embedding(1000, 16)(item_indices)
item_embs.shape

TensorShape([1000, 16])

In [87]:
B = tf.transpose(layers.Dense(16, use_bias=False)(item_embs), perm=(1,0))
B.shape

TensorShape([16, 1000])

In [88]:
tf.matmul(c_t, B)

<tf.Tensor: shape=(32, 1000), dtype=float32, numpy=
array([[-2.6993142e-04, -7.7177648e-04, -3.0237159e-03, ...,
        -2.0859393e-03,  3.1117271e-03,  1.6633859e-03],
       [ 1.9095363e-03, -4.3103634e-03, -1.6339528e-03, ...,
         8.7697189e-03, -3.5207812e-03,  1.8739251e-03],
       [ 1.7727338e-03, -3.0188120e-03,  2.0405627e-03, ...,
        -2.0793173e-05, -3.6653685e-03, -5.7584851e-04],
       ...,
       [-4.9045839e-04,  2.9522530e-03, -1.3343183e-03, ...,
        -1.1360727e-03,  6.5652337e-03,  1.6718714e-03],
       [-1.4801731e-03,  5.2711801e-03,  4.7737225e-03, ...,
         5.9710736e-03,  6.8791618e-04, -1.4408793e-03],
       [ 8.5831922e-04, -3.5764184e-03,  4.6938710e-04, ...,
        -3.5385431e-03, -4.4757500e-03, -2.3120157e-03]], dtype=float32)>