In [180]:
#### from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator, interleave_keys

def itos(field, batch):  # batch에서 원본 sentence 얻는 함수
    with torch.cuda.device_of(batch):
        #batch = batch.T.tolist()
        batch = batch.tolist()
    batch = [[field.vocab.itos[ind] for ind in ex] for ex in batch]  # denumericalize
    
    def trim(s, t):  # 현재 token ~ <EOS> token 사이의 문장 return
        sentence = []
        for w in s:
            if w == t:
                break
            sentence.append(w)
        return sentence

    batch = [trim(ex, field.eos_token) for ex in batch]  # batch를 문장으로 
    
    def filter_special(tok):
        return tok not in (field.init_token, field.pad_token)

    batch = [list(filter(filter_special, ex)) for ex in batch]
    return batch

In [113]:
SRC = Field(tokenize = "spacy",
            tokenizer_language = "de",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

TRG = Field(tokenize = "spacy",
            tokenizer_language = "en",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), fields = (SRC, TRG))

In [115]:
(
    'SRC', SRC.init_token, SRC.eos_token, SRC.pad_token,
    'TRG', TRG.init_token, TRG.eos_token, TRG.pad_token,
)

('SRC', '<sos>', '<eos>', '<pad>', 'TRG', '<sos>', '<eos>', '<pad>')

In [117]:
SRC.build_vocab(train_data)
TRG.build_vocab(train_data)

In [5]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchtext

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 128

# get iterator (train, valid, test)
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size = BATCH_SIZE,
    device = device)

In [144]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')  # masking with upper triangle
    return torch.from_numpy(subsequent_mask) == 0 # reverse (masking=False, non-masking=True)

class Batch:
    
    "Object for holding a batch of data with masking during training." 
    def __init__(self, src, trg=None, pad=1):
        self.src = src.T
        self.src_mask = (src != pad).unsqueeze(-2)  # source mask, <pad>: False, other tokens: True
        if trg is not None:
            #self.trg = trg[:, :-1]  # target sentence 0 ~ -1
            #self.trg_y = trg[:, 1:]  # target sentence 1 ~ end
            self.trg = trg.T[:, :-1]  # target sentence 0 ~ -1
            self.trg_y = trg.T[:, 1:]  # target sentence 1 ~ end
            self.trg_mask = self.make_std_mask(self.trg, pad) # target mask
            self.ntokens = (self.trg_y != pad).data.sum() # number of tokens
    
    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2) # <pad>: False, other tokens: True, reshape (batch_size, seq_len) -> (batch_size, 1, seq_len)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) # not <pad> && non-masking: True, others: False
        return tgt_mask
    
    def run_epoch(data_iter, model, loss_compute):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens = 0
        total_loss = 0
        tokens = 0
        pad_index = SRC.vocab[SRC.pad_token]
        
        for i, batch_without_mask in enumerate(data_iter):
            # mask 적용
            batch = Batch(batch_without_mask.src, batch_without_mask.trg, pad_index)
            
            out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
            loss = loss_compute(out, batch.trg_y, batch.ntokens)
            total_loss += loss
            total_tokens += batch_ntokens
            tokens += batch.ntokens
            if i % 50 == 1:
                elapsed = time.time() - start
                print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % (i, loss / batch.ntokens, tokens / elapsed))
                start = time.time()
                tokens = 0
            
        return total_loss / total_tokens

In [293]:
# 데이터 확인

pad_index = SRC.vocab[SRC.pad_token]

for batch_without_mask in train_iterator:
    batch = Batch(batch_without_mask.src, batch_without_mask.trg, pad_index)
    print(batch.src.shape, batch.src_mask.shape, batch.trg.shape, batch.trg_mask.shape)
    print(batch.src, batch.src_mask, batch.trg, batch.trg_mask)
    print(itos(TRG, batch.trg))
    
    print(' '.join(itos(SRC, batch.src)[0]))
    print()
    
    print(' '.join(itos(TRG, batch.trg)[0]))
    print()
    
    print(' '.join(itos(TRG, batch.trg)[1]))
    print()
    
    print(' '.join(itos(TRG, batch.trg_y)[0]))
    print()
    
    break

torch.Size([128, 32]) torch.Size([32, 1, 128]) torch.Size([128, 31]) torch.Size([128, 31, 31])
tensor([[   2,    5,   70,  ...,    1,    1,    1],
        [   2,    5, 2446,  ...,    1,    1,    1],
        [   2,    8,   16,  ...,    1,    1,    1],
        ...,
        [   2,    5,   49,  ...,    1,    1,    1],
        [   2,  227,  191,  ...,    1,    1,    1],
        [   2,   43,   45,  ...,    1,    1,    1]], device='cuda:0') tensor([[[ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True,  True,  True]],

        ...,

        [[False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False]]], device='cuda:0') tensor([[   2,    4, 7382,  ...,    1,    1,    1],
        [   2,    4,   41,  ...,    1,    1,    1],
        [   2,    4,   14,  ...,    1,    1,    1],
        .

# Model

In [284]:
class Transformer(nn.Module):
    
    def __init__(self, src_vocab, trg_vocab, seq_len=32, d=512, num_layers=6, h=8):
        super(Transformer, self).__init__()
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.seq_len = 32
        self.d = 512
        self.encoder = Encoder(num_layers, seq_len, d, h)
        self.deccoder = Decoder(num_layers, seq_len, d, h)
        
    def forward(self, src, trg, src_mask, trg_mask):
        return self.decode(trg, trg_mask, self.encode(src, src_mask), src_mask) # Decoder's src: Encoder's output
    
    def encode(self, x, mask):
        return self.encoder(x, mask)
    
    def decode(self, x, mask, encoder_output):
        return self.decoder(x, mask, encoder_output)

In [230]:
class Encoder(nn.Module):
    
    def __init__(self, num_layers, seq_len, d, h):
        super(Encoder, self).__init__()
        self.layers = []
        self.seq_len = seq_len
        self.d = d
        self.h = h
        for i in range(num_layers):
            self.layers.append(EncoderLayer(seq_len, d, h))
    
    def forward(self, x, mask):
        out = x
        for layer in self.layers: 
            out = self.layers[i](out, mask)
        return out

In [231]:
class Decoder(nn.Module):
    
    def __init__(self, num_layers, seq_len, d, h):
        super(Decoder, self).__init__()
        self.layers = []
        self.seq_len = seq_len
        self.d = d
        self.h = h
        for i in range(num_layers):
            self.layers.append(DecoderLayer(seq_len, d, h))
    
    def forward(self, x, mask, encoder_output):
        out = x
        for layer in self.layers: 
            out = self.layers[i](x, mask, encoder_output)
        return out

In [232]:
class EncoderLayer(nn.Module):
    
    def __init__(self, seq_len, d, h):
        super(EncoderLayer, self).__init__()
        self.seq_len = seq_len
        self.d = d
        self.h = h
        self.multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
        self.norm_layer = NormLayer(seq_len)
    
    def forward(self, x, mask):
        out = self.multi_head_attention_layer(x, mask)
        out = self.norm_layer(out)
        return out

In [289]:
class DecoderLayer(nn.Module):
    
    def __init__(self, seq_len, h):
        super(DecoderLayer, self).__init__()
        self.seq_len = seq_len
        self.d = d
        self.h = h
        self.masked_multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
        self.multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
        self.norm_layer = NormLayer(seq_len)
    
    def forward(self, x, mask, encoder_output):
        out = self.masked_multi_head_attention_layer(trg, trg, trg, mask)
        out = self.multi_head_attention_layer(trg, encoder_output, encoder_output, mask)
        out = self.norm_layer(out)
        return out

In [294]:
class MultiHeadAttentionLayer(nn.Module):
    
    def __init__(self, seq_len, d, h):
        super(MultiHeadAttentionLayer, self).__init__()
        self.seq_len = seq_len
        self.d = d
        self.h = h
        self.query_fc_layer = nn.Linear(in_features=d, out_features=d)
        self.key_fc_layer = nn.Linear(in_features=d, out_features=d)
        self.value_fc_layer = nn.Linear(in_features=d, out_features=d)
        self.fc_layer = nn.Linear(in_features=d, out_features=d)
    
    def forward(self, query, key, value, mask=None):
        # query, key, value's shape: (n_batch, seq_len, d)
        
        # reshape (n_batch, seq_len, d) to (n_batch, h, seq_len, d_k)
        def transform(x, fc_layer):
            # x's shape: (n_batch, seq_len, d)
            n_batch = x.shape[0]
            out = fc_layer(x) # out's shape: (n_batch, seq_len, d)
            out = out.view(n_batch, self.seq_len, self.h, self.d//self.h) # out's shape: (n_batch, seq_len, h, d_k)
            out = out.transpose(1, 2) # out's shape: (n_batch, h, seq_len, d_k)
            return out
        
        query = transform(query, self.query_fc_layer)      # out's shape: (n_batch, h, seq_len, d_k)
        key = transform(key, self.key_fc_layer)
        value = transform(value, self.value_fc_layer)
        out = calculate_attention(query, key, value, mask) # out's shape: (n_batch, h, seq_len, d_k)
        out = out.transpose(1, 2)  # out's shape: (n_batch, seq_len, h, d_k)
        out = out.view(n_batch, seq_len, d)  # out's shape: (n_batch, seq_len, d)
        out = self.fc_layer(out)  # out's shape: (n_batch, seq_len, d)
        return out
    
    @staticmethod
    def calculate_attention(query, key, value, mask): 
        d_k = query.size(-1) # key's dimension, query, key, and value's dimensions are same value
        score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T
        score = score / math.sqrt(d_k)  # scaling
        if mask is not None:
            score = score.masked_fill(mask==0, -1e9)  # masking (Decoder's Masked Multi-Attention Layer)
        out = F.softmax(scores, dim = -1) # get softmax score
        out = torch.matmul(out, value) # score x V
        return out

1. residual connection
2. norm layer
3. embedding

In [220]:
class EncoderOrDecoder(nn.Module):
    
    def __init__(self, kind, num_layers, seq_len, d, h):
        super(EncoderOrDecoder, self).__init__()
        self.kind = kind
        self.layers = []
        self.seq_len = seq_len
        self.d = d
        self.h = h
        for i in range(num_layers):
            self.layers.append(EncoderOrDecoderLayer(kind, seq_len, d, h))
        self.layers = nn.ModuleList(self.layers)
    
    def forward(self, x, src_mask, encoder_output, trg_mask):
        out = x
        for layer in self.layers: 
            out = self.layers[i](out, src_mask, encoder_output, trg_mask)
        return out

In [207]:
class EncoderOrDecoderLayer(nn.Module):
    
    def __init__(self, kind, seq_len, d, h):
        super(EncoderOrDecodeLayer, self).__init__()
        self.kind = kind
        self.seq_len = seq_len
        self.d = d
        self.h = h
        if kind == 'Encoder':
            self.multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
        elif kind == 'Decoder':
            self.masked_multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
            self.multi_head_attention_layer = MultiHeadAttentionLayer(seq_len, d, h)
        self.norm_layer = NormLayer(seq_len)
    
    def forward(self, src, src_mask):
        if kind == 'Encoder':
            out = self.multi_head_attention_layer(src, src_mask)
        elif kind == 'Decoder':
            out = self.masked_multi_head_attention_layer(trg, trg_mask)
            out = self.multi_head_attention_layer(src, src_mask, trg, trg_mask)
        out = self.norm_layer(out)
        return out