In [17]:
from dataclasses import dataclass
from functools import partial

import numpy as np
from tqdm import tqdm

import jax
import jax.numpy as jnp

import flax
from flax import linen as nn

In [2]:
from tokenizers import Tokenizer, ByteLevelBPETokenizer
files = [f"./wikitext-2-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]]
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(files, vocab_size=32768)
tokenizer.save('./tokenizer.json')
tokenizer = Tokenizer.from_file('./tokenizer.json')

In [3]:
class MultiHeadSelfAttention(nn.Module):
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):
        seq_len = x.shape[1]
        d_k = self.d_model // self.n_heads

        q = nn.Dense(self.d_model)(x)
        k = nn.Dense(self.d_model)(x)
        v = nn.Dense(self.d_model)(x)

        q = q.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))
        k = k.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))
        v = v.reshape((-1, seq_len, self.n_heads, d_k)).transpose((0, 2, 1, 3))

        a = jnp.matmul(q, k.transpose((0, 1, 3, 2))) / jnp.sqrt(d_k)

        mask = jnp.where(mask, 0, -jnp.inf)
        a += mask

        a = nn.softmax(a, axis=-1)
        a = nn.Dropout(0.1)(a, deterministic=not training)
        a = jnp.matmul(a, v)

        return a.transpose((0, 2, 1, 3)).reshape(-1, seq_len, self.d_model)


class MLP(nn.Module):
    d_model: int

    @nn.compact
    def __call__(self, x, training):
        x = nn.Dense(self.d_model * 4)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.d_model)(x)
        x = nn.Dropout(0.1)(x, deterministic=not training)

        return x


class Block(nn.Module):
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x, mask, training):

        x = x + MultiHeadSelfAttention(self.d_model,
                                       self.n_heads)(nn.LayerNorm()(x), mask, training)
        x = x + nn.Dropout(0.1)(MLP(self.d_model)(nn.LayerNorm()
                                                  (x), training), deterministic=not training)

        return x


class GPT2(nn.Module):
    config: dataclass

    @nn.compact
    def __call__(self, x, training=False):
        seq_len = x.shape[-1]

        position_ids = jnp.arange(start=0, stop=seq_len, step=1)
        mask = jnp.triu(
            jnp.ones((1, seq_len, seq_len)), k=1) == 0

        content_embedding = nn.Embed(
            self.config.vocab_size, self.config.d_model)
        embeddings = content_embedding(
            x) + nn.Embed(self.config.max_seq_len, self.config.d_model)(position_ids)
        x = nn.Dropout(0.1)(embeddings)

        for _ in range(self.config.n_layers):
            x = Block(self.config.d_model, self.config.n_heads)(
                x, mask, training)

        x = nn.LayerNorm()(x)
        x = content_embedding.attend(x)

        return x

In [4]:
@dataclass
class Config:
    fast: bool = False

    batch_size: int = 1
    epochs: int = 1

    max_seq_len: int = 128
    n_layers: int = 2
    vocab_size: int = 32768
    d_model: int = 768
    n_heads: int = 8

In [5]:
config = Config()

In [9]:
rng = jax.random.PRNGKey(42)
rng, dropout_rng = jax.random.split(rng)

x = jax.random.randint(rng, (config.batch_size, config.max_seq_len), 0, 1000, jnp.int32)
variables = GPT2(config).init({'params': rng, 'dropout': rng}, x, training=False)

In [11]:
x = jax.random.randint(rng, (config.batch_size, 1), 0, 1000, jnp.int32)

y_hat = GPT2(config).apply(variables, x, training=True, rngs={'dropout': rng})

In [12]:
y_hat.shape

(1, 1, 32768)

In [45]:
generated = tokenizer.encode(' ').ids

In [46]:
for i in tqdm(range(16)):
    rng, _ = jax.random.split(rng, 2)

    x = jnp.array(generated).reshape(1, -1)
    logits = GPT2(config).apply(
        variables, x, training=False, rngs={'dropout': rng})
    preds = nn.softmax(logits, axis=-1)

    next_token = jax.random.categorical(rng, preds[0, -1])
    generated += [int(next_token)]

# print(f'Dataset: {tokenizer.decode(batches[:config.seq_len])}')
print("\n")
print(f'Continuation: {tokenizer.decode(generated)}')


100%|██████████| 16/16 [00:42<00:00,  2.66s/it]

Continuation:   Im ni slot resistance chail selected Ornette Sting Coward nucleoplasm Blonde Cruz � precisely rit Wheelchair



In [82]:
y_hat.shape

(2, 1, 32768)