Simple decoder-only transformer

- only need decoder because we are generating from data and only using self-attention (i.e. "babbling" Shakespeare), instead of feeding in key-value pairs from the encoder (i.e. translating a French sentence into an English sentence)

In [1]:
# # tiny shakespeare dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-01-17 18:46:23--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-01-17 18:46:24 (3.68 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x7fa47169c450>

In [4]:
with open('input.txt', 'r', encoding="utf-8") as f:
  text = f.read()
len(text) # number of chars

1115394

In [5]:
# vocab
chars = sorted(list(set(text))) # get all unique chars in data
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
# map chars and ints
stoi = {ch:i for i,ch in enumerate(chars)} # dict
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: string -> list of int
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: list of int -> string

print(encode("hello world"))
print(decode(encode("hello world")))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [7]:
# encode entire dataset
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [8]:
# train val split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 8
x = train_data[:block_size]
print(x)
y = train_data[1:block_size+1]
print(y)
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"when input is {context} the target is {target}")

tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])
when input is tensor([18]) the target is 47
when input is tensor([18, 47]) the target is 56
when input is tensor([18, 47, 56]) the target is 57
when input is tensor([18, 47, 56, 57]) the target is 58
when input is tensor([18, 47, 56, 57, 58]) the target is 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is 58


In [10]:
batch_size = 4 # how many indepent sequences processed in parallel
block_size = 8 # maximum context length for preds

def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data)-block_size, (batch_size,)) # get batch_size number of random data values
  x = torch.stack([data[i:i+block_size] for i in ix]) # (4,8)
  y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # (4,8)
  return x,y

xb, yb = get_batch("train")
print(xb)
print(yb)

print('----')

for b in range(batch_size): # batch dimension; same loop as before but add B
  for t in range(block_size): # time dimension
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f"when input is {context.tolist()} the target: {target}")

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53, 56, 1, 58, 46] the target: 39
when input is [44, 53, 

[Baseline] Bigram model

In [11]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
  
  def __init__(self, vocab_size):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # contains vocab_size number of tensors or size vocab_size (embedding is same length as vocab - what happens if we made this smaller? It should still work just lower dim embed space)
  
  def forward(self, idx, targets=None):
    # idx and targets are both (B,T)
    logits = self.token_embedding_table(idx) # (B,T,C) where C is the second vocab_size above (additional dim we get from embedding)
    
    if targets is None: # during generation
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C) # stack the batches on top of eachother; just some data wrangling for F.cross_entropy
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    
    return logits, loss
  
  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in current context
    for _ in range(max_new_tokens):
      logits, loss = self(idx) # calls self.forward
      logits = logits[:,-1,:] # (B,C); use -1 to get the last time step, AKA what does the last character think is most "interesting" from self-attn?
      probs = F.softmax(logits, dim=-1) # (B,C)
      idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
      idx = torch.cat((idx,idx_next),dim=1) # (B,T+1) append to running sequence
    
    return idx
    
m = BigramLanguageModel(vocab_size)
# m.token_embedding_table(torch.tensor(encode("h"))) # plucks out the row for "h" and does this for each letter in input
logits, loss = m(xb,yb)
# print(logits[0])
print(logits.shape)
print(loss)

print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [12]:
# testing stuff
a = torch.tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])
print(a)
a.view(8,3)

tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])


tensor([[-0.0251, -1.6902,  0.7172],
        [-0.6431,  0.0748,  0.6969],
        [ 1.4970,  1.3448, -0.9685],
        [-0.3677, -2.7265, -0.1685],
        [ 1.4970,  1.3448, -0.9685],
        [ 0.4362, -0.4004,  0.9400],
        [-0.6431,  0.0748,  0.6969],
        [ 0.9124, -2.3616,  1.1151]])

In [13]:
# training
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32

for steps in range(100):
  xb,yb = get_batch("train")
  logits,loss = m(xb,yb)
  optimizer.zero_grad(True) # zero out grads before loss (so it doesn't accumulate during training)
  loss.backward()
  optimizer.step()
  
print(loss.item())

4.65630578994751


In [18]:
for name, param in m.named_parameters():
  if param.requires_grad:
    print(name, param.data.shape)

token_embedding_table.weight torch.Size([65, 65])


Self-attention

- Think of B,T,C as B by T matrix of tokens, each token with C dims
- Self-attention insight is using non-uniform weights in tril matrix; some info is more "interesting" than others, and it is gathered in an independent way through KQV
  - Information that makes up x, which is what K and Q operate on, is token and pos embedding
  - X is kind of like private information; if you find anything interesting from me, I will communicate V

In [59]:
# version 3: use Softmax
# aggregation is just mean of previous tokens

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T, T)) # masking so only previous tokens in context are used
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # for ones that we don't want to see, weight it -inf
wei = F.softmax(wei, dim=-1) # e^-inf = 0 -> same effect as taking average
xbow3 = wei @ x

In [70]:
# version 4: self-attention
# aggregation is based on attention mechanism with KQV insetad of simple mean

torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# each node produces a key and a query
k = key(x)  # (B,T,C) @ (C,16) = (B,T,16)
q = query(x) # (B,T,16)

# weight is based on how aligned key and query are; communicate by dot product queries and keys
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T); matmul for higher D just operates on last 2D

# similar to version 3
tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
print(out.shape)
print(wei[0])

torch.Size([4, 8, 16])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)


Attention heads

- Head size is like how much info you "communicate" to other nodes
  - Changes dimension of key and query channel
- Having multiple attention heads is like having multiple communication channels, with different information
  - More heads w/ smaller channels > 1 head w/ larger channels
  - Can gather a lot of diff types of information
  - Similar to having more convolutional filters!

In [3]:
class LayerNorm1d: # (used to be BatchNorm1d); LayerNorm used more in LM because you don't know how long each input will be (aka how many tokens in sentence)
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True)
    xvar = x.var(1, keepdim=True)
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [8]:
x.mean(1) # across the 100 dimensions instead of batch of 32

tensor([-9.5367e-09, -2.3842e-09, -2.0266e-08,  1.7881e-08,  1.6689e-08,
         9.8348e-09,  4.7684e-09,  1.9073e-08, -1.4305e-08, -4.7684e-09,
        -1.3113e-08, -5.9605e-09,  0.0000e+00, -7.1526e-09, -2.0266e-08,
         7.0035e-09, -1.2815e-08,  1.7881e-08,  6.5565e-09, -4.7684e-09,
         9.5367e-09, -3.5763e-09, -2.8610e-08,  4.7684e-09,  3.5763e-09,
        -7.1526e-09, -4.7684e-09,  0.0000e+00,  5.3644e-09, -1.1921e-08,
         4.7684e-09,  1.9073e-08])

In [10]:
x.var(1) # creates mean of 0 and variance of 1

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

### Summary of how transformers work:

Blocks (Self attention + Feed forward -> Communication + Computation)
- Each block has self attention + feedforward, and stacked together blocks create the transformer
- Self attention is on a per token level, and then feed forward is "thinking" on the data that they gathered individually

Add + Norm (Residual connnections and layer norm for optimizing larger networks)
- There is a gradient flow of information, which is automatically passed through; each block is free to fork off, do some computation, and then project back onto residual pathway and add info to the flow
  - Allows gradient to flow through to the last layer