In [131]:
import tensorflow as tf

from trainer.models.common.transformer import FeedForward, MultiHeadSelfAttentionLayer


class PositionalEmbedding(tf.keras.layers.Layer):
    """SASRec embedding is composed of a positional embedding layer and a normal embedding layer

    Input shape
      - token index 2D tensor with shape: ``(batch_size, sequence_length)``.

    Output shape
      - 3D tensor with shape: ``(batch_size, sequence_length, embedding_size)``.

    References
        - [Self-Attentive Sequential Recommendation](https://arxiv.org/pdf/1808.09781.pdf)
    """

    def __init__(self, token_embedding, seq_length=50, dim=50, **kwargs):
        super(PositionalEmbedding, self).__init__(**kwargs)
        assert seq_length % 2 == 0, "Output dimension needs to be an even integer"
        self.length = seq_length
        self.dim = dim
        self.token_emb = token_embedding
        self.position_emb = tf.keras.layers.Embedding(
            input_dim=seq_length, output_dim=dim
        )

    def call(self, inputs, **kwargs):
        length = tf.shape(inputs)[1]
        embedded_tokens = self.token_emb(inputs)
        embedded_positions = self.position_emb(tf.range(length))
        # This factor sets the relative scale of the embedding and positonal_encoding.
        embedded_tokens *= tf.math.sqrt(tf.cast(self.dim, tf.float32))
        return embedded_tokens + embedded_positions[tf.newaxis, :, :]

    # Pass mask from token_emb, https://www.tensorflow.org/guide/keras/understanding_masking_and_padding#supporting_masking_in_your_custom_layers
    def compute_mask(self, inputs, mask=None):
        return self.token_emb.compute_mask(inputs, mask=mask)


class SASRecBlock(tf.keras.layers.Layer):
    """SASRec block is a stack of self attention layer + MLP + layer norm + residual layers

    Input shape
      - token embedding 3D tensor with shape: ``(batch_size, sequence_length, embedding_size)``.

    Output shape
      - 3D tensor with shape: ``(batch_size, sequence_length, embedding_size)``.

    References
        - [Self-Attentive Sequential Recommendation](https://arxiv.org/pdf/1808.09781.pdf)
    """

    def __init__(self, head_num=1, dim=50, dropout=0.1, **kwargs):
        super(SASRecBlock, self).__init__(**kwargs)
        self.head_num = head_num
        self.dim = dim
        self.dropout = dropout
        self.attention = MultiHeadSelfAttentionLayer(
            head_num=head_num, key_dim=dim, dropout=dropout
        )
        self.ff = FeedForward(ff_dim=dim, dropout=dropout, model_dim=dim)

    def call(self, inputs, training=False):
        # must enable causal mask
        return self.ff(
            self.attention(inputs, inputs, inputs, training=training, use_causal_mask=True),
            training=training,
        )

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "head_num": self.head_num,
                "dim": self.dim,
                "dropout": self.dropout,
            }
        )
        return config


class SASRec(tf.keras.layers.Layer):
    """SASRec model is a stack of self attention layers

    Input shape
      - sequential token index 2D tensor with shape: ``(batch_size, sequence_length)``.
      - positive token index 2D tensor with shape: ``(batch_size, sequence_length)``.
      - negative token index 2D tensor with shape: ``(batch_size, sequence_length)``.

    Output shape
      - 3D tensor with shape: ``(batch_size, sequence_length, 2)``.

    References
        - [Self-Attentive Sequential Recommendation](https://arxiv.org/pdf/1808.09781.pdf)
    """

    def __init__(
        self,
        vocab_size,
        head_num=1,
        block_num=2,
        seq_length=50,
        dim=50,
        dropout=0.1,
        **kwargs
    ):
        super(SASRec, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.head_num = head_num
        self.block_num = block_num
        self.seq_length = seq_length
        self.dim = dim
        self.dropout = dropout
        # will be reused to general pos and neg embeddings
        self.token_emb = tf.keras.layers.Embedding(
            input_dim=vocab_size, output_dim=dim, mask_zero=True
        )
        self.positional_emb = PositionalEmbedding(
            self.token_emb, seq_length=seq_length, dim=dim
        )
        self.sas_blocks = [
            SASRecBlock(head_num=head_num, dim=dim, dropout=dropout)
            for _ in range(block_num)
        ]

    def call(self, inputs, training=False):
        input_token, pos, neg = inputs
        # shape [batch_size, token_length, dim]
        input_emb = self.positional_emb(input_token)
        pos_emb = self.token_emb(pos)
        neg_emb = self.token_emb(neg)
        for sas_block in self.sas_blocks:
            output_emb = sas_block(input_emb, training=training)
        # shape [batch_size, token_length, 1]
        pos_logits = tf.reduce_sum(output_emb * pos_emb, axis=-1, keepdims=True)
        neg_logits = tf.reduce_sum(output_emb * neg_emb, axis=-1, keepdims=True)
        # shape [batch_size, token_length, 2]
        return tf.concat([pos_logits, neg_logits], axis=-1)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "vocab_size": self.vocab_size,
                "block_num": self.block_num,
                "seq_length": self.seq_length,
                "head_num": self.head_num,
                "dim": self.dim,
                "dropout": self.dropout,
            }
        )
        return config


In [132]:
vocab_size = 10
tokens = tf.reshape(tf.range(9), [1, 9])
tokens

<tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[0, 1, 2, 3, 4, 5, 6, 7, 8]], dtype=int32)>

In [133]:
pos = tf.reshape(tf.range(1, 10), [1, 9])
pos

<tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)>

In [134]:
neg = tf.random.uniform([1, 9], minval=0, maxval=10, dtype=tf.dtypes.int32)
neg

<tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[3, 7, 4, 1, 3, 7, 4, 1, 0]], dtype=int32)>

In [135]:
sas_rec = SASRec(vocab_size)
(tokens, pos, neg)

(<tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[0, 1, 2, 3, 4, 5, 6, 7, 8]], dtype=int32)>,
 <tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)>,
 <tf.Tensor: shape=(1, 9), dtype=int32, numpy=array([[3, 7, 4, 1, 3, 7, 4, 1, 0]], dtype=int32)>)

In [136]:
sas_rec((tokens, pos, neg))

<tf.Tensor: shape=(1, 9, 2), dtype=float32, numpy=
array([[[-0.30672818, -0.1411577 ],
        [ 0.00989199,  0.00433109],
        [-0.19873658, -0.23819302],
        [-0.27444744, -0.39519155],
        [-0.09321697, -0.3705861 ],
        [ 0.14692208,  0.01091674],
        [ 0.11229304, -0.5366032 ],
        [ 0.18126515, -0.38814524],
        [-0.01035482, -0.18661167]]], dtype=float32)>

In [154]:
def loss(label, pred):
    return tf.reduce_sum(
        - tf.math.log(tf.math.sigmoid(pred[:,:,0]) + 1e-24) -
        tf.math.log(1 - tf.math.sigmoid(pred[:,:,1]) + 1e-24)
    ) / tf.cast(tf.reduce_sum(label), tf.float32)


def auc(label, pred):
    return tf.reduce_sum(
            ((tf.math.sign(pred[:,:,0] - pred[:,:,1]) + 1) / 2)
        ) / tf.cast(tf.reduce_sum(label), tf.float32)


In [155]:
class SASRecModel(tf.keras.Model):

    def call(self, inputs, training=False) -> tf.Tensor:
        logits = sas_rec(tf.split(inputs, 3, axis=1), training=training)
        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass
        return logits


In [156]:
sas_rec_model = SASRecModel()

In [157]:
sas_rec_model.compile(
            loss=loss,
            optimizer=tf.keras.optimizers.Adam(
                learning_rate=0.001,
            ),
            metrics=[auc],
        )


In [158]:
label = tf.repeat([[[1, 0]]], repeats=[9], axis=1)
label

<tf.Tensor: shape=(1, 9, 2), dtype=int32, numpy=
array([[[1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0]]], dtype=int32)>

In [159]:
tf.cast(tf.reduce_sum(label), tf.float32)

<tf.Tensor: shape=(), dtype=float32, numpy=9.0>

In [160]:
tf.stack([tokens, pos, neg], axis=1)

<tf.Tensor: shape=(1, 3, 9), dtype=int32, numpy=
array([[[0, 1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8, 9],
        [3, 7, 4, 1, 3, 7, 4, 1, 0]]], dtype=int32)>

In [161]:
dataset = tf.data.Dataset.from_tensor_slices((tf.stack([tokens, pos, neg], axis=1), label)).repeat(1000)

In [152]:
sas_rec_model.fit(dataset)



<keras.callbacks.History at 0x7f9e646551f0>