# Lecture 5: Transformers

In [71]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x111f00d10>

In [72]:
# consider the following toy example:
B,T,C = 4,8,2 # batch, time, channels(embed_zise)
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [73]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [74]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

| 2. | 7. |  
| 6. | 4. |
| 6. | 5. |

| 2. | 7. |
| . | . |
| . | . |

In [75]:
# toy example illustrating how matrix multiplication can be used for a "zghted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [76]:
# toy example illustrating how matrix multiplication can be used for a "zghted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [77]:
# version 2: using matrix multiply for a zghted aggregation
z = torch.tril(torch.ones(T, T))
z = z / z.sum(1, keepdim=True)
xbow2 = z @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [78]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
z = torch.zeros((T,T))
print(tril, "\n \n",z)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]]) 
 
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [79]:
z = z.masked_fill(tril == 0, float('-inf'))
z

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [80]:
z = F.softmax(z, dim=-1)
z

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [81]:
xbow3 = z @ x
torch.allclose(xbow, xbow3)

True

In [82]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels (embed_zise)
x = torch.randn(B,T,C)

In [83]:
# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
z =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

In [84]:
tril = torch.tril(torch.ones(T, T))
#z = torch.zeros((T,T))
z = z.masked_fill(tril == 0, float('-inf'))
z = F.softmax(z, dim=-1)

v = value(x)
out = z @ v
#out = z @ x

out.shape

torch.Size([4, 8, 16])

In [85]:
z[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [86]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        z = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        z = z.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        z = F.softmax(z, dim=-1) # (B, T, T)
        z = self.dropout(z)
        # perform the zghted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = z @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [55]:
data = open('names.txt').read().splitlines()
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [56]:
token_to_index = {tok: i for i, tok in enumerate('abcdefghijklmnopqrstuvwxyz')}
# Start/stop token
token_to_index['[S]'] = 26
# Padding token
token_to_index['[PAD]'] = 27

index_to_token = {i: tok for tok, i in token_to_index.items()}

In [57]:
import torch

def build_dataset(data):
    X, Y = [], []
    for item in data:
        tokens = ['[S]'] + list(item) + ['[S]']
        indices = [token_to_index[token] for token in tokens]
        X.append(indices[:-1])
        Y.append(indices[1:])
    return X, Y

# Split into train, dev, test
import random
random.seed(123)
random.shuffle(data)

n1 = int(0.8 * len(data))
n2 = int(0.9 * len(data))

X_train, Y_train = build_dataset(data[:n1])
X_dev, Y_dev = build_dataset(data[n1:n2])
X_test, Y_test = build_dataset(data[n2:])

len(X_train), len(Y_train)

(25626, 25626)

In [58]:
X_train[0], len(X_train[0]), X_train[1], len(X_train[1]), max(len(x) for x in X_train)

([26, 11, 20, 0, 13, 13], 6, [26, 18, 7, 0, 8, 13], 6, 16)

In [59]:
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, d_model, nhead, dim_ff=64, max_len=128):
        super(Block, self).__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0, batch_first=True)
        self.ff1 = nn.Linear(d_model, dim_ff)
        self.ff2 = nn.Linear(dim_ff, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.act = nn.ReLU()
        self.register_buffer('mask', torch.triu(torch.ones(max_len, max_len), diagonal=1).bool())

    def forward(self, x):
        B, T, D = x.size()
        # Pre-normalization
        x = self.ln1(x)
        # Self-attention
        x2 = self.attn(x, x, x, is_causal=True, attn_mask=self.mask[:T,:T])[0]
        # Residual connection
        x = x + x2
        # Pre-normalization
        x = self.ln2(x)
        # Feed-forward
        x2 = self.ff2(self.act(self.ff1(x)))
        # Residual connection
        x = x + x2
        return x

In [60]:
# test out the block
block = Block(10, 2)
x = torch.randn(10, 32, 10)
block(x).shape

torch.Size([10, 32, 10])

In [61]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_ff, max_len=128):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            Block(d_model, nhead, dim_ff) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, x):
        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_encoder(pos)
        for block in self.blocks:
            x = block(x)
        logits = self.fc(x)
        return logits


In [62]:
model = TransformerLM(len(token_to_index), 64, 2, 2, 64)

x = torch.tensor(X_train[:1])

logits = model(x)
logits.size()

torch.Size([1, 6, 28])

In [63]:
def pad_batch(X_batch, Y_batch, pad_index):
    max_len = max(len(x) for x in X_batch)
    X_padded = torch.zeros(len(X_batch), max_len, dtype=torch.long) + pad_index
    Y_padded = torch.zeros(len(Y_batch), max_len, dtype=torch.long) + pad_index
    for i, (x, y) in enumerate(zip(X_batch, Y_batch)):
        X_padded[i, :len(x)] = torch.tensor(x)
        Y_padded[i, :len(y)] = torch.tensor(y)
    return X_padded, Y_padded

xp, yp = pad_batch(X_train[:4], Y_train[:4], token_to_index['[PAD]'])

print(xp)
for x in xp:
    print([index_to_token[i.item()] for i in x])

tensor([[26, 11, 20,  0, 13, 13, 27, 27, 27, 27],
        [26, 18,  7,  0,  8, 13, 27, 27, 27, 27],
        [26, 17, 20, 15,  4, 17, 19, 27, 27, 27],
        [26, 12, 14, 10, 18,  7,  0,  6, 13,  0]])
['[S]', 'l', 'u', 'a', 'n', 'n', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[S]', 's', 'h', 'a', 'i', 'n', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[S]', 'r', 'u', 'p', 'e', 'r', 't', '[PAD]', '[PAD]', '[PAD]']
['[S]', 'm', 'o', 'k', 's', 'h', 'a', 'g', 'n', 'a']


In [64]:
X_batch, Y_batch = pad_batch(X_train[:2], Y_train[:2], token_to_index['[PAD]'])

logits = model(X_batch)
logits.size()

torch.Size([2, 6, 28])

In [65]:
import torch.optim as optim

model = TransformerLM(len(token_to_index), 64, 2, 2, 64)

# Count model parameters
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

# Hyperparameters
learning_rate = 0.001
num_epochs = 10
batch_size = 16

# Loss function and optimizer
# NOTE: We ignore the loss whenever the target token is a padding token
criterion = nn.CrossEntropyLoss(ignore_index=token_to_index['[PAD]'])

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    # Reshuffle the data
    perm = torch.randperm(len(X_train))
    X_train = [X_train[i] for i in perm]
    Y_train = [Y_train[i] for i in perm]
    
    model.train()
    total_loss = 0
    for i in range(0, len(X_train), batch_size):
        X_batch = X_train[i:i+batch_size]
        Y_batch = Y_train[i:i+batch_size]
        X_batch, Y_batch = pad_batch(X_batch, Y_batch, token_to_index['[PAD]'])

        # Forward pass
        outputs = model(X_batch) # [batch_size, seq_len, vocab_size]
        outputs = outputs.view(-1, len(token_to_index)) # [batch_size * seq_len, vocab_size]
        Y_batch = Y_batch.view(-1) # [batch_size * seq_len]
        loss = criterion(outputs, Y_batch)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / (len(X_train) // batch_size)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

    # Evaluate validation loss
    eval_loss = 0
    model.eval()
    with torch.no_grad():
        for i in range(0, len(X_dev), batch_size):
            X_batch = X_dev[i:i+batch_size]
            Y_batch = Y_dev[i:i+batch_size]
            X_batch, Y_batch = pad_batch(X_batch, Y_batch, token_to_index['[PAD]'])

            outputs = model(X_batch)
            outputs = outputs.view(-1, len(token_to_index))
            Y_batch = Y_batch.view(-1)
            loss = criterion(outputs, Y_batch)

            eval_loss += loss.item()
    avg_eval_loss = eval_loss / (len(X_dev) // batch_size)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_eval_loss:.4f}')


Model parameters: 62236
Epoch [1/10], Loss: 2.2517
Epoch [1/10], Validation Loss: 2.1796
Epoch [2/10], Loss: 2.1419
Epoch [2/10], Validation Loss: 2.1359
Epoch [3/10], Loss: 2.1061
Epoch [3/10], Validation Loss: 2.1189
Epoch [4/10], Loss: 2.0831
Epoch [4/10], Validation Loss: 2.0974
Epoch [5/10], Loss: 2.0634
Epoch [5/10], Validation Loss: 2.0906
Epoch [6/10], Loss: 2.0492
Epoch [6/10], Validation Loss: 2.0908
Epoch [7/10], Loss: 2.0363
Epoch [7/10], Validation Loss: 2.0760
Epoch [8/10], Loss: 2.0265
Epoch [8/10], Validation Loss: 2.0682
Epoch [9/10], Loss: 2.0173
Epoch [9/10], Validation Loss: 2.0649
Epoch [10/10], Loss: 2.0093
Epoch [10/10], Validation Loss: 2.0596


In [66]:
# Sample from the model
def sample(model, context, max_length=100):
    model.eval()
    output = []
    with torch.no_grad():
        x = torch.tensor([[token_to_index['[S]']] + context])
        for _ in range(max_length):
            logits = model(x)
            y = torch.softmax(logits[0, -1], dim=0)
            y = torch.multinomial(y, 1)
            token = index_to_token[y.item()]
            if token == '[S]':
                break
            output.append(token)
            x = torch.cat([x, y.unsqueeze(0)], dim=1)
    return ''.join(output)

In [70]:
for i in range(10):
    print(sample(model, []))

ellohe
colston
maldi
tillana
kochan
sara
selan
isahya
lyde
allayah


In [68]:
prompt = 's'
for i in range(10):
    out = sample(model, [token_to_index[tok] for tok in prompt])
    print(prompt + out)

samanah
sisa
sultan
shanna
sisrah
samnanda
sian
selanor
sejvi
savelen
