In [250]:
import numpy as np
from tqdm import tqdm

import jax
import jax.numpy as jnp

import flax
from flax import linen as nn

In [281]:
class MultiHeadAttention(nn.Module):
    seq_len: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, q, k, v, mask=None):
        seq_len = q.shape[0]
        d_k = self.d_model // self.n_heads

        q = nn.Dense(self.d_model)(q)
        k = nn.Dense(self.d_model)(k)
        v = nn.Dense(self.d_model)(v)

        q = q.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))
        k = k.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))
        v = v.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))

        a = jnp.matmul(q, k.transpose((0, 2, 1))) / jnp.sqrt(d_k)

        if mask is not None:
            mask = jnp.where(mask, 0, -jnp.inf)
            a += mask
    
        a = nn.softmax(a, axis=-1)
        # a = nn.Dropout(0.1)(a)
        a = jnp.matmul(a, v)

        return a.transpose((1, 0, 2)).reshape(self.seq_len, self.d_model)

In [282]:
class MLP(nn.Module):
    d_model:

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(d_model * 4)(x)
        x = nn.gelu(x)
        x = nn.Dropout(0.1)(x)
        x = nn.Dense(d_model)(x)

        return x

In [283]:
seq_len = 128
d_model = 768
n_heads = 8
d_k = d_model // n_heads

q = jnp.ones((seq_len, d_model))
k = jnp.ones((seq_len, d_model))
v = jnp.ones((seq_len, d_model))

In [284]:
rng = jax.random.PRNGKey(42)

init_multi_head = jnp.ones((seq_len, d_model), jnp.float32)

variables_attention = MultiHeadAttention(seq_len, d_model, n_heads).init(rng, init_multi_head, init_multi_head, init_multi_head)

In [285]:
mask = jnp.triu(jnp.ones((1, seq_len, seq_len)), k=1) == 0
temp = jnp.ones((1, seq_len, seq_len))

In [286]:
out = MultiHeadAttention(seq_len, d_model, n_heads).apply(variables_attention, q, k, v, mask)

In [287]:
out.shape

(128, 768)