In [9]:
import tensorflow as tf
import numpy as np

class RotaryPositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, d_model, max_seq_len):
        super(RotaryPositionalEmbedding, self).__init__()

        rotation_matrix = np.zeros((d_model, d_model), dtype=np.float32)
        for i in range(d_model):
            for j in range(d_model):
                rotation_matrix[i, j] = np.cos(i * j * 0.01)
        self.rotation_matrix = tf.convert_to_tensor(rotation_matrix)

        positional_embedding = np.zeros((max_seq_len, d_model), dtype=np.float32)
        for i in range(max_seq_len):
            for j in range(d_model):
                positional_embedding[i, j] = np.cos(i * j * 0.01)
        self.positional_embedding = tf.convert_to_tensor(positional_embedding)

    def call(self, inputs):
        inputs += self.positional_embedding[:tf.shape(inputs)[1], :]
        outputs = tf.matmul(inputs, self.rotation_matrix)

        return outputs





d_model = 64
max_seq_len = 100
batch_size = 32
seq_len = 50
rotary_embedding_layer = RotaryPositionalEmbedding(d_model, max_seq_len)

dummy_input = tf.random.uniform((batch_size, seq_len, d_model), dtype=tf.float32)

output = rotary_embedding_layer(dummy_input)

print("Output shape:", output.shape)


Output shape: (32, 50, 64)


In [10]:
output

<tf.Tensor: shape=(32, 50, 64), dtype=float32, numpy=
array([[[ 9.5585991e+01,  8.9043701e+01,  7.0989044e+01, ...,
          6.0270667e+00,  6.9009361e+00,  5.9935856e+00],
        [ 8.7726990e+01,  8.2514854e+01,  6.8032700e+01, ...,
          3.1373153e+00,  3.4993145e+00,  3.0325756e+00],
        [ 8.4850082e+01,  8.0066010e+01,  6.6779739e+01, ...,
          2.1989565e+00,  2.5980859e+00,  2.5391860e+00],
        ...,
        [ 3.4925240e+01,  3.3035671e+01,  2.7746836e+01, ...,
          3.5989296e+00,  1.9517485e+00,  5.6525409e-02],
        [ 2.9421074e+01,  2.8060596e+01,  2.4223621e+01, ...,
          2.3800254e+00,  1.7422734e+00,  5.9251755e-01],
        [ 3.5072178e+01,  3.2647770e+01,  2.5946465e+01, ...,
          6.8680978e+00,  6.7413759e+00,  4.8270411e+00]],

       [[ 9.8080315e+01,  9.1772514e+01,  7.4329483e+01, ...,
          2.7503026e+00,  3.9117270e+00,  4.1997595e+00],
        [ 9.1774818e+01,  8.5958504e+01,  6.9836006e+01, ...,
          1.1767038e+00,  1.7