<a href="https://colab.research.google.com/github/imxj/imxj.github.io/blob/master/colabs/llms%20/jax_gpt_dev.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Changed Andrej Karpathy's minGPT from pytorch to Jax code (see [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT).

## import

In [2]:
import time
import jax
from jax import device_put
import jax.numpy as jnp
from jax import lax
import jax.random as random
import jax.nn as jnn
from jax.nn.initializers import normal

import flax.linen as nn

import optax
from optax import adam

## Building a GPT

### load data

In [3]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-04-02 04:10:38--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-02 04:10:39 (25.5 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [6]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [7]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [8]:
'!' in chars

True

In [9]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [10]:
# # let's now encode the entire text dataset and store it into a torch.Tensor
# import torch # we use PyTorch: https://pytorch.org
# data = torch.tensor(encode(text), dtype=torch.long)
# print(data.shape, data.dtype)
# print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

# Assuming `encode` is a function defined elsewhere
data = jnp.array(encode(text), dtype=jnp.int64)
print(data.shape, data.dtype)
print(data[:1000])

  data = jnp.array(encode(text), dtype=jnp.int64)


(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

#### split into train and test

In [11]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [12]:
val_data.shape

(111540,)

In [13]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

#### build feature input x and target output y

In [14]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    print(t)
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

0
when input is [18] the target: 47
1
when input is [18 47] the target: 56
2
when input is [18 47 56] the target: 57
3
when input is [18 47 56 57] the target: 58
4
when input is [18 47 56 57 58] the target: 1
5
when input is [18 47 56 57 58  1] the target: 15
6
when input is [18 47 56 57 58  1 15] the target: 47
7
when input is [18 47 56 57 58  1 15 47] the target: 58


In [15]:
prng = jax.random.PRNGKey(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
ix = random.randint(random.PRNGKey(0), (batch_size,), 0, len(data) - block_size)

def get_batch(split, subkey):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)
    # x = jnp.stack([data[i:i+block_size] for i in ix])
    # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    # optimize the above code ^^.
    # speed up by using dynamic_slice and vmap
    def slice_data(i):
        return jax.lax.dynamic_slice(data, (i,), (block_size,))

    x = jax.vmap(slice_data)(ix)
    y = jax.vmap(slice_data)(ix+1)
    x, y = device_put(x), device_put(y)
    return x, y

xb, yb = get_batch('train', prng)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
(4, 8)
[[ 1 51 39 52 58 50 43  0]
 [25 17 27 10  0 27  6  1]
 [47 51  8  0 14 59 58  1]
 [ 1 57 59 41 46  1 50 43]]
targets:
(4, 8)
[[51 39 52 58 50 43  0 53]
 [17 27 10  0 27  6  1 50]
 [51  8  0 14 59 58  1 46]
 [57 59 41 46  1 50 43 52]]
----
when input is [1] the target: 51
when input is [1, 51] the target: 39
when input is [1, 51, 39] the target: 52
when input is [1, 51, 39, 52] the target: 58
when input is [1, 51, 39, 52, 58] the target: 50
when input is [1, 51, 39, 52, 58, 50] the target: 43
when input is [1, 51, 39, 52, 58, 50, 43] the target: 0
when input is [1, 51, 39, 52, 58, 50, 43, 0] the target: 53
when input is [25] the target: 17
when input is [25, 17] the target: 27
when input is [25, 17, 27] the target: 10
when input is [25, 17, 27, 10] the target: 0
when input is [25, 17, 27, 10, 0] the target: 27
when input is [25, 17, 27, 10, 0, 27] the target: 6
when input is [25, 17, 27, 10, 0, 27, 6] the target: 1
when input is [25, 17, 27, 10, 0, 27, 6, 1] the target: 5

In [16]:
print(xb) # our input to the transformer

[[ 1 51 39 52 58 50 43  0]
 [25 17 27 10  0 27  6  1]
 [47 51  8  0 14 59 58  1]
 [ 1 57 59 41 46  1 50 43]]


### Baseline model: Bigram

In [17]:
class BigramLanguageModel(nn.Module):
    vocab_size: int

    @nn.compact
    def __call__(self, idx, targets=None):
        # Token embedding table
        embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)
        logits = embedding_table(idx)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.reshape(B*T, C)
            targets = targets.reshape(B*T)
            loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()

        return logits, loss

# Example usage
model = BigramLanguageModel(vocab_size)

# Initialize model parameters and optimizer
key = random.PRNGKey(1337)
params = model.init(key, jnp.ones((1, 1), jnp.int32))

# jax jit the model apply to speed up
flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))

logits, loss = flax_apply_jitted(params, xb, yb)
print(loss)

(32, 65)
4.1973763


#### training

In [35]:
# Define the optimizer
learning_rate = 1e-3  # Adjust as needed
tx = adam(learning_rate)

# Initialize model parameters and optimizer state
prng = random.PRNGKey(1337)
params = model.init(prng, jnp.ones((1, 1), jnp.int32))
opt_state = tx.init(params)

# Loss function (assuming you have a batch of data: xb, yb)
def loss_fn(params, xb, yb):
    print(xb, yb)
    logits, loss = model.apply(params, xb, yb)
    return loss

# Update function for a single training step
@jax.jit
def update_step(params, opt_state, xb, yb):
    loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)
    updates, opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Training loop (example)
batch_size = 32
for steps in range(100):
    t1 = time.time()
    prng, subkey = random.split(prng)
    xb, yb = get_batch('train', subkey)
    t2 = time.time()
    params, opt_state, loss = update_step(params, opt_state, xb, yb)
    print(f"Epoch {steps},Loss: {loss}, sample batch: {t2-t1}, forward pass and backward grad: {time.time()-t1}")

flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(jax.lax.stop_gradient(params), xb, yb))

Traced<ShapedArray(int32[32,32])>with<DynamicJaxprTrace(level=1/0)> Traced<ShapedArray(int32[32,32])>with<DynamicJaxprTrace(level=1/0)>
Epoch 0,Loss: 4.176701545715332, sample batch: 0.006464958190917969, forward pass and backward grad: 0.3173556327819824
Epoch 1,Loss: 4.168148040771484, sample batch: 0.006571531295776367, forward pass and backward grad: 0.007528066635131836
Epoch 2,Loss: 4.176245212554932, sample batch: 0.006763458251953125, forward pass and backward grad: 0.007700681686401367
Epoch 3,Loss: 4.169358730316162, sample batch: 0.007287740707397461, forward pass and backward grad: 0.009604454040527344
Epoch 4,Loss: 4.175686836242676, sample batch: 0.009287357330322266, forward pass and backward grad: 0.010252714157104492
Epoch 5,Loss: 4.170393943786621, sample batch: 0.0073015689849853516, forward pass and backward grad: 0.008276939392089844
Epoch 6,Loss: 4.169349193572998, sample batch: 0.006832599639892578, forward pass and backward grad: 0.007781982421875
Epoch 7,Loss: 

#### inference
Slower than torch code
 1. logits[:, -1, :]
 2. random.categorical  

In [37]:
def generate(params, flax_apply_jitted, key, idx, max_new_tokens):
  for _ in range(max_new_tokens):
      logits, _ = flax_apply_jitted(params, idx, None)
      logits = logits[:, -1, :]  # (B, C)
      key, subkey = random.split(key)
      idx_next = random.categorical(subkey, logits)[:, None]  # (B, 1)
      idx = jnp.concatenate((idx, idx_next), axis=1)  # (B, T+1)
  return idx

%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))


yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x
pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJrbX;b!HR'Fr;rI?&Nui3;woGFdW pAZYho3YO!hHPv:F3uMAHbG:slLyWXd;woxmBMTexUpY ZEP
tTk?BlWOP&ZP.zNS YjFV,OxrO?!$wNDsXCd;iM:c!elaw'uOPGCJJDBsSf,E.XguCoK-rJP-kybvHsxxwu,:i3UJgZbBMO;s:coPALGSTE-hJWOStcI3$VaeVYfJsTPqaqT-ebJqAWy
Ev:WFmCykXrvetkGbw-3-N!'oW
nKqi:FgOyU3XdQwNr gVItNvRo,JbtDAvcfHSKDkh.caNKrf CMrJIGs?lbiNDbgJg'cHB:rRwAuGq&UDPhOdnmc:&jU,ZCuG?mF.An-r,EMDfCHfITHsvztXPL U3iSE-dAsTxeqf??i
OUQfArTnZ.Hgv
CPU times: user 1.03 s, sys: 9 ms, total: 1.04 s
Wall time: 1.05 s


In [38]:
# Try speed up with stop gradient
t1=time.time()
print(decode(generate(jax.lax.stop_gradient(params), flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 100)[0].tolist()))
print('TIME total', time.time()-t1)


yD.P.e'wn,CZsvq gP-f$f&W3aypokkuSEz?Paw:YCj?M;x
pctpxMvdJMlTZrmCZhPRjYRJUfrgld,bqlwXxBlCHIWu'FYEBTwJ
TIME total 0.3211863040924072


#### put together

In [43]:
prng = jax.random.PRNGKey(1337)

# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 500
eval_interval = 100
learning_rate = 1e-3
eval_iters = 10
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0


# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = jnp.array(encode(text), dtype=jnp.int32)
data = device_put(data)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

def get_batch(split, subkey):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    t1 = time.time()
    ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)
    t2 = time.time()
    # x = jnp.stack([data[i:i+block_size] for i in ix])
    # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    def slice_data(i):
        return jax.lax.dynamic_slice(data, (i,), (block_size,))

    x = jax.vmap(slice_data)(ix)
    y = jax.vmap(slice_data)(ix+1)
    x, y = device_put(x), device_put(y)
    # print('TIME rand idx', t2-t1)
    # print('TIME rand idx fetch', time.time()-t2)
    return x, y

def estimate_loss(params, prng):
    out = {}
    for split in ['train', 'val']:
        losses = jnp.zeros(eval_iters)
        for k in range(eval_iters):
            prng, subkey = random.split(prng)
            X, Y = get_batch(split, subkey)
            logits, loss = model.apply(params, X, Y)
            losses = losses.at[k].set(loss)
        out[split] = losses.mean()
    return out

class BigramLanguageModel(nn.Module):
    vocab_size: int

    @nn.compact
    def __call__(self, idx, targets=None):
        # Token embedding table
        embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)
        logits = embedding_table(idx)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.reshape(B*T, C)
            targets = targets.reshape(B*T)
            loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()

        return logits, loss

    def generate(self, params, key, idx, max_new_tokens):
      for _ in range(max_new_tokens):
          logits, _ = self.apply(params, idx)
          logits = logits[:, -1, :]  # (B, C)
          key, subkey = random.split(key)
          idx_next = random.categorical(subkey, logits)[:, None]  # (B, 1)
          idx = jnp.concatenate((idx, idx_next), axis=1)  # (B, T+1)
      return idx

# Define the optimizer
tx = adam(learning_rate)

# Initialize model parameters and optimizer state
model = BigramLanguageModel(vocab_size)

# Initialize model parameters and optimizer
params = model.init(prng, jnp.ones((1, 1), jnp.int32))
opt_state = tx.init(params)

# Loss function (assuming you have a batch of data: xb, yb)
def loss_fn(params, xb, yb):
    logits, loss = model.apply(params, xb, yb)
    return loss

# Update function for a single training step
@jax.jit
def update_step(params, opt_state, xb, yb):
    loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb)
    updates, opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Training loop (example)
batch_size = 32
t = time.time()
for steps in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if steps == max_iters - 1 or steps % eval_interval == 0:
      losses = estimate_loss(params, prng)
      print(f"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    prng, subkey = random.split(prng)
    xb, yb = get_batch('train', subkey)
    t2 = time.time()
    params, opt_state, loss = update_step(params, opt_state, xb, yb)
    print(f"Epoch {steps},Loss: {loss}")
print('TIME jax', time.time()-t)

step 0: train loss 4.1747, val loss 4.1776
Epoch 0,Loss: 4.176701545715332
Epoch 1,Loss: 4.168148040771484
Epoch 2,Loss: 4.176245212554932
Epoch 3,Loss: 4.169358730316162
Epoch 4,Loss: 4.175686836242676
Epoch 5,Loss: 4.170393943786621
Epoch 6,Loss: 4.169349193572998
Epoch 7,Loss: 4.162716388702393
Epoch 8,Loss: 4.157233238220215
Epoch 9,Loss: 4.163629055023193
Epoch 10,Loss: 4.16217565536499
Epoch 11,Loss: 4.1645026206970215
Epoch 12,Loss: 4.1598124504089355
Epoch 13,Loss: 4.157349586486816
Epoch 14,Loss: 4.163051128387451
Epoch 15,Loss: 4.15546178817749
Epoch 16,Loss: 4.154058456420898
Epoch 17,Loss: 4.14748477935791
Epoch 18,Loss: 4.153148174285889
Epoch 19,Loss: 4.15305757522583
Epoch 20,Loss: 4.148075580596924
Epoch 21,Loss: 4.1545729637146
Epoch 22,Loss: 4.1454854011535645
Epoch 23,Loss: 4.146124839782715
Epoch 24,Loss: 4.152686595916748
Epoch 25,Loss: 4.148232460021973
Epoch 26,Loss: 4.143317222595215
Epoch 27,Loss: 4.135959625244141
Epoch 28,Loss: 4.133705139160156
Epoch 29,Loss

In [49]:
flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(params, xb, yb))

In [54]:
def generate(params, flax_apply_jitted, key, idx, max_new_tokens):
  for _ in range(max_new_tokens):
      logits, _ = flax_apply_jitted(params, idx, None)
      logits = logits[:, -1, :]  # (B, C)
      key, subkey = random.split(key)
      idx_next = random.categorical(subkey, logits)[:, None]  # (B, 1)
      idx = jnp.concatenate((idx, idx_next), axis=1)  # (B, T+1)
  return idx

%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, 1), jnp.int32), 500)[0].tolist()))


yD.P.e'wn,CZsvq gPrf$f&W3aypokkuSEz?Paw:YCj?M;x
pctexMadJMlTZr,CyhaRoYRJUfrsld,bqlwXxclCHIWu'FYEBldJrby;b!HR'Frcr,?&Nui3;woGFdW psZYhosYO!hHPv:F3uMAHbGoslLIWXd;woxmBMTe UpY ZEP
tTk?BlWOPrZP.zNS pjFR,OxrO?!$wNDsXCd;il:c!'lal'uOPGCJeDusSf,E.XgunoK-rJP-ky oHsxxwu,:i3UJgZbBMO;s:
oPALGSTE-heWO,tcI3$VaeVY JsTPqaqT-ebedAWhoEv:WFiCykXrvetkGbw'3-N!'oW
n
qi:FgOyU3Xd wrr gVItNvRo,JbtDAvcfHSKDWh.caNKrf CMr IGs?lbiNDerJg'cHB:rRwAuGq&UDUhOdnmc:&jUSZCuG?mF.An--,EMDfCHfITHs ztXPL U3iSE--AsTxeqf??imOUQfArTnZ.Hgv
CPU times: user 1.02 s, sys: 7.95 ms, total: 1.03 s
Wall time: 1.03 s
