### Transformer

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time

In [2]:
SEED = 11747

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [6]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [7]:
SRC = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TGT = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

In [8]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.en', '.de'), 
                                                    fields = (SRC, TGT))

In [9]:
print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")

Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000


In [10]:
SRC.build_vocab(train_data, min_freq = 2)
TGT.build_vocab(train_data, min_freq = 2)

print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (en) vocabulary: {len(TGT.vocab)}")

Unique tokens in source (de) vocabulary: 5893
Unique tokens in target (en) vocabulary: 7855


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

In [14]:
class Encoder(nn.Module):
    def __init__(self, ninp, nhid, nlayers, nheads, pfdim, dropout, device, max_len=100):
        super(Encoder, self).__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(ninp, nhid)
        self.pos_embedding = nn.Embedding(max_len, nhid)
        self.layers = nn.ModuleList([EncoderLayer(nhid, nheads, pfdim, dropout,device)
                                     for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([nhid])).to(device)
        
    def forward(self, src, src_mask):
        # src: (b, s)
        # src_mask: (b, s)
        b, s = src.shape
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(b, 1).to(self.device)
        src = self.dropout((self.tok_embedding(src)*self.scale+self.pos_embedding(pos)))
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

In [13]:
class EncoderLayer(nn.Module):
    def __init__(self, ninp, nhid, nlayers, nheads, pfdim, dropout, device, max_len=100):
        super(EncoderLayer, self).__init__()
        self.self_attn_layer_norm = nn.LayerNorm(nhid)
        self.ff_layer_norm = nn.LayerNorm(nhid)
        self.self_attn = MultiHeadAttentionLayer(nhid, nheads, dropout, device)
        self.positionwise_ff = PositionwiseFeedforwardLayer(nhid, pfdim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        # src: (b, s, nhid)
        # src_mask: (b, s)
        _src, _ = self.self_attn(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src+self.dropout(_src))
        # src: (b, s, nhid)
        _src = self.positionwise_ff(src)
        src = self.ff_layer_norm(src+self.dropout(_src))
        # src: (b, s, nhid)
        return src

In [15]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, nhid, nheads, dropout, device):
        super(MultiHeadAttentionLayer, self).__init__()
        assert(nhid % nheads == 0)
        
        self.nhid = nhid
        self.nheads = nheads
        self.head_dim = nhid // nheads
        
        self.fc_q = nn.Linear(nhid, nhid)
        self.fc_k = nn.Linear(nhid, nhid)
        self.fc_v = nn.Linear(nhid, nhid)
        self.fc_o = nn.Linear(nhid, nhid)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])),to(device)
        
    def forward(self, query, key, value, mask=None):
        b = query.shape[0]
        
        # query: (b, ql, nhid), key: (b, kl, nhid), value: (b, vl, nhid)
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        Q = Q.view(b, -1, self.nheads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(b, -1, self.nheads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(b, -1, self.nheads, self.head_dim).permute(0, 2, 1, 3)
        # Q: (b, nheads, ql, head_dim)
        # K: (b, nheads, kl, head_dim)
        # V: (b, nheads, vl, head_dim)
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        # energy: (b, nheads, ql, kl)
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        
        attn = torch.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attn), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(b, -1, self.nhid)
        x = self.fc_o(x)
        return x, attn