In [1]:
# default_exp attention

# Attention

> Sources:
* http://nlp.seas.harvard.edu/2018/04/03/attention.html
* http://jalammar.github.io/illustrated-transformer/



In [3]:
#hide
# from nbdev.showdoc import *

In [143]:
import math
import numpy as np
import torch
import torch.nn as nn

In [54]:
x = torch.randn(2, 3)
x.shape

torch.Size([2, 3])

In [55]:
y = torch.transpose(x, 0, 1)
y.shape

torch.Size([3, 2])

### Create data

In [88]:
copus_a = ["one is one", "two is two", "three is three", "four is four", "five is five",
           "six is six", "seven is seven", "eight is eight", "nine is nine"]
copus_b = ["1 la 1", "2 la 2", "3 la 3", "4 la 4", "5 la 5",
           "6 la 6", "7 la 7", "8 la 8", "9 la 9"]

In [89]:
embed_a = {"is":   [1.0,0,0,0,0,0,0,0,0,0],
           "one":  [0,1.0,0,0,0,0,0,0,0,0],
           "two":  [0,0,1.0,0,0,0,0,0,0,0],
           "three":[0,0,0,1.0,0,0,0,0,0,0],
           "four": [0,0,0,0,1.0,0,0,0,0,0],
           "five": [0,0,0,0,0,1.0,0,0,0,0],
           "six":  [0,0,0,0,0,0,1.0,0,0,0],
           "seven":[0,0,0,0,0,0,0,1.0,0,0],
           "eight":[0,0,0,0,0,0,0,0,1.0,0],
           "nine": [0,0,0,0,0,0,0,0,0,1.0]}

embed_b = {"9": [1.0,0,0,0,0,0,0,0,0,0],
           "8": [0,1.0,0,0,0,0,0,0,0,0],
           "7": [0,0,1.0,0,0,0,0,0,0,0],
           "6": [0,0,0,1.0,0,0,0,0,0,0],
           "5": [0,0,0,0,1.0,0,0,0,0,0],
           "4": [0,0,0,0,0,1.0,0,0,0,0],
           "3": [0,0,0,0,0,0,1.0,0,0,0],
           "2": [0,0,0,0,0,0,0,1.0,0,0],
           "1": [0,0,0,0,0,0,0,0,1.0,0],
           "la":[0,0,0,0,0,0,0,0,0,1.0]}

In [91]:
def sentence_embed(sentence, embed_dict):
    res = []
    for word in sentence.split():
        res.append(embed_dict[word])
    return res  

[[0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [99]:
inp = sentence_embed("one is one", embed_a) 
out = sentence_embed("1 la 1", embed_b)
inp = torch.tensor(inp, dtype=torch.float32)
out = torch.tensor(out, dtype=torch.float32)
inp.shape, out.shape

(torch.Size([3, 10]), torch.Size([3, 10]))

### Scaled dot product attention

In [140]:
def dot_attention(inp, dk):
    # Initiate weight matrix for Query, Key and Value
    wq, wk, wv = [torch.rand(inp.size(-1), dk, requires_grad=True) for i in range(3)]
    q,k,v = inp @ wq, inp @ wk, inp @ wv
    logit = (q @ k.transpose(0, -1)) / math.sqrt(dk)
    weigt = torch.softmax(logit, dim=-1)
    res = weigt @ v
    return weigt, res    

In [141]:
dot_attention(inp, 8)

(tensor([[0.3219, 0.3561, 0.3219],
         [0.2966, 0.4068, 0.2966],
         [0.3219, 0.3561, 0.3219]], grad_fn=<SoftmaxBackward>),
 tensor([[0.3128, 0.4050, 0.9274, 0.2796, 0.2709, 0.6891, 0.6944, 0.5752],
         [0.3533, 0.4207, 0.9196, 0.2763, 0.2897, 0.6940, 0.6767, 0.5981],
         [0.3128, 0.4050, 0.9274, 0.2796, 0.2709, 0.6891, 0.6944, 0.5752]],
        grad_fn=<MmBackward>))

In [197]:
class SelfAttention(nn.Module):
    def __init__(self, inp, dk):
        super().__init__()
        self.inp, self.dk = inp, dk
        # Initiate weights for Query, Key and Value
        self.wq, self.wk, self.wv = [torch.rand(self.inp.size(-1), dk) 
                                     for i in range(3)]
        
    def forward(self, inp):
        return self._dot_attention(inp, self.dk)
        
    def _dot_attention(self,inp, dk):
        # Initiate weight matrix for Query, Key and Value
        q,k,v = inp @ self.wq, inp @ self.wk, inp @ self.wv
        logit = (q @ k.transpose(0, -1)) / math.sqrt(dk)
        weigt = torch.softmax(logit, dim=-1)
        res = weigt @ v
        return res 

In [198]:
# test
dk = 8
satten = SelfAttention(inp, dk)
satten(inp)

tensor([[0.5275, 0.7369, 0.8443, 0.7131, 0.4083, 0.2168, 0.2937, 0.7003],
        [0.5221, 0.7481, 0.8478, 0.7108, 0.4193, 0.1963, 0.2772, 0.7107],
        [0.5275, 0.7369, 0.8443, 0.7131, 0.4083, 0.2168, 0.2937, 0.7003]])

### Multi-head Attention

In [260]:
class MultiHeadAttention(nn.Module):
    def __init__(self, inp, dk, nh):
        """
        inp: input
        dk: key dimension
        nh: number of heads
        """
        super().__init__()
        self.inp, self.dk, self.nh = inp, dk, nh
        self.layers = [SelfAttention(inp, dk) for i in range(nh)]
        self.out = torch.rand(dk*nh, dk, requires_grad=True)        
        
    def forward(self, inp):
        res = []
        for l in self.layers:
            res.append(l(inp))
        ccat = torch.cat(res, 1)
        res = ccat @ self.out
        return res        

In [261]:
dk = 8
nh = 6
mul_head = MultiHeadAttention(inp, dk, nh)

In [262]:
mul_head(inp)

tensor([[13.4013, 15.1956, 16.2374, 16.2639, 11.9684, 11.9073, 13.4298, 14.3592],
        [13.3845, 15.1707, 16.2030, 16.2503, 11.9526, 11.9179, 13.4255, 14.3424],
        [13.4013, 15.1956, 16.2374, 16.2639, 11.9684, 11.9073, 13.4298, 14.3592]],
       grad_fn=<MmBackward>)