<a href="https://colab.research.google.com/github/heerboi/AI-from-scratch/blob/main/gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Following Andrej's video: https://www.youtube.com/watch?v=kCc8FmEb1nY

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt

: 

In [2]:
with open('input.txt', 'r', encoding="utf-8") as f:
    text = f.read()

In [3]:
text[:100]

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [5]:
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

print(encode("Hii"))
print(decode(encode("Hii")))

In [6]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

In [7]:
split = int(0.9*len(data))
train_data = data[:split]
val_data = data[split:]
print(len(train_data))
print(len(val_data))

In [8]:
#context length

block_size = 8
train_data[:block_size+1]

In [9]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(context, target)

In [10]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):

    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

x, y = get_batch("train")
print(x)
print(y)

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        self.token_embedding_table = nn.Embedding(num_embeddings = vocab_size, embedding_dim = vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            # last time step for each batch and include all embeddings
            logits = logits[:, -1, :]

            probabilities = F.softmax(logits, dim=1)
            # (B, 1)
            next_idx = torch.multinomial(probabilities, num_samples=1)
            # (B, T+1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

m = BigramLanguageModel(vocab_size).to(device)
out, loss = m(x, y)
print(out.shape)
print(out)

print(decode(m.generate(torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=100)[0].tolist()))

In [13]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [14]:
batch_size = 32

for steps in range(5000):
    xb,yb = get_batch('train')

    logits, loss = m(xb,yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

In [None]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=1000)[0].tolist()))

In [None]:
eval_iters = 200
@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(split)
            logits, loss = m(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out
estimate_loss()

## Mathematical trick in self-attention!

- have to average the logits in the time dim 0..t for logit t


In [None]:
B, T, C = 4, 8, 2
x = torch.randn(B,T,C)

In [None]:
div = torch.tril(torch.ones(T,T))
div /= div.sum(dim=1, keepdim=True)
xbow = div @ x

In [None]:
div

In [None]:
x[0], xbow[0]

### using softmax(infinity)

hint: e^-infinity = 0, and e^0 = 1

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei,dim=1)
xbow3 = wei @ x
wei

## A bit about attention

- Attention is just a mechanism that adds a set of values with a set of weights. The approach above takes the weights to be equally distributed for the node itself and the nodes before, and zero for all nodes after.

- But, the current node might find more of what it needs from some nodes rather than others; it won't necessarily be equally distributed.

- Paper proposes an attention function where each node (token) at time T emits a query vector that contains the information that the current node is looking for, and a key vector that contains the information that the current node has within itself.

- This query vector and key vector get multiplied together to get the "affinities" between what the nodes are looking for and what the nodes have (T, T dimension, so each combination)

- Instead of taking the average of each node, we perform softmax on this new matrix. Now, instead of multiplying the "original" values $x$, we multiply it with the "value" matrix, which is different for each attention "head"

- As each head has a different purpose, it will have a different value to emit in each head, a different value that it posesses that makes more sense for that particular head!

In [None]:
head_size = 16
Q = nn.Linear(C, head_size, bias=False)
K = nn.Linear(C, head_size, bias=False)
V = nn.Linear(C, head_size, bias=False)

queries = Q(x)
keys = K(x)

print(queries.shape)
print(keys.shape)

In [None]:
T

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.einsum('btd,bad->bat', [queries, keys])
# print(wei)
# _wei = keys @ queries.transpose(-2, -1) # (4, 8, 8)
# wei = torch.zeros((T, T))
wei1 = wei.masked_fill(tril==0, float('-inf'))
wei1 = F.softmax(wei1, dim=1)
wei = F.softmax(wei,dim=1)

values = V(x)

xbow4 = wei @ values
xbow5 = wei1 @ values
print(wei.shape)
print(xbow4.shape)

In [None]:
wei[0], xbow4[0]

In [None]:
wei1[0], xbow5[0]

there's a little problem tho

In [None]:
query = torch.randn((4, 8, 16))
key = torch.randn((4, 8, 16))

print(query.var())
print(key.var())

In [None]:
qk = key @ query.transpose(-2, -1)
print(qk.var())

HUGE difference in variance, and when variance is high, means the difference between the values is huge. Since we'll apply softmax on this, if the values are very imbalanced, there'll be a huge imbalance in the weight assigned to other nodes, esp when the network is still untrained.

The paper proposes dividing the multiplication by the square root of head size, let's try it.

In [None]:
qk = key @ query.transpose(-2, -1) * head_size**-0.5
print(qk.var())

looks good

In [None]:
# num of attn heads running in parallel
n_heads = 16
# embedding size
# all layer final outputs must match 256
n_embd = 512

# individual heads are concat at the end
head_size = n_embd // n_heads

# size of ffn hidden layer
hidden_size = 2048

# total number of stacked transformer blocks
n_blocks = 6

block_size=64

In [None]:
num = torch.arange(0, n_embd, 2).float()
thetas = 1.0/10000**(num/n_embd)

m = torch.arange(0, 5)
freqs1 = torch.einsum('i,j->ij', [m, thetas])
freqs2 = torch.outer(m, thetas).float()
print(freqs1, freqs2)

In [None]:
x = torch.randn(4, 8, 4)
print(x)
x[...,0]


In [None]:
def create_thetas(seq_len, head_size, theta = 10000):

    num = torch.arange(0, head_size, 2).float()
    thetas = 1.0/theta**(num/head_size)

    m = torch.arange(0, seq_len)
    freqs = torch.einsum('i,j->ij', [m, thetas])
    # freqs = torch.outer(m, thetas).float()

    freqs_complex = torch.polar(torch.ones_like(freqs),freqs)

    return freqs_complex

def apply_rot_embd(x, freqs_complex):

    # all shapes except last; divide last shape into pairs of two
    # B,T,N,2
    x_mod = x.float().reshape(*x.shape[:-1], -1, 2)
    xr = x_mod[..., 0]
    xi = x_mod[..., 1]
    
    freqs_complex = freqs_complex.unsqueeze(0)
    cr = freqs_complex.real
    ci = freqs_complex.imag

    out_r = xr * cr - xi * ci
    out_i = xr * ci + xi * cr

    x_rot = torch.stack((out_r, out_i), dim=-1).reshape(*x.shape).to('cuda')
    return x_rot

In [None]:
class SingleAttentionHead(nn.Module):

    def __init__(self, rope_freqs, mask=False):
        super().__init__()
        self.Q = nn.Linear(n_embd, head_size, bias=False)
        self.K = nn.Linear(n_embd, head_size, bias=False)
        self.V = nn.Linear(n_embd, head_size, bias=False)
        self.act = nn.SiLU()
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.rope_freqs = rope_freqs
        # encoder attn module
        self.mask = mask

    def forward(self, x, encoder_embd=None):
        B, T, C = x.shape
        # (B, T, head_size)
        queries = self.act(apply_rot_embd(self.Q(x), self.rope_freqs))

        # in encoder arch, keys and values come from the encoder
        # this usually involves the ground truth
        if encoder_embd:
            keys = self.K(encoder_embd)
            values = self.V(encoder_embd)
        else:
            keys = self.K(x)
            values = self.V(x)

        keys = self.act(keys)
        values = self.act(values)

        keys = apply_rot_embd(keys, self.rope_freqs)

        wei = torch.einsum('btd, bad->bta', [queries, keys]) * head_size ** -0.5

        if self.mask:
            wei = wei.masked_fill(self.tril == 0, float('-inf'))

        wei = F.softmax(wei, dim=-1)

        x = wei @ values

        return x

class FFN(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.layers = nn.Sequential(
            # op: (B, T, hidden_size)
            nn.Linear(in_features, hidden_size, bias=bias),
            nn.GELU(),
            # op: (B, T, n_embd)
            nn.Linear(hidden_size, out_features, bias=bias),
        )
        self.layer_norm=nn.LayerNorm(out_features)

    def forward(self, x):
        out = x + self.layer_norm(self.layers(x))
        return out

In [None]:
class MultiAttentionBlock(nn.Module):

    def __init__(self, mask=False):
        super().__init__()


        rope_freqs = create_thetas(seq_len=block_size, head_size=head_size).to('cuda')

        self.heads = [SingleAttentionHead(mask=mask, rope_freqs=rope_freqs).to(device) for _ in range(n_heads)]

        self.linear = nn.Linear(n_embd, n_embd)

        self.layer_norm = nn.LayerNorm(n_embd)

    def forward(self, x, encoder_embd=None):
        # each op: (B, T, head_size)
        act = [head(x, encoder_embd) for head in self.heads]
        # op: (B, T, n_embd)
        out = x+self.layer_norm(self.linear(torch.concat(act, dim=-1)))

        return out

class DecoderBlock(nn.Module):
    def __init__(self, is_enc=False):
        super().__init__()

        self.multi_self_attention_block = MultiAttentionBlock(mask=True).to(device)
        if is_enc:
            self.cross_attn_block = MultiAttentionBlock(mask=False).to(device)
        self.ffn = FFN(n_embd, n_embd).to(device)

    def forward(self, x, encoder_embd=None):
        x = self.multi_self_attention_block(x)
        if encoder_embd:
            x = self.cross_attn_block(x, encoder_embd)
        x = self.ffn(x)

        return x

class EncoderBlock(nn.Module):
    def __init__(self):

        super().__init__()

        self.multi_self_attention_block = MultiAttentionBlock(mask=False).to(device)

        self.ffn = FFN(n_embd, n_embd).to(device)

    def forward(self, x):
        x = self.multi_self_attention_block(x)
        x = self.ffn(x)

        return x

In [None]:
def positional_embed(seq_len, n_embd):
    pe = torch.zeros(seq_len, n_embd, device=device)

    position = torch.arange(0, seq_len).unsqueeze(1).float()
    even = torch.arange(0,n_embd,2).float()

    pe[:, 0::2] = torch.sin(position / 10000**(2*even/n_embd))
    pe[:, 1::2] = torch.cos(position / 10000**((2*even+1)/n_embd))
    return pe

In [None]:
class Transformer(nn.Module):

    def __init__(self, encoder=False):
        super().__init__()
        self.encoder=encoder

        self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_embd)

        # positional embedding applied in the MultiAttentionBlock layer

        # self.rope_freqs = create_thetas(seq_len=block_size, n_embd=n_embd)
        # self.position_embedding_table = positional_embed(block_size, n_embd)
        # self.lm_head = SingleAttentionHead(head_size)
        # self.ffn = FFN(head_size, hidden_size)
        # self.attention_block = SingleAttentionBlock(head_size, hidden_size)

        # inp: (B, T, n_embd)
        # op:  (B, T, n_embd)
        # self.multi_head_attn = MultiAttentionBlock()

        # self.ffn = FFN(n_embd, n_embd)
        # pairs of multi head self attn blocks + ffn in sequence
        if encoder:
            self.encoder_block = nn.Sequential(*[EncoderBlock().to(device) for _ in range(n_blocks)])
        self.decoder_block = nn.Sequential(*[DecoderBlock(is_enc=encoder).to(device) for _ in range(n_blocks)])

        self.nn = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embd = self.token_embedding_table(idx)
        # pos_embd = torch.nn.Dropout(0.1)(self.position_embedding_table)
        x = torch.nn.Dropout(0.1)(tok_embd)
        # x = self.lm_head(x)
        # x = self.ffn(x)
        # x = self.attention_block(x)
        # residual connections moved to their respective classes
        if self.encoder:
            x_enc = self.encoder_block(x)
            x = self.decoder_block(x, x_enc)
        else:
            x = self.decoder_block(x)

        logits = self.nn(x)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets, label_smoothing=0.1)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # pick only last 8 tokens for next token
            idx_next = idx[:, -block_size:]
            logits, loss = self(idx_next)
            # last time step for each batch and include all embeddings
            logits = logits[:, -1, :]

            probabilities = F.softmax(logits, dim=1)
            # (B, 1)
            next_idx = torch.multinomial(probabilities, num_samples=1)
            # (B, T+1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

In [None]:
xb, yb = get_batch('train')

In [None]:
m = Transformer(encoder=False).to(device)
out, loss = m(xb, yb)
print(out.shape)
print(out)
print("Total parameters:")
print(sum([p.nelement() for p in m.parameters()]))

print(decode(m.generate(torch.zeros((1,128), dtype=torch.long, device=device), max_new_tokens=100)[0].tolist()))

In [None]:
optimizer = torch.optim.SGD(m.parameters(), lr=5e-3)

In [None]:
batch_size = 64

for steps in range(10000):
    xb,yb = get_batch('train')

    logits, loss = m(xb,yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % 100 == 0:
        print(f"Loss at {steps}: {loss.item()}")

training a bit longer bec loss still decreasing

In [None]:
batch_size = 512

for steps in range(1500):
    xb,yb = get_batch('train')

    logits, loss = m(xb,yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % 100 == 0:
        print(f"Loss at {steps}: {loss.item()}")

In [None]:
estimate_loss()

* transformer with single attention block, no layer norm, and no ffn: train 2.3442 val: 2.3719
* transformer with single attn, layer norm and ffn, no residual connection: train 2.2628 val 2.3023

* transformer with multi head attn + linear, layer norm, ffn and residual connection in after multihead attn: train 1.7584 val 1.9072

* transformer with multiple stacked attention-ffn blocks!: train 1.65 val 1.82

* using sinusoidal positional embedding converges much faster!! train 1.71 val 1.87

* GPU MAKES IT SM FASTERRRRRRR but model stops learnign because im doing layernorm after residual?? it works when i do residual after layernorm


new best: train 1.49 val 1.75

In [None]:
print(decode(m.generate(torch.zeros((1,128), dtype=torch.long,device=device), max_new_tokens=1000)[0].tolist())[128:])

## positional encoding


In [None]:
from math import sin, cos

In [None]:
i = list(range(1,51))
pos = list(range(1,9))
embeddings = {n:[] for n in pos}
embeddings_no_div = {n:[] for n in pos}
for p in pos:
    for num in i:
        if num % 2 == 0:
            embed = sin(p)
            div_embed = sin(p/10000**(-2*num/512))
        else:
            embed = cos(p)
            div_embed = cos(p/10000**(-2*num/512))
        embeddings_no_div[p].append(embed)
        embeddings[p].append(div_embed)

In [None]:
sin(1)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(20, 5)
ax[0].plot(embeddings[2])
ax[0].plot(embeddings_no_div[2], label = "Without division")
# ax[0].plot(embeddings[6])
ax[1].plot(embeddings[3])
ax[1].plot(embeddings_no_div[3], label = "without division")
# ax[1].plot(embeddings[5])
# plt.plot([e for e in embeddings.values()],label = [f"Pos{i}" for i in embeddings.keys()])
plt.legend()
plt.show()

In [None]:
xb

In [None]:
torch.arange(0, 5).unsqueeze(1).shape

In [None]:
even = torch.arange(0,n_embd,2).float()
even+1

In [None]:
def positional_embed(seq_len, n_embd):
    pe = torch.zeros(seq_len, n_embd)

    position = torch.arange(0, seq_len).unsqueeze(1).float()
    even = torch.arange(0,n_embd,2).float()

    pe[:, 0::2] = torch.sin(position / 10000**(2*even/n_embd))
    pe[:, 1::2] = torch.cos(position / 10000**((2*even+1)/n_embd))
    return pe

In [None]:
positional_embed(8,8)