In [1]:
import numpy as np
from tqdm import tqdm

import jax
import jax.numpy as jnp

import flax
from flax import linen as nn

In [287]:
class MultiHeadSelfAttention(nn.Module):
    seq_len: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):
        seq_len = x.shape[0]
        d_k = self.d_model // self.n_heads

        q = nn.Dense(self.d_model)(x)
        k = nn.Dense(self.d_model)(x)
        v = nn.Dense(self.d_model)(x)

        q = q.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))
        k = k.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))
        v = v.reshape((seq_len, self.n_heads, d_k)).transpose((1, 0, 2))

        a = jnp.matmul(q, k.transpose((0, 2, 1))) / jnp.sqrt(d_k)

        mask = jnp.where(mask, 0, -jnp.inf)
        a += mask
    
        a = nn.softmax(a, axis=-1)
        a = nn.Dropout(0.1)(a, deterministic=not training)
        a = jnp.matmul(a, v)

        return a.transpose((1, 0, 2)).reshape(self.seq_len, self.d_model)

In [288]:
class MLP(nn.Module):
    d_model: int

    @nn.compact
    def __call__(self, x, training):
        x = nn.Dense(self.d_model * 4)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.d_model)(x)
        x = nn.Dropout(0.1)(x, deterministic=not training)

        return x

In [289]:
class Block(nn.Module):
    seq_len: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):

        x = x + MultiHeadSelfAttention(self.seq_len, self.d_model, self.n_heads)(nn.LayerNorm()(x), mask, training)
        x = x + nn.Dropout(0.1)(MLP(self.d_model)(nn.LayerNorm()(x), training), deterministic=not training)

        return x

In [290]:
class GPT2(nn.Module):
    seq_len: int
    n_layers: int
    vocab_size: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, training=False):
        position_ids = jnp.arange(start=0, stop=self.seq_len, step=1)
        mask = jnp.triu(jnp.ones((1, self.seq_len, self.seq_len)), k=1) == 0

        embeddings = nn.Embed(self.vocab_size, self.d_model)(x) + nn.Embed(self.seq_len, self.d_model)(position_ids)
        x = nn.Dropout(0.1)(embeddings)
        
        for _ in range(self.n_layers):
            x = Block(self.seq_len, self.d_model, self.n_heads)(x, mask, training)
        
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.vocab_size)(x)

        return x

In [291]:
seq_len = 128
n_layers = 2
vocab_size = 1024
d_model = 768
n_heads = 8

d_k = d_model // n_heads

x = jax.random.normal(rng, (seq_len, d_model))

In [292]:
rng = jax.random.PRNGKey(42)
rng, dropout_rng = jax.random.split(rng)

mask = jnp.triu(jnp.ones((1, seq_len, seq_len)), k=1) == 0

init = jnp.ones((seq_len, d_model), jnp.float32)

init_gpt2 = jnp.array(1)#jnp.ones((seq_len), jnp.float32)

# variables_attention = MultiHeadSelfAttention(seq_len, d_model, n_heads).init({'params': rng, 'dropout': dropout_rng}, init, mask, training=False)
# variables_mlp = MLP(d_model).init({'params': rng, 'dropout': dropout_rng}, init, training=False)
# variables_block = Block(seq_len, d_model, n_heads).init({'params': rng, 'dropout': dropout_rng}, init, mask, training=False)
variables_gpt2 = GPT2(seq_len, n_layers, vocab_size, d_model, n_heads).init({'params': rng, 'dropout': dropout_rng}, init_gpt2, training=False)

In [172]:
out = MultiHeadSelfAttention(seq_len, d_model, n_heads).apply(variables_attention, x, mask, training=True, rngs={'dropout': rng})

In [202]:
out = MLP(d_model).apply(variables_mlp, out, training=True, rngs={'dropout': rng})

In [300]:
out = Block(seq_len, d_model, n_heads).apply(variables_block, x, mask, training=True, rngs={'dropout': rng})

0

In [295]:
out = GPT2(seq_len, n_layers, vocab_size, d_model, n_heads).apply(variables_gpt2, init_gpt2, training=False, rngs={'dropout': rng})

In [296]:
out.shape

(128, 1024)

In [301]:
out

DeviceArray([[-0.07671511, -1.7705754 ,  0.2934252 , ...,  1.9849722 ,
              -2.142795  , -1.6366477 ],
             [-1.7738787 , -0.36792904,  2.158223  , ...,  0.09209104,
              -0.5125929 ,  1.3942482 ],
             [-1.5357774 , -0.16519678,  0.7202017 , ..., -0.80888844,
              -1.4017371 , -0.7294917 ],
             ...,
             [ 0.7833952 , -0.35864604, -0.6492446 , ..., -1.5745726 ,
              -0.18905675, -0.5896662 ],
             [ 0.49872887,  1.8160241 , -0.7537211 , ..., -1.7014841 ,
              -1.2140181 ,  1.3204899 ],
             [ 1.5718806 ,  0.09945551,  0.9722837 , ...,  0.93707865,
              -2.2785974 ,  0.32326937]], dtype=float32)