In [51]:
from rich import print as rprint
import torch
import torch.nn as nn
import torch.nn.functional as F

In [59]:
with open('./input.txt', encoding='utf-8') as f:
    text = f.read()

rprint(f'Text length: {len(text)}')

In [61]:
rprint(text[:100])

In [4]:
# Unique characters in text
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

rprint(f'Vocabulary size: {VOCAB_SIZE}')
rprint(f'{"".join(chars)}')

## Defining embeddings

In [5]:
# Character mapping
s_to_i = {s: i for i, s in enumerate(chars)}
i_to_s = {i: s for i, s in enumerate(chars)}

encode = lambda s: [s_to_i[c] for c in s]
decode = lambda l: ''.join([i_to_s[i] for i in l])

rprint(f'Encoded: {encode("Hello World")}')

rprint(f'Decoded: {decode([56, 50, 42])}')

In [6]:
# Encoding the input
data = torch.tensor(encode(text), dtype=torch.long)
rprint(f'Encoded data shape: {data.shape}')
rprint(f'Encoded data: {data[:100]}')

In [7]:
n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n:]

## Generating batches

In [8]:
BLOCK_SIZE = 8

X = train_data[:BLOCK_SIZE]
y = train_data[1:BLOCK_SIZE + 1]

for t in range(BLOCK_SIZE):
    rprint(f'Input: {X[:t+1]} , Output:{y[t]}')

In [9]:
torch.manual_seed(1337)

BATCH_SIZE = 4
BLOCK_SIZE = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    X = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
    return X, y

xb, yb = get_batch('train')
rprint(f'Inputs')
rprint(f'X shape: {xb.shape}')
rprint(f'X: {xb}')
rprint(f'Outputs')
rprint(f'y shape: {yb.shape}')
rprint(f'y: {yb}')

rprint('-' * 50)

In [10]:
for b in range(BATCH_SIZE):
    for t in range(BLOCK_SIZE):
        context = xb[b, :t+1]
        target = yb[b, t]
        rprint(f'Context: {context} , Target: {target}')
        rprint(f'Context: {decode(context.tolist())} , Target: {decode([int(target)])}')
    rprint('-' * 50)

In [11]:
print(xb)

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


## Bigram Language Model

In [12]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets is None:
            return logits, None
        
        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)

            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
            
        return idx

model = BigramLanguageModel(VOCAB_SIZE)

out = model(xb, yb)
rprint(f'Output shape: {out[0].shape}')
rprint(f'Loss: {out[1]}')
rprint(f'Generate: {model.generate(xb, 5)}')

In [13]:
rprint(decode(model.generate(torch.zeros((1,1), dtype=torch.long), 100)[0].tolist()))

## Train the model

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)

In [15]:
BATCH_SIZE = 32

for steps in range(10000):
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 1000 == 0:
        rprint(f'Step: {steps} , Loss: {loss}')

rprint(f'Loss: {loss.item()}')

In [16]:
rprint(decode(model.generate(torch.zeros((1,1), dtype=torch.long), 300)[0].tolist()))

## Building up to self attention

In [17]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [18]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, dim=0)

In [19]:
rprint(f'X[0]: {x[0]}')
rprint(f'Xbow[0]: {xbow[0]}')

Notice how the first row is the same, but every subsequent $i^{th}$ row is a simple average of the 0 to $i^{th}$ rows.

### Matrix multiplication trick

In [20]:
# Matrix multiplication
torch.manual_seed(42)
a = torch.ones(4,4)
b = torch.randint(0, 10, (4,3)).float()
c = a @ b

rprint(f'a: {a}')
rprint(f'b: {b}')
rprint(f'c: {c}')

In [21]:
# Matrix multiplication to sum previous rows. This is using a lower triangular matrix as a mask
a_ = torch.tril(torch.ones((4,4)))
rprint(f'a_: {a_}')
rprint(f'b: {b}')
rprint(f'a_ @ b: {a_ @ b}')

In [22]:
# Taking the average as a normalization step
a_normalized = a_ / torch.sum(a_, dim=1, keepdim=True)
rprint(f'a_normalized: {a_normalized}')

rprint(f'a_normalized @ b: {a_normalized @ b}')

### Using matrix multiplication to make it efficient

In [23]:
weights = torch.tril(torch.ones(T, T))
weights = weights / torch.sum(weights, dim=1, keepdim=True)
rprint(f'Weights: {weights}')

In [24]:
xbow_matrix = weights @ x
rprint(f'Xbow matrix: {xbow_matrix}')

### Using softmax


In [25]:
tril = torch.tril(torch.ones((T, T)))
weights = torch.zeros((T, T))

weights = weights.masked_fill(tril == 0, float('-inf'))

weights = F.softmax(weights, dim=-1)

xbow_softmax = weights @ x
rprint(f'Xbow softmax: {xbow_softmax}')

torch.allclose(xbow_matrix, xbow_softmax)

True

## Self-attention

In [29]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
weights = torch.randn((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

out = weights @ x


### Single head of self attention

In [45]:
HEAD_SIZE = 16
query = nn.Linear(C, HEAD_SIZE, bias=False)
key = nn.Linear(C, HEAD_SIZE, bias=False)
value = nn.Linear(C, HEAD_SIZE, bias=False)

rprint(f'Query: {query}')
rprint(f'Key: {key}')
rprint(f'Value: {value}')

In [46]:
q = query(x)
k = key(x)
v = value(x)

rprint(f'Query shape: {q.shape}')
rprint(f'Key shape: {k.shape}')
rprint(f'Value shape: {v.shape}')

The weight is calculated using:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

Reference: https://arxiv.org/pdf/1706.03762.pdf

In [47]:
tril = torch.tril(torch.ones(T, T))
weights = q @ k.transpose(-2, -1) ** (HEAD_SIZE ** -0.5) # scaled dot product attention
weights = weights.masked_fill(tril == 0, float("-inf"))
weights = F.softmax(weights, dim=-1)

rprint(f'Weights: {weights[0]}')

In [50]:
out = weights @ v
rprint(f'Out: {out.shape}')