<a href="https://colab.research.google.com/github/imxj/imxj.github.io/blob/master/colabs/llms/jax_gpt_dev_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Changed Andrej Karpathy's minGPT from pytorch to Jax code (see [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT).

## import

In [8]:
import time
import jax
from jax import device_put
import jax.numpy as jnp
from jax import lax
import jax.random as random
import jax.nn as jnn
from jax.nn.initializers import normal

import flax.linen as nn

import optax
from optax import adam

## Building a GPT

### load data

In [9]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-04-03 21:22:52--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-03 21:22:53 (31.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [10]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [11]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [12]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [13]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [14]:
'!' in chars

True

In [15]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [16]:
# # let's now encode the entire text dataset and store it into a torch.Tensor
# import torch # we use PyTorch: https://pytorch.org
# data = torch.tensor(encode(text), dtype=torch.long)
# print(data.shape, data.dtype)
# print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

# Assuming `encode` is a function defined elsewhere
data = jnp.array(encode(text), dtype=jnp.int64)
print(data.shape, data.dtype)
print(data[:1000])

  data = jnp.array(encode(text), dtype=jnp.int64)


(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

#### split into train and test

In [17]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [18]:
val_data.shape

(111540,)

In [19]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

#### build feature input x and target output y

In [20]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    print(t)
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

0
when input is [18] the target: 47
1
when input is [18 47] the target: 56
2
when input is [18 47 56] the target: 57
3
when input is [18 47 56 57] the target: 58
4
when input is [18 47 56 57 58] the target: 1
5
when input is [18 47 56 57 58  1] the target: 15
6
when input is [18 47 56 57 58  1 15] the target: 47
7
when input is [18 47 56 57 58  1 15 47] the target: 58


In [21]:
prng = jax.random.PRNGKey(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
ix = random.randint(random.PRNGKey(0), (batch_size,), 0, len(data) - block_size)

def get_batch(split, subkey, batch_size=4, block_size=8):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = random.randint(subkey, (batch_size,), 0, len(data) - block_size)
    # x = jnp.stack([data[i:i+block_size] for i in ix])
    # y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    # optimize the above code ^^.
    # speed up by using dynamic_slice and vmap
    def slice_data(i):
        return jax.lax.dynamic_slice(data, (i,), (block_size,))

    x = jax.vmap(slice_data)(ix)
    y = jax.vmap(slice_data)(ix+1)
    x, y = device_put(x), device_put(y)
    return x, y

xb, yb = get_batch('train', prng)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
(4, 8)
[[ 1 51 39 52 58 50 43  0]
 [25 17 27 10  0 27  6  1]
 [47 51  8  0 14 59 58  1]
 [ 1 57 59 41 46  1 50 43]]
targets:
(4, 8)
[[51 39 52 58 50 43  0 53]
 [17 27 10  0 27  6  1 50]
 [51  8  0 14 59 58  1 46]
 [57 59 41 46  1 50 43 52]]
----
when input is [1] the target: 51
when input is [1, 51] the target: 39
when input is [1, 51, 39] the target: 52
when input is [1, 51, 39, 52] the target: 58
when input is [1, 51, 39, 52, 58] the target: 50
when input is [1, 51, 39, 52, 58, 50] the target: 43
when input is [1, 51, 39, 52, 58, 50, 43] the target: 0
when input is [1, 51, 39, 52, 58, 50, 43, 0] the target: 53
when input is [25] the target: 17
when input is [25, 17] the target: 27
when input is [25, 17, 27] the target: 10
when input is [25, 17, 27, 10] the target: 0
when input is [25, 17, 27, 10, 0] the target: 27
when input is [25, 17, 27, 10, 0, 27] the target: 6
when input is [25, 17, 27, 10, 0, 27, 6] the target: 1
when input is [25, 17, 27, 10, 0, 27, 6, 1] the target: 5

In [22]:
print(xb) # our input to the transformer

[[ 1 51 39 52 58 50 43  0]
 [25 17 27 10  0 27  6  1]
 [47 51  8  0 14 59 58  1]
 [ 1 57 59 41 46  1 50 43]]


## The mathematical trick in 'causal' self-attention

In [23]:
import jax
import jax.numpy as jnp
from jax import random

# set the random key and seed
key = random.PRNGKey(42)

# create a lower triangular matrix
a = jnp.tril(jnp.ones((3, 3)))
print(a)
a = a / jnp.sum(a, axis=1, keepdims=True)
print(a)
# create a random matrix
b = random.randint(key, (3, 2), 0, 10, dtype=jnp.int32)

# perform the matrix multiplication
c = jnp.matmul(a, b)

# print the matrices and the result
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

[[1. 0. 0.]
 [1. 1. 0.]
 [1. 1. 1.]]
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
a=
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
--
b=
[[2 3]
 [9 9]
 [4 6]]
--
c=
[[2.  3. ]
 [5.5 6. ]
 [5.  6. ]]


In [24]:
# We want x[b,t] = mean_{i<=t} x[b,i]
# Initialize xbow tensor
# Set the random key and seed
rng = random.PRNGKey(1337)
B,T,C = 4,8,2 # batch, time, channels
x = random.normal(rng, (B, T, C))
xbow = jnp.zeros((B, T, C))

# Loop over batch and time dimensions
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1]  # (t, C)
        mean_xprev = jnp.mean(xprev, axis=0)
        xbow = xbow.at[b, t].set(mean_xprev)

# Print the result
print(xbow)

[[[ 1.3654243  -1.3698599 ]
  [ 1.8501079  -1.18029   ]
  [ 1.2796469  -1.1314313 ]
  [ 1.2234701  -0.95123225]
  [ 1.0816249  -0.8464665 ]
  [ 0.9756752  -0.40811488]
  [ 1.089724   -0.345884  ]
  [ 1.0498141  -0.29275966]]

 [[-0.02413952  1.4920624 ]
  [-0.37343726  1.7603257 ]
  [-0.48404503  1.4237347 ]
  [-0.18112068  1.2025388 ]
  [-0.28402337  0.9092719 ]
  [-0.21978617  0.8809384 ]
  [-0.3614355   1.045876  ]
  [-0.27126977  0.8439261 ]]

 [[ 0.67737067  0.45489657]
  [-0.02651259  0.6725235 ]
  [ 0.3897322   0.32223004]
  [-0.05313486  0.49777415]
  [-0.09553531  0.17514853]
  [ 0.09583326  0.24674283]
  [ 0.17541157  0.14904366]
  [ 0.36874425 -0.14085239]]

 [[-1.1729926  -1.0436211 ]
  [-0.76886547 -0.83241093]
  [-0.1617876  -0.5863479 ]
  [-0.23202893 -0.7893201 ]
  [-0.09417699 -0.3836251 ]
  [-0.03415637 -0.4088895 ]
  [-0.10969827 -0.2769311 ]
  [ 0.08133804 -0.42021018]]]


##### version 2: using matrix multiply for a weighted aggregation


In [25]:
# version 2: using matrix multiply for a weighted aggregation

# Create a lower triangular matrix with normalized rows
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / jnp.sum(wei, axis=1, keepdims=True)

# Perform the weighted aggregation using matrix multiplication
xbow2 = jnp.matmul(wei, x)  # (T, T) @ (B, T, C) ----> (B, T, C)

# Check if the results are close
all_close = jnp.allclose(xbow, xbow2)
print(all_close)

True


In [26]:
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / jnp.sum(wei, axis=1, keepdims=True)
wei

Array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      , 0.        ,
        0.        , 0.        , 0.        ],
       [0.2       , 0.2       , 0.2       , 0.2       , 0.2       ,
        0.        , 0.        , 0.        ],
       [0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
        0.16666667, 0.        , 0.        ],
       [0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,
        0.14285715, 0.14285715, 0.        ],
       [0.125     , 0.125     , 0.125     , 0.125     , 0.125     ,
        0.125     , 0.125     , 0.125     ]], dtype=float32)

##### version 3: use Softmax


In [27]:
# version 3: use Softmax
from jax.nn import softmax

# Create a lower triangular mask
tril = jnp.tril(jnp.ones((T, T)))

# Create a mask filled with negative infinity where tril is 0
wei = jnp.zeros((T, T))
wei = jnp.where(tril == 0, -jnp.inf, wei)

# Apply softmax along the last dimension
wei = jnn.softmax(wei, axis=-1)

# Perform the weighted aggregation using matrix multiplication
xbow3 = jnp.matmul(wei, x)

# Check if the results are close
all_close = jnp.allclose(xbow, xbow3)
print(all_close)


True


#### causal self attention

In [28]:
# Set the random key and seed
rng = random.PRNGKey(1337)

# Create a random input tensor
B,T,C = 4,8,32 # batch, time, channels
x = random.normal(rng, (B, T, C))

# Define the head size for self-attention
head_size = 16

# Define the linear layers for key, query, and value
key = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())
query = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())
value = nn.Dense(head_size, kernel_init=nn.initializers.glorot_normal())

# Compute the key, query, and value projections
k_variables = key.init(rng, x)
print(x.shape, k_variables['params']['kernel'].shape)
k = key.apply(k_variables, x) # (B, T, 16)
print(k.shape)

q_variables = query.init(rng, x)
q = query.apply(q_variables, x) # (B, T, 16)
print(q.shape)

# Compute the attention weights
wei = jnp.matmul(q, k.transpose((0, 2, 1)))  # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
# Create a lower triangular mask
tril = jnp.tril(jnp.ones((T, T)))

# Apply the mask and then softmax along the last dimension
wei = jnp.where(tril == 0, -jnp.inf, wei)
wei = nn.softmax(wei, axis=-1)
print(wei.shape, wei[0].shape)

# Compute the output using the attention weights and the value projection
v_variables = value.init(rng, x)
v = value.apply(v_variables, x)
out = jnp.matmul(wei, v)

# Print the shape of the output
print(out.shape)


(4, 8, 32) (32, 16)
(4, 8, 16)
(4, 8, 16)
(4, 8, 8) (8, 8)
(4, 8, 16)


In [29]:
wei[0], wei.shape

(Array([[1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [3.02152942e-10, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [4.84214890e-11, 9.70194233e-05, 9.99902964e-01, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [3.24488383e-06, 1.23785867e-04, 3.93391110e-06, 9.99868989e-01,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.25760309e-08, 1.02552214e-07, 2.14070397e-08, 4.20416164e-08,
         9.99999881e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.27010195e-16, 9.92109958e-20, 2.51674267e-16, 2.21403877e-16,
         9.01286152e-18, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [4.48613235e-09, 1.74030134e-07, 2.45671767e-08, 7.91097576e-10,
         8.77842105e-11, 3.65485463e-12, 9.99

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [30]:

# Create random key
key = random.PRNGKey(1337)

# Create random query and key tensors
q = random.normal(key, (B, T, head_size))
k = random.normal(key, (B, T, head_size))

# Compute the dot product between q and k and scale by head_size^-0.5
wei = jnp.matmul(q, k.transpose((0, 2, 1))) * head_size**-0.5

# Print the shape of the output
print(wei.shape)

(4, 8, 8)


In [31]:
k.var()

Array(0.9194436, dtype=float32)

In [32]:
q.var()

Array(0.9194436, dtype=float32)

In [33]:
wei.var()

Array(2.3990507, dtype=float32)

In [34]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5]))

Array([0.19249782, 0.1426059 , 0.23511738, 0.1426059 , 0.287173  ],      dtype=float32)

In [35]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5])*8) # gets too peaky, converges to one-hot

Array([0.03260834, 0.00295816, 0.1615102 , 0.00295816, 0.79996514],      dtype=float32)

## GPT

In [36]:
class Head(nn.Module):
    head_size: int
    dropout: float

    @nn.compact
    def __call__(self, x, training: bool):
        B, T, C = x.shape

        # Key query embeding
        key = nn.Dense(self.head_size, use_bias=False)(x)
        query = nn.Dense(self.head_size, use_bias=False)(x)

        # Attention weights
        wei = jnp.matmul(query, key.transpose((0, 2, 1))) * C **(-0.5)

        # Apply the mask and then softmax along the last dimension
        tril = jnp.tril(jnp.ones((T, T)))
        wei = jnp.where(tril == 0, -jnp.inf, wei)
        wei = nn.softmax(wei, axis=-1)

        # Dropout (if you have dropout implemented in Flax)
        wei = nn.Dropout(rate=self.dropout)(wei, deterministic=not training)
        # value
        value = nn.Dense(self.head_size, use_bias=False)(x)

        out = jnp.matmul(wei, value)
        return out

In [37]:
x = jnp.ones((1, 2, 3))
out = nn.Dense(x, 50)
print(out)

Dense(
    # attributes
    features = Array([[[1., 1., 1.],
            [1., 1., 1.]]], dtype=float32)
    use_bias = 50
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = None
    dot_general_cls = None
)


In [38]:
class MultiHead(nn.Module):
    num_heads: int
    head_size: int
    dropout: float

    @nn.compact
    def __call__(self, x, training=False):
        n_embd = self.head_size * self.num_heads
        out = jnp.concatenate([Head(head_size=self.head_size, dropout=self.dropout)(x, training) for _ in range(self.num_heads)], axis=-1)
        out = nn.Dense(n_embd)(out)
        out = nn.Dropout(rate=self.dropout)(out, deterministic=not training)
        return out

In [39]:
class FeedForward(nn.Module):
    n_embd: int
    dropout: float

    @nn.compact
    def __call__(self, x, training=False):
        out = nn.Dense(4 * self.n_embd)(x)
        out = nn.relu(out)
        out = nn.Dense(self.n_embd)(out)
        out = nn.Dropout(rate=self.dropout)(out, deterministic=not training)
        return out

In [40]:
class Block(nn.Module):
    n_embd: int
    n_head: int
    dropout: float

    @nn.compact
    def __call__(self, x, training=False):
        llm_head = MultiHead(num_heads=self.n_head, head_size=self.n_embd // self.n_head, dropout=self.dropout)
        ffd = FeedForward(n_embd=self.n_embd, dropout=self.dropout)
        ln1 = nn.LayerNorm()
        ln2 = nn.LayerNorm()
        x = x + llm_head(ln1(x), training=training) # (B, T, C)
        x = x + ffd(ln2(x), training=training) # (B, T, C)
        return x

In [41]:
# Set the random key and seed
rng = random.PRNGKey(1337)

# Create a random input tensor
B,T,C = 4,8,32 # batch, time, channels
x = random.normal(rng, (B, T, C))

model = Head(head_size=2, dropout=0.2)
variables = model.init(rng, x, training=False)
print(variables['params']['Dense_0']['kernel'].shape)

(32, 2)


In [42]:
class GPTModel(nn.Module):
    vocab_size: int
    n_embd: int
    n_head: int
    block_size: int
    dropout: float

    @nn.compact
    def __call__(self, idx, targets=None, training=False):
        # Token embedding table
        embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embd)
        position_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embd)
        blocks = nn.Sequential([Block(n_embd=self.n_embd, n_head=self.n_head, dropout=self.dropout) for _ in range(n_layer)])
        lm_head = nn.Dense(self.vocab_size)
        ln1 = nn.LayerNorm()

        B, T = idx.shape
        tok_emb = embedding_table(idx)  # (B,T,C)
        pos_emb = position_embedding_table(jnp.arange(T))
        x = tok_emb + pos_emb # (B, T, C)
        x = blocks(x, training=training) # (B, T, C)
        x = ln1(x)
        logits = lm_head(x) # B, T, vocab_size
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.reshape(B*T, C)
            targets = targets.reshape(B*T)
            loss = -jnp.sum(jax.nn.one_hot(targets, C) * jax.nn.log_softmax(logits), axis=1).mean()

        return logits, loss

In [43]:
# Example usage
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 10
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.2

print('vocab_size', vocab_size)
model = GPTModel(vocab_size=vocab_size, block_size=block_size, n_embd=n_embd, n_head=n_head, dropout=dropout)

# Initialize model parameters and optimizer
key = random.PRNGKey(1337)
main_key, dropout_key = random.split(key)
params = model.init(main_key, jnp.zeros((1, block_size), jnp.int32))

# Create a random input tensor
print(xb.shape, yb.shape)
logits, loss = model.apply(params, xb, yb, rngs={'dropout': dropout_key})
print('loss', loss)

vocab_size 65
(4, 8) (4, 8)
loss 4.7885876


In [44]:
# # Final hparams
# # hyperparameters
# batch_size = 64 # how many independent sequences will we process in parallel?
# block_size = 256 # what is the maximum context length for predictions?
# max_iters = 5000
# eval_interval = 500
# learning_rate = 3e-4
# eval_iters = 10
# n_embd = 384
# n_head = 6
# n_layer = 6
# dropout = 0.2

# print('vocab_size', vocab_size)
# model = GPTModel(vocab_size=vocab_size, block_size=block_size, n_embd=n_embd, n_head=n_head, dropout=dropout)

# # Initialize model parameters and optimizer
# key = random.PRNGKey(1337)
# main_key, dropout_key = random.split(key)
# params = model.init(main_key, jnp.zeros((1, block_size), jnp.int32))

# # Create a random input tensor
# print(xb.shape, yb.shape)
# logits, loss = model.apply(params, xb, yb, rngs={'dropout': dropout_key})
# print('loss', loss)

In [45]:
# Define the optimizer
tx = adam(learning_rate)

# Initialize model parameters and optimizer state
prng = random.PRNGKey(1337)
params = model.init(prng, jnp.ones((1, 1), jnp.int32))
opt_state = tx.init(params)

# Loss function (assuming you have a batch of data: xb, yb)
def loss_fn(params, xb, yb, dropout_key):
    logits, loss = model.apply(params, xb, yb, training=True, rngs={'dropout': dropout_key})
    return loss

# Update function for a single training step
@jax.jit
def update_step(params, opt_state, xb, yb, dropout_key):
    loss, grads = jax.value_and_grad(loss_fn)(params, xb, yb, dropout_key)
    updates, opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

@jax.jit
def estimate_loss(params, prng):
    out = {}
    for split in ['train', 'val']:
        losses = jnp.zeros(eval_iters)
        for k in range(eval_iters):
            prng, subkey = random.split(prng)
            X, Y = get_batch(split, subkey)
            logits, loss = model.apply(params, X, Y, training=False)
            losses = losses.at[k].set(loss)
        out[split] = losses.mean()
    return out

# Training loop (example)
for steps in range(max_iters):
    if steps == max_iters - 1: # or steps % eval_interval == 0:
      losses = estimate_loss(params, prng)
      print(f"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    prng, subkey, dropout_key = random.split(prng, 3)
    xb, yb = get_batch('train', subkey, batch_size=batch_size, block_size=block_size)

    params, opt_state, loss = update_step(params, opt_state, xb, yb, dropout_key)
    print(f"Epoch {steps},Loss: {loss}")

flax_apply_jitted = jax.jit(lambda params, xb, yb: model.apply(jax.lax.stop_gradient(params), xb, yb))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 1,Loss: 4.014387607574463
Epoch 2,Loss: 3.8010220527648926
Epoch 3,Loss: 3.618058681488037
Epoch 4,Loss: 3.5852298736572266
Epoch 5,Loss: 3.5638856887817383
Epoch 6,Loss: 3.412853956222534
Epoch 7,Loss: 3.4588723182678223
Epoch 8,Loss: 3.4619557857513428
Epoch 9,Loss: 3.30143404006958
Epoch 10,Loss: 3.3536880016326904
Epoch 11,Loss: 3.401754856109619
Epoch 12,Loss: 3.553654909133911
Epoch 13,Loss: 3.340710401535034
Epoch 14,Loss: 3.2985382080078125
Epoch 15,Loss: 3.4168903827667236
Epoch 16,Loss: 3.31078839302063
Epoch 17,Loss: 3.326681613922119
Epoch 18,Loss: 3.472869396209717
Epoch 19,Loss: 3.2958321571350098
Epoch 20,Loss: 3.330526113510132
Epoch 21,Loss: 3.316978931427002
Epoch 22,Loss: 3.4605517387390137
Epoch 23,Loss: 3.2397422790527344
Epoch 24,Loss: 3.298964023590088
Epoch 25,Loss: 3.2691707611083984
Epoch 26,Loss: 3.4556989669799805
Epoch 27,Loss: 3.4703245162963867
Epoch 28,Loss: 3.2902615070343018
Epoch 2

In [46]:
print(batch_size, block_size)

16 32


In [47]:
def generate(params, flax_apply_jitted, key, idx, max_new_tokens):
  for _ in range(max_new_tokens):
      idx_crop = idx[:, -block_size:]
      logits, _ = flax_apply_jitted(params, idx_crop, None)
      logits = logits[:, -1, :]  # (B, C)
      key, subkey = random.split(key)
      idx_next = random.categorical(subkey, logits)[:, None]  # (B, 1)
      idx = jnp.concatenate((idx, idx_next), axis=1)  # (B, T+1)
  return idx

%time print(decode(generate(params, flax_apply_jitted, key, jnp.zeros((1, block_size), jnp.int32), 1000)[0][block_size-1:].tolist()))


MAMENERENERE:
Will the that of stroctage
Clainderces our him.

CORIRDY:
UnEslo, his in the untown word, by madrard
Sucion: God I see their,Haven, who I slays dewould we daught:
that was my have play, this now so dail cairall O,
Four so for tunder.

CICINIUS:
Ay, in, he a the our nevery, ratch and you spidater be thy no:
For knower anwill thou
night any a down your hadought.

LEONTES:
I nor that I straiget Jonce: by aut you hodnour much uptay.

GREET:
Comey some so sried of The fair
Our And a
Hith with gove singballet pray cy gonder's
Have to would and courd for rove!

iSOMO:
Aand aliest ard it:
This out; you and thou ammansts live this this to thou a wn him right you,
'Sparrom my leice your grace, onell trock hath what the words
In that indened their spark by thine welle:
If know I loy; let maints troducused thed maan that
sir mading tent! groop. I have with you to my crook
Cavance me ler, of I'll hath it,
First than the palouty cicirts:
Nor nor graver to her stord.

Should I good tho