http://nlp.seas.harvard.edu/2018/04/03/attention.html
    

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy

import random
import math
import os

In [2]:
SEED = 1

random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [5]:
SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)
TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)

In [6]:
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))

In [7]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [8]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size=BATCH_SIZE,
     sort_key = lambda x : len(x.src),
     sort_within_batch=True,
     device=device)

In [9]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout):
        super().__init__()

        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.encoder_layer = encoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        
        
        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout) 
                                     for _ in range(n_layers)])
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src sent len]
        #src_mask = [batch size, src sent len]
        
        src = self.tok_embedding(src)
        
        #src = [batch size, src sent len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        return src

[](http://nlp.seas.harvard.edu/images/the-annotated-transformer_14_0.png)

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout):
        super().__init__()
        
        self.ln = nn.LayerNorm(hid_dim)
        self.sa = self_attention(hid_dim, n_heads)
        self.pf = positionwise_feedforward(hid_dim, pf_dim)
        self.do = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src sent len, hid dim]
        #src_mask = [batch size, src sent len]
        
        src = self.ln(src + self.do(self.sa(src, src, src, src_mask)))
        
        src = self.ln(src + self.do(self.pf(src)))
        
        return src

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        
        assert hid_dim % n_heads == 0
        
        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc = nn.Linear(hid_dim, hid_dim)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads]))
        
    def forward(self, query, key, value, mask=None):
        
        bsz = query.shape[0]
        
        #query = key = value [batch size, sent len, hid dim]
        
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
        
        #Q, K, V = [batch size, sent len, hid dim]
        
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        
        #Q, K, V = [batch size, n heads, sent len, hid dim // n heads]
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, sent len, sent len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = F.softmax(energy, dim=-1)
        
        #attention = [batch size, n heads, sent len, sent len]
        
        x = torch.matmul(attention, V)
        
        #x = [batch size, n heads, sent len, hid dim // n heads]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, sent len, n heads, hid dim // n heads]
        
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        
        #x = [batch size, src sent len, hid dim]
        
        x = self.fc(x)
        
        #x = [batch size, sent len, hid dim]
        
        return x

In [12]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, hid_dim, pf_dim):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.pf_dim = pf_dim
        
        self.fc_1 = nn.Conv1d(hid_dim, pf_dim, 1)
        self.fc_2 = nn.Conv1d(pf_dim, hid_dim, 1)
        
    def forward(self, x):
        
        #x = [batch size, sent len, hid dim]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, hid dim, sent len]
        
        x = F.relu(self.fc_1(x))
        
        #x = [batch size, ff dim, sent len]
        
        x = self.fc_2(x)
        
        #x = [batch size, hid dim, sent len]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, sent len, hid dim]
        
        return x

In [13]:
input_dim = len(SRC.vocab)
hid_dim = 512
n_layers = 6
n_heads = 8
pf_dim = 2048
dropout = 0.1

enc = Encoder(input_dim, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionwiseFeedforward, dropout)

In [14]:
bsz = 32
seq_len = 25

x = torch.zeros(bsz, seq_len).long()
x_mask = (x != 0).unsqueeze(1).unsqueeze(3)
print(x_mask.shape)
#x_mask = torch.zeros(bsz, 1, seq_len, 1).long()


torch.Size([32, 1, 25, 1])


In [15]:
enc(x, x_mask).shape

torch.Size([32, 25, 512])

In [16]:
"""class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout):
        super().__init__()

        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.encoder_layer = encoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        
        
        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout) 
                                     for _ in range(n_layers)])
        
    def forward(self, src, mask):
        
        #src = [src sent len, batch size]
        
        embedded = self.tok_embedding(src.permute(1, 0))
        
        #embedded = [batch size, src sent len, hid dim]
        
        for layer in self.layers:
            embedded = layer(embedded, mask)
            
        return embedded"""

'class Encoder(nn.Module):\n    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.hid_dim = hid_dim\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.pf_dim = pf_dim\n        self.encoder_layer = encoder_layer\n        self.self_attention = self_attention\n        self.positionwise_feedforward = positionwise_feedforward\n        self.dropout = dropout\n        \n        self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n        \n        \n        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout) \n                                     for _ in range(n_layers)])\n        \n    def forward(self, src, mask):\n        \n        #src = [src sent len, batch size]\n        \n        embedded = self.tok_embedding(src.permute(1, 0))

In [26]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, decoder_layer, self_attention, positionwise_feedforward, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.decoder_layer = decoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        
        self.layers = nn.ModuleList([decoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout)
                                     for _ in range(n_layers)])
        
    def forward(self, trg, src, trg_mask, src_mask):
        
        #trg = [batch_size, trg sent len]
        #src = [batch_size, src sent len]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
        
        trg = self.tok_embedding(trg.permute(1, 0))
        
        #trg = [batch size, trg sent len, hid dim]
        
        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)
            
        return trg

In [27]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout):
        super().__init__()
        
        self.ln = nn.LayerNorm(hid_dim)
        self.sa = self_attention(hid_dim, n_heads)
        self.ea = self_attention(hid_dim, n_heads)
        self.pf = positionwise_feedforward(hid_dim, pf_dim)
        self.do = nn.Dropout(dropout)
        
    def forward(self, trg, src, trg_mask, src_mask):
        
        #trg = [batch size, trg sent len, hid dim]
        #src = [batch size, src sent len, hid dim]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
        
        trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))
        
        trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))
        
        trg = self.ln(trg + self.do(self.pf(trg)))
        
        return trg

In [28]:
input_dim = len(SRC.vocab)
hid_dim = 512
n_layers = 6
n_heads = 8
pf_dim = 2048
dropout = 0.1

enc = Encoder(input_dim, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionwiseFeedforward, dropout)

In [29]:
output_dim = len(TRG.vocab)
hid_dim = 512
n_layers = 6
n_heads = 8
pf_dim = 2048
dropout = 0.1

dec = Decoder(output_dim, hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, dropout)

In [34]:
bsz = 32
seq_len = 25
pad_idx = 1

src = torch.zeros(bsz, seq_len).long()
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(3)
print(src_mask.shape)
#src_mask = [bsz, 1, seq len, 1]

torch.Size([32, 1, 25, 1])


In [35]:
enc_src = enc(src, src_mask)

print(enc_src.shape)

torch.Size([32, 25, 512])


In [61]:
import numpy as np

pad_idx = 1


trg = torch.zeros(bsz, seq_len+5).long()
trg_pad_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(3)

print(trg_mask.shape, trg_mask.dtype)

trg_sub_mask = torch.from_numpy(np.triu(np.ones((1, seq_len+5, seq_len+5)), k=1)) == 0

trg_mask = trg_mask & sub_mask

print(trg_mask[0], trg_mask[0].shape)

torch.Size([32, 1, 30, 30]) torch.uint8
tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [114]:
import numpy as np

pad_idx = 1

bsz = 32
seq_len = 5

trg = torch.zeros(bsz, seq_len).long()
trg_pad_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(3)

print(trg_pad_mask.shape, trg_pad_mask.dtype)

trg_sub_mask = torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1)) == 0

print(trg_sub_mask)

trg_sub_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.uint8))

print(trg_sub_mask, trg_sub_mask.shape, trg_sub_mask.dtype)

trg_mask = trg_pad_mask & trg_sub_mask

print(trg_mask.shape, trg_mask.dtype)

torch.Size([32, 1, 5, 1]) torch.uint8
tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]], dtype=torch.uint8)
tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]], dtype=torch.uint8) torch.Size([5, 5]) torch.uint8
torch.Size([32, 1, 5, 5]) torch.uint8


In [36]:
dec(y, enc_src, y_mask, x_mask)

NameError: name 'y' is not defined

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self):
        super().__init__()
        
    def make_masks(self, src, trg):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
        
        src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(3)
        
        trg_pad_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(3)

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), dtype=torch.uint8))
        
        trg_mask = trg_pad_mask & trg_sub_mask
        
        return src_mask, trg_mask