### Lab 9.1 Attention Implementation

This week you will experimenet with attention-based models.

In [1]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

1. Complete the following implementation of scaled dot-product attention.   Run the code cell to verify that the output shape is what it should be.

*Note: you can use `scores = scores.masked_fill(...)` to fill in values where the mask is True.  Fill in -1e9 as the score for masked values.*

In [4]:
def attention(Q, K, V, mask=None):
    """
    Computes scaled dot-product attention.

    Compute scores as Q*K^T. 
    Optionally mask out score values to -1e9 where the mask is True.
    Divide by sqrt(d_k).
    Compute softmax on scores along the rows to obtain attention weights.
    Matrix multiply attention weights by values.

    Arguments:
      Q: queries [B,L,d_k]
      K: keys    [B,S,d_k]
      V: values  [B,S,d_v]
      mask: optional Boolean mask where True means hidden [B,L,S]

    Returns:
      Sequence of context vectors of shape [B,L,d_v]
    """
    scores = torch.matmul(Q, K.transpose(-2, -1))
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
    
    scores = scores / np.sqrt(Q.size(-1))
    weights = torch.softmax(scores, dim=-1)
    out = torch.matmul(weights, V)
    return out


Q = torch.rand(1, 5, 64)
K = torch.rand(1, 10, 64)
V = torch.rand(1, 10, 8)
mask = (torch.rand(1, 5, 10) > 0.5)

y = attention(Q, K, V, mask=mask)

y.shape

torch.Size([1, 5, 8])

The following code creates classes to build a Transformer-style decoder for generating sequences.

In [5]:
class AttentionHead(nn.Module):
    def __init__(self,d_model,d_k):
        super().__init__()
        self.WQ = nn.Linear(d_model,d_k)
        self.WK = nn.Linear(d_model,d_k)
        self.WV = nn.Linear(d_model,d_k)

    def forward(self,Q,K,V,mask=None):
        """ Compute attention head.

            Project the input to queries, keys, and values, and then apply attention.
            Arguments:
                Q: queries [B,L,d_model]
                K: keys    [B,S,d_model]
                V: values  [B,L,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
                Context vectors [B,L,d_k]
        """
        # apply linear projections to queries, keys, and values followed by masked attention
        return attention(self.WQ(Q),self.WK(K),self.WV(V),mask=mask)

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model=512,num_heads=8):
        super().__init__()
        self.heads = []
        d_k = d_model // num_heads
        self.heads = nn.ModuleList([AttentionHead(d_model,d_k) for head in range(num_heads)])
        self.W = nn.Linear(d_model,d_model)

    def forward(self,Q,K,V,mask=None):
        """ Compute multi-head attention.

            Applies attention num_heads times, concatenates the results, and applies a final linear projection.
            Arguments:
                Q: queries [B,L,d_model]
                K: keys    [B,S,d_model]
                V: values  [B,L,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
               result of multi-head attention [B,L,d_model]
        """
        # compute each attention head and concatenate
        h = torch.cat([head(Q,K,V,mask=mask) for head in self.heads],dim=-1)

        # apply output projection
        return self.W(h)

class SelfAttentionBlock(nn.Module):
    def __init__(self,d_model=512,num_heads=8,d_ff=2048):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(d_model,num_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self,x,mask=None):
        """ Compute self attention block.

            Arguments:
                x: input sequence [B,S,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
               result of attention block [B,L,d_model]
        """
        # compute multi-head attention
        mha = self.multi_head_attention(x,x,x,mask=mask)

        # residual connection and layer normalization
        x = self.ln1(mha + x)

        # compute feed-forward network
        ff = self.feed_forward(x)

        # residual connection and layer normalization
        x = self.ln2(ff + x)

        return x
    
class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len,d_model):
        super().__init__()
        self.positional_embedding = nn.Embedding(max_seq_len,d_model)

    def forward(self,x):
        """ Adds a positional embedding.

            Arguments:
                x: input token sequence [B,S,d_model]
            Output:
               sequence with positional embedding added [B,S,d_model]
        """
        # get sequence length
        N = x.shape[1]

        # look up positional embedding vectors
        pe = self.positional_embedding(torch.arange(N).to(x.device)) # [N,d_model]

        # add to input
        x = x + pe[None,...] # [B,N,d_model]
        
        return x

class TransformerDecoder(nn.Module):
    def __init__(self,vocabulary_size,max_seq_len,
                      d_model=512,num_heads=8,d_ff=2048,num_blocks=6):
        super().__init__()
        self.blocks = nn.ModuleList([SelfAttentionBlock(d_model,num_heads,d_ff) for b in range(num_blocks)])
        self.token_embedding = nn.Embedding(vocabulary_size,d_model)
        self.output = nn.Linear(d_model,vocabulary_size)
        self.positional_embedding = PositionalEmbedding(max_seq_len,d_model)

    def forward(self,x,mask=None):
        """ Computes the decoded sequence:

            Convert input to token embedding vectors
            Add positional embedding to input
            Apply self-attention blocks with mask
            Compute output

            Arguments:
                x: input token sequence [B,S]
                mask: optional Boolean mask where false means hidden [B,S]
            Output:
               sequence predictions [B,S,output_size]
        """
        # look up embedding vectors for tokens
        x = self.token_embedding(x) # [B,S,d_model]

        # apply positional embedding
        x = self.positional_embedding(x) # [B,S,d_model]
        
        # apply sequence of masked self-attention blocks
        for block in self.blocks:
            x = block(x,mask=mask) # [B,S,d_model]
        
        # produce sequence of output vectors
        y = self.output(x) # [B,S,vocabulary_size]

        return y


This function produces masks appropriate for sequence prediction.  The mask ensures that the output token at time t+1 only sees the generated sequence up to time t.

In [6]:
def make_mask(seq_len):
    """ Make a mask for sequence prediction. """
    return (torch.triu(torch.ones((1,seq_len,seq_len)), diagonal=1)==1)

make_mask(10)

tensor([[[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False, False, False,  True,  True,  True,  True,  True],
         [False, False, False, False, False, False,  True,  True,  True,  True],
         [False, False, False, False, False, False, False,  True,  True,  True],
         [False, False, False, False, False, False, False, False,  True,  True],
         [False, False, False, False, False, False, False, False, False,  True],
         [False, False, False, False, False, False, False, False, False, False]]])

Now we will make a sequence of integers and see if the Transformer decoder can learn the sequence.

In [7]:
seq = torch.arange(100)
x = seq[:-1][None,...]
y = seq[1:][None,...]
mask = make_mask(x.shape[1])
x,y

(tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
          36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
          54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
          72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
          90, 91, 92, 93, 94, 95, 96, 97, 98]]),
 tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
          19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
          37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
          55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
          73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
          91, 92, 93, 94, 95, 96, 97, 98, 99]]))

In [8]:
steps = 100

model = TransformerDecoder(vocabulary_size=100,max_seq_len=x.shape[1],
                           d_model=64,num_heads=8,d_ff=512,num_blocks=3
                           )


opt = torch.optim.Adam(model.parameters(),lr=.01)
loss_fn = nn.CrossEntropyLoss()

for step in range(steps):
    model.train()
    opt.zero_grad()

    y_pred = model(x,mask)
    loss = loss_fn(y_pred.view(-1,y_pred.shape[-1]),y.view(-1))
    loss.backward()

    opt.step()

    print(step,loss.item())

0 4.686927795410156
1 2.9648141860961914
2 2.588106393814087
3 2.3566484451293945
4 1.7879502773284912
5 1.2888712882995605
6 0.7589181661605835
7 0.5663195252418518
8 0.39252418279647827
9 0.2695852220058441
10 0.1874636709690094
11 0.12777379155158997
12 0.08888237923383713
13 0.0635567381978035
14 0.04658680781722069
15 0.03535320609807968
16 0.027867989614605904
17 0.022476395592093468
18 0.01819978654384613
19 0.014668811112642288
20 0.011812885291874409
21 0.00959496758878231
22 0.007916177622973919
23 0.006650303490459919
24 0.005682090763002634
25 0.00492187449708581
26 0.004310456104576588
27 0.0038083535619080067
28 0.0033909459598362446
29 0.0030427398160099983
30 0.0027510453946888447
31 0.002506152493879199
32 0.0023002822417765856
33 0.0021264569368213415
34 0.001978531712666154
35 0.0018513250397518277
36 0.0017400712240487337
37 0.0016414851415902376
38 0.0015528165968135
39 0.0014722332125529647
40 0.0013984597753733397
41 0.0013306493638083339
42 0.0012683414388448
43

If the Transformer has learned the sequence correctly, this output will read 1, 2, 3, ..., 97, 98, 99.

In [9]:
torch.argmax(model(x),-1)

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
         37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
         55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
         73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
         91, 92, 93, 94, 95, 96, 97, 98, 99]])

2. What size context does the Transformer need in order to learn the above sequence?

In [11]:
max_seq_len = x.shape[1]
max_seq_len

99

The model need to be able to attend over the entire sequence of length 99 in order to learn the sequence.

3. Now design a pattern that requires a larger context and see if the Transformer can learn it.

In [15]:
# larger context in the form of reversed sequence
batch_size = 16
seq_len = 50
vocab_size = 100
x = torch.randint(0, vocab_size, (batch_size,seq_len))
# reverse sequence as target
y = torch.flip(x, dims=(-1,))

print("Input sequences:\n", x)
print("Target (reversed) sequences:\n", y)


Input sequences:
 tensor([[82, 73, 64, 46, 30, 25, 90, 43, 45, 73, 83, 52, 97,  9, 11, 95, 45, 50,
         69,  5, 22, 83, 39,  9, 84,  2, 21, 70, 29, 97, 67, 38, 93, 22, 67, 59,
         29,  5, 14, 64, 42, 82, 35, 37, 12, 75, 15, 95, 33, 61],
        [59, 64, 78, 55,  4, 52, 63,  1,  6, 25, 68, 57, 57, 28, 65, 39, 48, 36,
         69, 54, 15, 60, 75, 47, 22, 96, 42, 93, 22, 17, 50, 91, 18, 80, 80, 63,
         14,  6, 64, 43, 99, 16, 76, 13, 83, 28, 94, 86, 13, 31],
        [82, 26, 84,  1, 44, 70,  3, 35, 19,  8, 93, 58, 63,  9, 98, 68,  6, 14,
         53, 89, 49, 53,  8, 22, 59,  3, 71, 45, 95, 57, 16, 89, 34, 19, 13, 14,
          8, 99, 53, 67, 36,  1, 93, 72, 90, 76, 52, 48, 24, 18],
        [33, 61,  0, 94, 97, 68,  2, 56, 26, 41, 92, 51,  6, 24, 72, 37, 43, 62,
         71, 68, 78, 59, 89, 44, 51, 53, 45, 41, 35, 28,  3, 16, 69, 97, 30, 89,
          3, 65, 88, 33, 79, 22, 76, 87, 57, 26, 91, 69, 71, 49],
        [97, 44, 93, 89,  9, 93, 12,  2, 11, 98, 66, 75, 66, 14, 76, 7

In [16]:
steps = 100

model = TransformerDecoder(vocabulary_size=vocab_size,max_seq_len=max_seq_len,
                            d_model=64,num_heads=8,d_ff=512,num_blocks=3,
                            )
mask = make_mask(x.shape[1])


opt = torch.optim.Adam(model.parameters(),lr=.01)
loss_fn = nn.CrossEntropyLoss()

for step in range(steps):
    model.train()
    opt.zero_grad()

    y_pred = model(x,mask)
    loss = loss_fn(y_pred.view(-1,y_pred.shape[-1]),y.view(-1))
    loss.backward()

    opt.step()

    print(step,loss.item())

0 4.744674205780029
1 4.48797082901001
2 4.23170280456543
3 3.927212715148926
4 3.7159228324890137
5 3.534147024154663
6 3.3498973846435547
7 3.208298683166504
8 2.814455270767212
9 2.6157360076904297
10 2.432138681411743
11 2.062748670578003
12 1.8448920249938965
13 1.5506446361541748
14 1.241349458694458
15 1.0094467401504517
16 0.765720546245575
17 0.5727940201759338
18 0.4164573550224304
19 0.2993224263191223
20 0.2102164328098297
21 0.14678436517715454
22 0.10846157371997833
23 0.08212029188871384
24 0.06059759482741356
25 0.045069798827171326
26 0.035211049020290375
27 0.02803529053926468
28 0.022875657305121422
29 0.019254906103014946
30 0.0159899964928627
31 0.013663635589182377
32 0.011799123138189316
33 0.010077040642499924
34 0.008975471369922161
35 0.007924433797597885
36 0.007059736642986536
37 0.006503175478428602
38 0.005859775003045797
39 0.005433380138128996
40 0.0050832186825573444
41 0.0046897088177502155
42 0.004481224808841944
43 0.004209522157907486
44 0.003995881

In [17]:
torch.argmax(model(x),-1)

tensor([[45, 48, 53, 52, 75, 12, 37,  1, 82, 42, 64, 14,  5, 89, 59, 67, 22, 93,
         38, 97, 97, 29, 70, 21,  2, 84,  9, 39, 83, 22,  5, 69, 50, 45, 95, 11,
          9, 97, 52, 83, 73, 45, 43, 90, 25, 30, 46, 64, 73, 82],
        [69, 57, 52, 94, 61, 83, 13, 76, 16, 99, 94, 64,  6, 14, 63, 80, 80, 18,
         38, 50, 17, 22, 93, 42, 96, 22, 47, 75, 60, 15, 54, 69, 36, 48, 39, 65,
         28, 57, 57, 68, 25,  6,  1, 63, 52,  4, 55, 78, 64, 59],
        [ 2, 97, 40, 52, 19, 62, 59, 93,  1, 36, 67, 29, 99, 26, 14, 91, 19, 21,
         21, 16, 57, 84, 45, 71,  3, 59, 22,  8, 53, 49, 89, 53, 14,  6, 68, 98,
          9, 63, 58, 93,  8, 19, 35,  3, 70, 44,  1, 84, 26, 82],
        [92, 98, 63, 91, 26, 57, 95, 76,  9, 23, 50, 88, 65,  3, 21, 59, 97, 69,
         16,  3, 28, 35, 41, 41, 53, 51, 44, 89, 59, 78, 68, 71, 62, 43, 37, 72,
         24,  6, 92, 92, 41, 26,  1,  2, 68, 97, 94,  0, 61, 33],
        [59, 11, 31, 31, 93, 95, 44,  1, 22, 42, 83,  9, 83,  9, 90, 93, 66,  7,
       

Increasing the length of the sequence led to the model being unable to learn the pattern.  This is likely due to the model's limited capacity to attend over long sequences.