In [None]:
!pip install jax
!pip install flax

In [5]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-05-08 11:23:58--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-05-08 11:23:58 (60.4 MB/s) - ‘input.txt’ saved [1115394/1115394]

--2023-05-08 11:23:58--  http://./data/input.txt
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2023-05-08 11:23:58--
Total wall clock time: 0.2s
Downloaded: 1 files, 1.1M in 0.02s (60.4 MB/s)


In [92]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

vocab = sorted(list(set(text)))
vocab_size = len(vocab)

print(vocab_size)
print(''.join(vocab))


65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [93]:
itos = { i: s for i, s in enumerate(vocab) }
stoi = { s: i for i, s in enumerate(vocab) }

encode = lambda txt: [stoi[c] for c in txt ]
decode = lambda num: ''.join([itos[n] for n in num ])

encoded = encode("The test string")
decoded = decode(encoded)

print(encoded)
print(decoded)


[32, 46, 43, 1, 58, 43, 57, 58, 1, 57, 58, 56, 47, 52, 45]
The test string


In [94]:
import jax
import flax
import optax

from tqdm import tqdm
from typing import Callable
from jax import numpy as jnp
from flax import linen as nn
from functools import partial

In [95]:
data = jnp.array(encode(text))
data.shape

(1115394,)

In [96]:
n_val = int(0.9 * len(data))
train_data = data[:n_val]
val_data = data[n_val:]

train_data.shape, val_data.shape

((1003854,), (111540,))

In [97]:
block_size = 8
batch_size = 32

def get_batch(batch_key, *, split='train'):
    d = train_data if split == 'train' else val_data
    ix = jax.random.randint(batch_key, (batch_size,), 0, len(d)-block_size)
    x = jnp.stack([d[i:i+block_size] for i in ix])
    y = jnp.stack([d[i+1:i+block_size+1] for i in ix])
    return x, y

key = jax.random.PRNGKey(1337)

X, Y = get_batch(key)
print(X.shape)
print(Y.shape)

(32, 8)
(32, 8)


In [98]:
class BigramLanguageModel(nn.Module):
    
    vocabulary_size: int
    embedding_size: int
    
    @nn.compact
    def __call__(self, idx):
        logits = nn.Embed(num_embeddings=self.vocabulary_size, features=self.embedding_size)(idx)
        return logits

@jax.jit
def bigram_loss(y_hat, y):
    losses = optax.softmax_cross_entropy_with_integer_labels(y_hat, y)
    return jnp.mean(losses)

def bigram_generate(apply_fn: Callable, key, idx, max_tokens=1):
    for i in range(max_tokens):
        key_i = jax.random.fold_in(key, i)
        logits = apply_fn(params, idx)
        logits = logits[:, -1, :] # B, C
        probs = nn.softmax(logits, axis=-1) # B, C
        idx_next = jax.random.categorical(key_i, logits, axis=-1) # B
        idx_next = jnp.expand_dims(idx_next, -1)
        idx = jnp.concatenate([idx, idx_next], axis=1)
    return idx

@partial(jax.jit, static_argnums=[0])
def train_step(apply_fn: Callable, params, opt_state, x_batch, y_batch):

    @jax.jit
    def loss_fn(params, x_batch, y_batch):
        y_hat = apply_fn(params, x_batch)
        return bigram_loss(y_hat, y_batch)

    loss, grad = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
    updates, opt_state = opt.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


key = jax.random.PRNGKey(1337)

model = BigramLanguageModel(vocab_size, vocab_size)
params = model.init(key, X)
out = model.apply(params, X)
print(out.shape)

loss = bigram_loss(out, Y)
print(loss)


(32, 8, 65)
4.1763463


In [99]:
key = jax.random.PRNGKey(1337)
key_init, key_shuffle = jax.random.split(key)

model = BigramLanguageModel(vocab_size, vocab_size)
params = model.init(key, X)

opt = optax.adam(1e-3)
opt_state = opt.init(params)

In [102]:
for epoch in tqdm(range(3000)):
    key_batch = jax.random.fold_in(key=key_shuffle, data=epoch)
    X, Y = get_batch(key_batch)
    params, opt_state, loss = train_step(model.apply, params, opt_state, X, Y)

print(f"\n\nLoss: {loss}")

100%|██████████| 3000/3000 [05:11<00:00,  9.63it/s]



Loss: 2.5348286628723145





In [103]:
Z = jnp.zeros(shape=(1, 1), dtype=jnp.int32)
out = bigram_generate(model.apply, key_shuffle, Z, max_tokens=500)
out = decode(out[0].tolist())
print(out)





QBINGBitisin f the.jctito ad

Tof L:
YBRLLoes f lir.

Gber, he, kwak!AD:

NI HE om wNThe urNGRI bee,LO:

And dsis ds
Arer showh herith tofouthan.
STOFis herlin d yor:
UE:r ar bergnt hisewe,$Fou boved sthethueRGour h tur y wer wamy lhertof$K:
Rit, ist uilldoo fisas t malUCOMoominor: wimuld jararit sou iho my he bbyouses aigathmo whVUKE? muerLOdis mectorundo,

SULONCUME hooune s zd?
IAro!QIK:
LIZ$Tiche.DURimanoncethatht?

NhighWh whengyof

?
HVI:
Louccor t.
WI BGsiskn!;qLINot.
Thartane gh f, spA
