In [None]:
!pip install jax
!pip install flax

In [5]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-05-08 11:23:58--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-05-08 11:23:58 (60.4 MB/s) - ‘input.txt’ saved [1115394/1115394]

--2023-05-08 11:23:58--  http://./data/input.txt
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2023-05-08 11:23:58--
Total wall clock time: 0.2s
Downloaded: 1 files, 1.1M in 0.02s (60.4 MB/s)


In [92]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

vocab = sorted(list(set(text)))
vocab_size = len(vocab)

print(vocab_size)
print(''.join(vocab))


65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [93]:
itos = { i: s for i, s in enumerate(vocab) }
stoi = { s: i for i, s in enumerate(vocab) }

encode = lambda txt: [stoi[c] for c in txt ]
decode = lambda num: ''.join([itos[n] for n in num ])

encoded = encode("The test string")
decoded = decode(encoded)

print(encoded)
print(decoded)


[32, 46, 43, 1, 58, 43, 57, 58, 1, 57, 58, 56, 47, 52, 45]
The test string


In [94]:
import jax
import flax
import optax

from tqdm import tqdm
from typing import Callable
from jax import numpy as jnp
from flax import linen as nn
from functools import partial

In [95]:
data = jnp.array(encode(text))
data.shape

(1115394,)

In [96]:
n_val = int(0.9 * len(data))
train_data = data[:n_val]
val_data = data[n_val:]

train_data.shape, val_data.shape

((1003854,), (111540,))

In [184]:
block_size = 8
batch_size = 4
n_embed = 32

def get_batch(batch_key, *, split='train'):
    d = train_data if split == 'train' else val_data
    ix = jax.random.randint(batch_key, (batch_size,), 0, len(d)-block_size)
    x = jnp.stack([d[i:i+block_size] for i in ix])
    y = jnp.stack([d[i+1:i+block_size+1] for i in ix])
    return x, y

key = jax.random.PRNGKey(1337)

X, Y = get_batch(key)
print(X.shape)
print(Y.shape)

(4, 8)
(4, 8)


In [185]:
ones = jnp.ones(shape=(10, 10))
tril = jnp.tril(ones)
jax.lax.select(tril == 0, jax.lax.broadcast(jnp.NINF, ones.shape), ones)

Array([[  1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       [  1.,   1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       [  1.,   1.,   1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
       [  1.,   1.,   1.,   1., -inf, -inf, -inf, -inf, -inf, -inf],
       [  1.,   1.,   1.,   1.,   1., -inf, -inf, -inf, -inf, -inf],
       [  1.,   1.,   1.,   1.,   1.,   1., -inf, -inf, -inf, -inf],
       [  1.,   1.,   1.,   1.,   1.,   1.,   1., -inf, -inf, -inf],
       [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1., -inf, -inf],
       [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1., -inf],
       [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]],      dtype=float32)

In [214]:
class AttentionHead(nn.Module):

    wq_init: Callable = nn.initializers.lecun_normal()
    wk_init: Callable = nn.initializers.lecun_normal()
    wv_init: Callable = nn.initializers.lecun_normal()
    
    head_size: int = 32
    
    @nn.compact
    def __call__(self, inputs):
        
        B, T, C = inputs.shape

        # mask
        ones = jnp.ones(shape=(T, T))
        tril = jnp.tril(ones)
        tril = jnp.stack([tril] * B)
        
        WQ = self.param("WQ", self.wq_init, (C, self.head_size))
        WK = self.param("WK", self.wk_init, (C, self.head_size))
        WV = self.param("WV", self.wv_init, (C, self.head_size))
        
        Q = inputs @ WQ
        K = inputs @ WK
        V = inputs @ WV

        KT = K.transpose([0, 2, 1]) # B, H, T
        QK = Q @ KT / jnp.sqrt(self.head_size)
        QK = jax.lax.select(tril == 0, jax.lax.broadcast(jnp.NINF, QK.shape), QK)

        attention = nn.softmax(QK, axis=-1) @ V

        return attention

class MultiHeadAttention(nn.Module):

    head_number: int = 4
    head_size: int = 32 // 4

    def setup(self):
        self.heads = [AttentionHead(head_size=self.head_size) for i in range(self.head_number)]

    def __call__(self, inputs):
        results = [h(inputs) for h in self.heads]
        return jnp.concatenate(results, axis=-1)

class BigramLanguageModel(nn.Module):
    
    vocab_size: int
    block_size: int
    embedding_size: int
    head_number: int = 4

    def setup(self):
        self.token_embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.embedding_size)
        self.position_embedding = nn.Embed(num_embeddings=self.block_size, features=self.embedding_size)
        self.sa_heads = MultiHeadAttention(head_number=self.head_number, head_size=self.embedding_size // self.head_number)
        self.lm_head = nn.Dense(self.vocab_size)
    
    def __call__(self, idx):
        token_embeddings = self.token_embedding(idx) # B, T, C

        positions = jnp.arange(0, self.block_size)
        positions_embeddings = self.position_embedding(positions) # T, C

        embeddings = token_embeddings + positions_embeddings
        attention = self.sa_heads(embeddings)
        logits = self.lm_head(attention)

        return logits

@jax.jit
def bigram_loss(y_hat, y):
    losses = optax.softmax_cross_entropy_with_integer_labels(y_hat, y)
    return jnp.mean(losses)

def bigram_generate(apply_fn: Callable, key, idx, max_tokens=1):
    for i in range(max_tokens):
        key_i = jax.random.fold_in(key, i)
        logits = apply_fn(params, idx[:, -block_size:])
        logits = logits[:, -1, :] # B, C
        probs = nn.softmax(logits, axis=-1) # B, C
        idx_next = jax.random.categorical(key_i, logits, axis=-1) # B
        idx_next = jnp.expand_dims(idx_next, -1)
        idx = jnp.concatenate([idx, idx_next], axis=1) # B, T+1
    return idx

@partial(jax.jit, static_argnums=[0])
def train_step(apply_fn: Callable, params, opt_state, x_batch, y_batch):

    @jax.jit
    def loss_fn(params, x_batch, y_batch):
        y_hat = apply_fn(params, x_batch)
        return bigram_loss(y_hat, y_batch)

    loss, grad = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
    updates, opt_state = opt.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


In [216]:
key = jax.random.PRNGKey(1337)

model = BigramLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_size=n_embed)
params = model.init(key, X)

out = model.apply(params, X)
print(out.shape)

loss = bigram_loss(out, Y)
print(loss)

(4, 8, 65)
4.187803


In [217]:
key = jax.random.PRNGKey(1337)
key_init, key_shuffle = jax.random.split(key)

model = BigramLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_size=n_embed)
params = model.init(key, X)

opt = optax.adam(1e-3)
opt_state = opt.init(params)

In [220]:
for epoch in tqdm(range(10000)):
    key_batch = jax.random.fold_in(key=key_shuffle, data=epoch)
    X, Y = get_batch(key_batch)
    params, opt_state, loss = train_step(model.apply, params, opt_state, X, Y)

print(f"\n\nLoss: {loss}") # Loss: 2.5348286628723145

100%|██████████| 5000/5000 [01:22<00:00, 60.94it/s]



Loss: 2.577763557434082





In [221]:
Z = jnp.zeros(shape=(1, block_size), dtype=jnp.int32)
out = bigram_generate(model.apply, key_shuffle, Z, max_tokens=500)
out = decode(out[0][block_size:].tolist())
print(out)


Sornter itiese fwthe. cild ist
Now, Lor, Rou eurf lirscueber, here wut ADUS:
And som whind uporke king tosal Gldsiv da
Arer showncherith to thince.
STO:
Forerlinsow drvady ruar bergnt his weaus thiov thir gourege, wiole.

DUCHAV:
Ris hertorteswait, ist uto doo fithse thatUCEE:
Thatre with, Courn ithsou ihe my hate, inads aight
mo wrasto ander ldes mectorund trut for thith peney caimard thalampre, cich.
DUn KING ICHARARICARKhis Why hencyou
Cour ineas ccom, worveresy, my; weste.

NY:
And gher, son
