## Building a GPT

Links:
- Let's build GPT: from scratch, in code, spelled out.: https://www.youtube.com/watch?v=kCc8FmEb1nY
- nanoGPT GitHub repository: https://github.com/karpathy/nanoGPT
- Attention Is All You Need paper: https://arxiv.org/abs/1706.03762
- GPT-3 paper: https://arxiv.org/abs/2005.14165

In [3]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!mkdir -p /project/inputs/
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O /project/inputs/input.txt

--2023-04-25 15:44:25--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘/project/inputs/input.txt’


2023-04-25 15:44:26 (33.7 MB/s) - ‘/project/inputs/input.txt’ saved [1115394/1115394]



In [1]:
# read it in to inspect it
with open("/project/inputs/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [2]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [3]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import jax # we use Jax
import jax.numpy as jnp
data = jnp.array(encode(text))
print(data.shape, data.dtype)
print(data[:1000])

(1115394,) int32
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

In [7]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [8]:
block_size = 8
train_data[:block_size+1]

DeviceArray([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [9]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [10]:
random_key = jax.random.PRNGKey(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix+1, (block_size,))
    return x, y

random_key, random_subkey = jax.random.split(random_key)
xb, yb = get_batch(random_subkey, train_data)
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("----")

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
(4, 8)
[[53 60 53 57 58  6  1 50]
 [ 1 39 52 42  1 42 43 39]
 [46 47 52 45  1 46 43 56]
 [63  1 40 56 53 58 46 43]]
targets:
(4, 8)
[[60 53 57 58  6  1 50 43]
 [39 52 42  1 42 43 39 58]
 [47 52 45  1 46 43 56  1]
 [ 1 40 56 53 58 46 43 56]]
----
when input is [53] the target: 60
when input is [53, 60] the target: 53
when input is [53, 60, 53] the target: 57
when input is [53, 60, 53, 57] the target: 58
when input is [53, 60, 53, 57, 58] the target: 6
when input is [53, 60, 53, 57, 58, 6] the target: 1
when input is [53, 60, 53, 57, 58, 6, 1] the target: 50
when input is [53, 60, 53, 57, 58, 6, 1, 50] the target: 43
when input is [1] the target: 39
when input is [1, 39] the target: 52
when input is [1, 39, 52] the target: 42
when input is [1, 39, 52, 42] the target: 1
when input is [1, 39, 52, 42, 1] the target: 42
when input is [1, 39, 52, 42, 1, 42] the target: 43
when input is [1, 39, 52, 42, 1, 42, 43] the target: 39
when input is [1, 39, 52, 42, 1, 42, 43, 39] the target: 5

In [11]:
import jax
import flax.linen as nn
import optax


class BigramLanguageModel(nn.Module):
    @nn.compact
    def __call__(self, idx):
        return nn.Embed(vocab_size, vocab_size)(idx)
    
    def generate(self, random_key, params, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits = self.apply(params, idx[:, -1])
            # sample from the distribution
            random_key, random_subkey = jax.random.split(random_key)
            idx_next = jax.random.categorical(random_subkey, logits, axis=-1) # (B, 1)
            # append sampled index to the running sequence
            idx = jnp.concatenate((idx, idx_next.reshape(logits.shape[0], -1)), axis=1) # (B, T+1)
        return idx

m = BigramLanguageModel()
random_key, random_subkey = jax.random.split(random_key)
params = m.init(random_subkey, idx=xb)

logits = m.apply(params, xb)
labels = jax.nn.one_hot(yb, vocab_size)
print(logits.shape)
loss = jnp.mean(optax.softmax_cross_entropy(logits, labels))
print(loss)

random_key, random_subkey = jax.random.split(random_key)
print(decode(m.generate(random_subkey, params, idx=jnp.zeros((1, 1), dtype=jnp.int32), max_new_tokens=100)[0].tolist()))


(4, 8, 65)
4.1723547

 'Q?'NqOzwr;lDvEnA!pGwNzOJ3AZ.?.ulxrvDENapoBYpWs- EYMwPkx o.aQeXddXQmSsoUaQha.WBxD$-3O
K$E au-dmbFq3


In [12]:
batch_size = 32
@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix+1, (block_size,))
    return x, y

@jax.jit
def cross_entropy_loss(params, xb, yb):
    logits = m.apply(params, xb)
    one_hot_encoded_labels = jax.nn.one_hot(yb, num_classes=vocab_size)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

# create a PyTorch optimizer
optimizer = optax.adam(learning_rate=1e-3)
optimizer_state = optimizer.init(params)

In [13]:
for steps in range(10000): # increase number of steps for good results... 
    # sample a batch of data
    random_key, random_subkey = jax.random.split(random_key)
    xb, yb = get_batch(random_subkey, train_data)

    # evaluate the loss
    loss, grad = jax.value_and_grad(cross_entropy_loss)(params, xb, yb)

    # update params
    update, optimizer_state = optimizer.update(
        grad, optimizer_state
    )
    params = optax.apply_updates(params, update)

print(loss.item())


2.499870777130127


In [14]:
print(decode(m.generate(random_subkey, params, idx=jnp.zeros((1, 1), dtype=jnp.int32), max_new_tokens=500)[0].tolist()))


An:
INIINThy We an.
S: o. ald tccrus,
ONIVIO:

Bous.
Wrothay, whemy y thowood it maithothe:
Alilecorcchay vik:
ireraithot t
Sthayo aind hes,
Theens han felithe bok h med:
K:
UENor'd Ginge bather bl din reanof peled theno d, kent

And g be, My y aur sellig tea: hinonghybe ty husthit, o uld ony iale V! uf.

TESThyovit bafuco, ndicr theate't ant, dofe canghe, aldd t b y my;


Thithr
Toonguche, I aranodangothinoul t coondongou! es selend tswshols s dillatoseve!
Amatha?

Bopr prtce Herd fortthous ied


## The mathematical trick in self-attention

In [15]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
random_key, random_subkey = jax.random.split(random_key)
a = jnp.tril(jnp.ones((3, 3)))
a = a / jnp.sum(a, 1, keepdims=True)
b = jax.random.randint(random_subkey, (3, 2), 0, 10)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
[[1.         0.         0.        ]
 [0.5        0.5        0.        ]
 [0.33333334 0.33333334 0.33333334]]
--
b=
[[4 2]
 [9 9]
 [0 5]]
--
c=
[[4.        2.       ]
 [6.5       5.5      ]
 [4.3333335 5.3333335]]


In [16]:
# consider the following toy example:
B, T, C = 4, 8, 2 # batch, time, channels
random_key, random_subkey = jax.random.split(random_key)
x = jax.random.normal(random_subkey, (B, T, C))
x.shape

(4, 8, 2)

In [17]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = jnp.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow = xbow.at[b, t].set(jnp.mean(xprev, 0))

In [18]:
# version 2: using matrix multiply for a weighted aggregation
wei = jnp.tril(jnp.ones((T, T)))
wei = wei / wei.sum(1, keepdims=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
jnp.allclose(xbow, xbow2)

DeviceArray(True, dtype=bool)

In [19]:
tril = jnp.tril(jnp.ones((T, T)))
nn.softmax(jnp.where(tril == 0, -jnp.inf, 0.), axis=-1)

DeviceArray([[1.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        ],
             [0.5       , 0.5       , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        ],
             [0.33333334, 0.33333334, 0.33333334, 0.        , 0.        ,
              0.        , 0.        , 0.        ],
             [0.25      , 0.25      , 0.25      , 0.25      , 0.        ,
              0.        , 0.        , 0.        ],
             [0.2       , 0.2       , 0.2       , 0.2       , 0.2       ,
              0.        , 0.        , 0.        ],
             [0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
              0.16666667, 0.        , 0.        ],
             [0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,
              0.14285715, 0.14285715, 0.        ],
             [0.125     , 0.125     , 0.125     , 0.125     , 0.125     ,
              0.125     , 0.125     , 0.125     ]],

In [20]:
# version 3: use Softmax
tril = jnp.tril(jnp.ones((T, T)))
wei = nn.softmax(jnp.where(tril == 0, -jnp.inf, 0.), axis=-1)
xbow3 = wei @ x
jnp.allclose(xbow, xbow3)

DeviceArray(True, dtype=bool)

In [21]:
# version 4: self-attention!
B, T, C = 4, 8, 32 # batch, time, channels
random_key, random_subkey = jax.random.split(random_key)
x = jax.random.normal(random_subkey, (B, T, C))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Dense(head_size, use_bias=False)
query = nn.Dense(head_size, use_bias=False)
value = nn.Dense(head_size, use_bias=False)

# Key
random_key, random_subkey = jax.random.split(random_key)
params_key = key.init(random_subkey, x)
k = key.apply(params_key, x) # (B, T, 16)

# Query
random_key, random_subkey = jax.random.split(random_key)
params_query = query.init(random_subkey, x)
q = query.apply(params_query, x) # (B, T, 16)
wei =  q @ jnp.transpose(k, axes=(0, 2, 1)) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = jnp.tril(jnp.ones((T, T)))
wei = nn.softmax(jnp.where(tril == 0, -jnp.inf, wei), axis=-1)

# Value
random_key, random_subkey = jax.random.split(random_key)
params_value = value.init(random_subkey, x)
v = value.apply(params_value, x)
out = wei @ v

out.shape

(4, 8, 16)

In [22]:
wei[0]

DeviceArray([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
             [3.5935980e-01, 6.4064020e-01, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
             [8.5110113e-02, 4.4948000e-02, 8.6994189e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
             [9.3279088e-01, 6.6124596e-02, 3.7592615e-06, 1.0807238e-03,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
             [1.2928504e-02, 9.0569223e-04, 3.0436894e-04, 1.8784864e-01,
              7.9801279e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
             [1.7108712e-05, 1.9536851e-10, 3.1041036e-05, 8.6323842e-02,
              9.1347349e-01, 1.5456323e-04, 0.0000000e+00, 0.0000000e+00],
             [3.5270318e-04, 1.7646615e-05, 1.3294152e-04, 7.8084207e-01,
              1.7192341e-01, 3.5

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [23]:
random_key, random_subkey = jax.random.split(random_key)
k = jax.random.normal(random_subkey, (B, T, head_size))
random_key, random_subkey = jax.random.split(random_key)
q = jax.random.normal(random_subkey, (B, T, head_size))
wei = q @ jnp.transpose(k, axes=(0, 2, 1)) * head_size**-0.5

In [24]:
k.var()

DeviceArray(1.0069916, dtype=float32)

In [25]:
q.var()

DeviceArray(1.0187279, dtype=float32)

In [26]:
wei.var()

DeviceArray(0.88002664, dtype=float32)

In [27]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5]), axis=-1)

DeviceArray([0.19249782, 0.1426059 , 0.23511738, 0.1426059 , 0.287173  ],            dtype=float32)

In [28]:
nn.softmax(jnp.array([0.1, -0.2, 0.3, -0.2, 0.5])*8, axis=-1) # gets too peaky, converges to one-hot

DeviceArray([0.03260834, 0.00295816, 0.1615102 , 0.00295816, 0.79996514],            dtype=float32)

In [29]:
class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    reduction_axes = -1

    @nn.compact
    def __call__(self, x):
        """Applies layer normalization on the input."""
        # compute statistics
        mean2 = jnp.mean(jax.lax.square(x), self.reduction_axes, keepdims=True)
        mean = jnp.mean(x, self.reduction_axes, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        # compute normalized inputs
        x_norm = (x - mean) * jax.lax.rsqrt(var + self.epsilon)
        return x_norm * self.param("scale", nn.initializers.ones, x.shape[-1]) + self.param("bias", nn.initializers.zeros, x.shape[-1])

random_key, random_subkey = jax.random.split(random_key)
module = LayerNorm()
x = jax.random.normal(random_subkey, (32, 100))
params = module.init(random_subkey, x)
x = module.apply(params, x)
x.shape

(32, 100)

In [30]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs

(DeviceArray(-0.05891827, dtype=float32),
 DeviceArray(1.0908911, dtype=float32))

In [31]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features

(DeviceArray(-1.1920929e-09, dtype=float32),
 DeviceArray(0.9999992, dtype=float32))

In [None]:
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>



### Full finished code, for reference

Please, refer to `nanoGPT_jax.py` script for the training loop.