In [1]:
import numpy as np
import tensorflow as tf

In [3]:
from deepgroebner.networks import ParallelEmbeddingLayer

In [30]:
@tf.function
def scaled_dot_product_attention(Q, K, V, mask=None):
    """Return calculated vectors and attention weights.

    Parameters
    ----------
    Q : `Tensor` of type `tf.float32' and shape (..., dq, d1)
        Tensor of queries as rows.
    K : `Tensor` of type `tf.float32` and shape (..., dkv, d1)
        Tensor of keys as rows.
    V : `Tensor` of type `tf.float32` and shape (..., dkv, d2)
        Tensor of values as rows.
    mask : `Tensor of type `tf.bool' and shape (..., 1, dkv)
        The mask representing valid key/value rows.

    Returns
    -------
    output : `Tensor` of type `tf.float32` and shape (..., dq, d2)
        Processed batch of Q, K, V.
    attention_weights : `Tensor` of type `tf.float32` and shape (..., dq, dkv)
        Attention weights from intermediate step.

    """
    QK = tf.matmul(Q, K, transpose_b=True)
    d = tf.cast(tf.shape(K)[-1], tf.float32)
    attention_logits = QK / tf.math.sqrt(d)
    if mask is not None:
        attention_logits += tf.cast(~mask, tf.float32) * -1e9
    attention_weights = tf.nn.softmax(attention_logits)
    output = tf.matmul(attention_weights, V)
    return output, attention_weights

In [59]:
class AttentionPoolingLayer(tf.keras.layers.Layer):
    
    def __init__(self, dim):
        super(AttentionPoolingLayer, self).__init__()
        self.dim = dim
        self.Wk = tf.keras.layers.Dense(dim)
        self.Wv = tf.keras.layers.Dense(dim)
        self.dense = tf.keras.layers.Dense(1)
    
    def build(self, batch_input_shape):
        self.Q = self.add_weight(name='query',
                                 shape=[1, self.dim],
                                 initializer='glorot_normal')
        super(AttentionPoolingLayer, self).build(batch_input_shape)

    def call(self, batch, mask=None):
        K = self.Wk(batch)
        V = self.Wv(batch)
        if mask is not None:
            mask = mask[:, tf.newaxis, tf.newaxis, :]
        X, attn_weights = scaled_dot_product_attention(self.Q, K, V, mask=mask)
        return tf.squeeze(self.dense(X), axis=-1)

In [60]:
layer = AttentionPoolingLayer(13)

In [63]:
state = np.random.rand(3, 10, 8).astype(np.float32)
state

array([[[0.06788629, 0.57301205, 0.04953672, 0.34167394, 0.33334586,
         0.5058849 , 0.745126  , 0.959517  ],
        [0.8084456 , 0.8914087 , 0.30999458, 0.8740726 , 0.15012875,
         0.5063418 , 0.76032823, 0.5649372 ],
        [0.80377656, 0.90998304, 0.98673   , 0.8215631 , 0.48435026,
         0.64458954, 0.16177076, 0.09962187],
        [0.59020656, 0.9020649 , 0.46243715, 0.14894478, 0.62541765,
         0.357886  , 0.93383676, 0.91721267],
        [0.6725505 , 0.03963019, 0.25039616, 0.9717347 , 0.4407414 ,
         0.43552145, 0.47752053, 0.6523828 ],
        [0.28118044, 0.64923483, 0.04277411, 0.550966  , 0.43449387,
         0.48481938, 0.7963205 , 0.6660466 ],
        [0.64957374, 0.27097702, 0.2389206 , 0.9549677 , 0.43790942,
         0.36531553, 0.5649379 , 0.8762377 ],
        [0.40529162, 0.96557724, 0.1817433 , 0.06427106, 0.63150704,
         0.46903282, 0.7546195 , 0.7125125 ],
        [0.9801216 , 0.6511995 , 0.5362876 , 0.43970835, 0.13224849,
         0.

In [64]:
layer(state)

<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.4520214 ],
       [-0.48208076],
       [-0.38833997]], dtype=float32)>