In [1]:
import numpy as np
from tqdm import tqdm

import jax
import jax.numpy as jnp

import flax
from flax import linen as nn

In [135]:
from typing import Callable

class MultiHeadSelfAttention(nn.Module):
    seq_len: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):
        seq_len = x.shape[1]
        d_k = self.d_model // self.n_heads

        q = nn.Dense(self.d_model)(x)
        k = nn.Dense(self.d_model)(x)
        v = nn.Dense(self.d_model)(x)

        q = q.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))
        k = k.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))
        v = v.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))

        a = jnp.matmul(q, k.transpose((0, 1, 3, 2))) / jnp.sqrt(d_k)

        mask = jnp.where(mask, 0, -jnp.inf)
        a += mask

        a = nn.softmax(a, axis=-1)
        a = nn.Dropout(0.1)(a, deterministic=not training)
        a = jnp.matmul(a, v)

        return a.transpose((0, 2, 1, 3)).reshape(-1, self.seq_len, self.d_model)


class MLP(nn.Module):
    d_model: int

    @nn.compact
    def __call__(self, x, training):
        x = nn.Dense(self.d_model * 4)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.d_model)(x)
        x = nn.Dropout(0.1)(x, deterministic=not training)

        return x


class Block(nn.Module):
    seq_len: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):

        x = x + MultiHeadSelfAttention(self.seq_len, self.d_model,
                                       self.n_heads)(nn.LayerNorm()(x), mask, training)
        x = x + nn.Dropout(0.1)(MLP(self.d_model)(nn.LayerNorm()
                                                  (x), training), deterministic=not training)

        return x


class GPT2(nn.Module):
    seq_len: int
    n_layers: int
    vocab_size: int
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, training=False):


        position_ids = jnp.arange(start=0, stop=self.seq_len, step=1)
        mask = jnp.triu(jnp.ones((1, self.seq_len, self.seq_len)), k=1) == 0

        content_embedding = nn.Embed(self.vocab_size, self.d_model)
        embeddings = content_embedding(x) + nn.Embed(self.seq_len, self.d_model)(position_ids)
        x = nn.Dropout(0.1)(embeddings)

        for _ in range(self.n_layers):
            x = Block(self.seq_len, self.d_model,
                      self.n_heads)(x, mask, training)

        x = nn.LayerNorm()(x)
        x = content_embedding.attend(x)

        return x


In [136]:
batch_size = 2
seq_len = 128
n_layers = 2
vocab_size = 1024
d_model = 768
n_heads = 8

d_k = d_model // n_heads

In [137]:
rng = jax.random.PRNGKey(42)
rng, dropout_rng = jax.random.split(rng)

x = jax.random.randint(rng, (batch_size, seq_len), 0, 1000, jnp.int32)

variables = GPT2(seq_len, n_layers, vocab_size, d_model, n_heads).init({'params': rng, 'dropout': dropout_rng}, x, training=False)
gpt2 = GPT2(seq_len, n_layers, vocab_size, d_model, n_heads)

In [138]:
def loss_fn(variables, batch, rng):
    x = batch[:, :-1]
    y = batch[:, 1:]
    y = jax.nn.one_hot(y, vocab_size)

    y_hat = gpt2.apply(variables, x, training=True, rngs={'dropout': rng})

    loss = jnp.sum(y * jax.nn.log_softmax(y_hat, axis=-1), axis=-1)
    return -jnp.mean(loss)

In [139]:
# @partial(jax.pmap, axis_name='batch')
def train_step(optimizer, batch, rng):
    rng, rng_dropout = jax.random.split(rng)

    loss, grad = jax.value_and_grad(loss_fn)(optimizer.target, batch, rng_dropout)

    # loss = jax.lax.pmean(loss, axis_name='batch')
    # grad = jax.lax.pmean(grad, axis_name='batch')

    optimizer = optimizer.apply_gradient(grad)

    return optimizer, loss, rng

In [140]:
optimizer = flax.optim.Adam(learning_rate=1e-4, beta1=0.5, beta2=0.9).create(variables)

In [141]:
x = jax.random.randint(rng, (batch_size, seq_len + 1), 0, 1000, jnp.int32)

In [142]:
optimizer, loss, rng = train_step(optimizer, x, rng)

In [143]:
loss

DeviceArray(7.443428, dtype=float32)

In [112]:
out = gpt2.apply(optimizer.target, x[:, :-1], training=True, rngs={'dropout': rng})

In [113]:
out2 = jnp.argmax(nn.softmax(out, axis=-1), axis=-1)

In [114]:
out2

DeviceArray([[371, 334, 451, 105, 244, 150, 707, 243, 501,  16, 627, 855,
              260, 138,  28, 626,  69, 934, 893, 146, 213, 556, 391, 170,
              391, 779, 957, 951, 239, 550, 451, 654,  18, 630, 278, 614,
               73, 424, 527, 503, 736, 970, 603, 707, 367, 458, 449, 159,
              949, 749,  15, 895, 523, 999,  89, 462, 957, 870, 898,  23,
              299, 339, 385, 109, 104, 678, 298, 178, 303, 617, 650, 485,
              391, 619, 941, 778, 122, 466, 634, 918, 871, 954, 848, 328,
              743, 244, 345,  17,  38, 775, 317, 283, 458, 838, 557, 121,
              396, 856, 761, 284, 234, 426, 541, 616, 873, 336, 671, 499,
               13, 784, 844, 225, 347, 412, 558, 714, 567, 926,  12, 864,
               92, 755, 417, 961, 998, 773,  18, 576],
             [400, 485, 625, 796, 682, 296, 637, 415,  68, 778, 977, 593,
              114, 305, 929, 118, 698, 735, 628, 932,   0, 424, 117, 684,
              874, 141, 113, 216, 169, 833, 116, 394, 487

In [115]:
x

Buffer([[662, 371, 334, 451, 105, 244, 150, 707, 243, 501,  16, 627, 855,
         260, 138,  28, 626,  69, 934, 893, 146, 213, 556, 391, 170, 391,
         779, 957, 951, 239, 550, 451, 654,  18, 630, 278, 614,  73, 424,
         527, 503, 736, 970, 603, 707, 367, 458, 449, 159, 949, 749,  15,
         895, 523, 999,  89, 462, 957, 870, 898,  23, 299, 339, 385, 109,
         104, 678, 298, 178, 303, 617, 650, 485, 391, 619, 941, 778, 122,
         466, 634, 918, 871, 954, 848, 328, 743, 244, 345,  17,  38, 775,
         317, 283, 458, 838, 557, 121, 396, 856, 761, 284, 234, 426, 541,
         616, 873, 336, 671, 499,  13, 784, 844, 225, 347, 412, 558, 714,
         567, 926,  12, 864,  92, 755, 417, 961, 998, 773,  18, 576],
        [586, 400, 485, 625, 796, 682, 296, 637, 415,  68, 778, 977, 593,
         114, 305, 929, 118, 698, 735, 628, 932,   0, 424, 117, 684, 874,
         141, 113, 216, 169, 833, 116, 394, 487,  75,  36, 830, 433, 609,
          72, 909, 681,   3, 993, 451,  91

In [116]:
out2 == x[:, 1:]

DeviceArray([[ True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,  True,  True,  True,  True,  True,  True,  True,
               True,