In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np


In [None]:
class PositionalEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_encoding = self.positional_encoding(max_len, d_model)

    def positional_encoding(self, max_len, d_model):
        pos = np.arange(max_len)[:, np.newaxis]
        i = np.arange(d_model)[np.newaxis, :]

        angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
        angle_rads = pos * angle_rates

        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        return tf.cast(angle_rads[np.newaxis, ...], dtype=tf.float32)

    def call(self, x):
        return x + self.pos_encoding[:, :tf.shape(x)[1], :]


In [None]:
class SelfAttention(layers.Layer):
    def __init__(self, d_model):
        super().__init__()
        self.q = layers.Dense(d_model)
        self.k = layers.Dense(d_model)
        self.v = layers.Dense(d_model)
        self.scale = tf.math.sqrt(tf.cast(d_model, tf.float32))

    def call(self, x):
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)

        scores = tf.matmul(Q, K, transpose_b=True) / self.scale
        weights = tf.nn.softmax(scores, axis=-1)

        output = tf.matmul(weights, V)
        return output


In [None]:
class TransformerEncoder(layers.Layer):
    def __init__(self, d_model):
        super().__init__()

        self.attention = SelfAttention(d_model)
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()

        self.ffn = tf.keras.Sequential([
            layers.Dense(128, activation="relu"),
            layers.Dense(d_model)
        ])

    def call(self, x):
        attn_output = self.attention(x)
        x = self.norm1(x + attn_output)

        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)

        return x


In [None]:
class SimpleTransformer(tf.keras.Model):
    def __init__(self, vocab_size, max_len, d_model):
        super().__init__()

        self.embedding = layers.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(max_len, d_model)
        self.encoder = TransformerEncoder(d_model)

    def call(self, x):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.encoder(x)
        return x


In [None]:
vocab_size = 1000
max_len = 10
d_model = 64

model = SimpleTransformer(vocab_size, max_len, d_model)

sample_input = tf.constant([[1, 5, 23, 45, 0, 0, 0, 0, 0, 0]])
output = model(sample_input)

print(output.shape)


(1, 10, 64)
