In [1]:

#stepが小さい、バッチサイズ小さい、positional embedingなし
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator

import spacy

import random
import math
import os
import time


In [2]:
SEED = 1

random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [4]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [5]:
SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)
TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)

In [6]:
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))
#train_data, valid_data, test_data = IWSLT.splits(exts=('.de', '.en'), fields=(SRC, TRG))#, filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN

In [7]:
SRC.build_vocab(train_data, min_freq=3)
TRG.build_vocab(train_data, min_freq=3)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size=BATCH_SIZE,
     device=device)

In [10]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, hid_dim, max_len=5000):
        super().__init__()
        
        # Compute the positional encodings once in log space.
        
        pe = torch.zeros(max_len, hid_dim).float().to(device)
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)#.to(device)
        div_term = (torch.arange(0, hid_dim, 2).float() *
                    -(math.log(10000.0) / hid_dim)).exp()#.to(device)
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        
        x = self.pe[:, :x.size(1)]
        return x

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, PositionalEncoding, positionwise_feedforward, dropout, device):
        super().__init__()

        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.encoder_layer = encoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim)
        
        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device) 
                                     for _ in range(n_layers)])
        
        self.do = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src sent len]
        #src_mask = [batch size, src sent len]
        
        #pos = torch.arange(0, src.shape[1]).unsqueeze(0).repeat(src.shape[0], 1).to(self.device)
        
        src = self.do((self.tok_embedding(src) * self.scale) + self.pos_embedding(src))
        
        #src = [batch size, src sent len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        return src

In [12]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):
        super().__init__()
        
        self.ln = nn.LayerNorm(hid_dim)
        self.sa = self_attention(hid_dim, n_heads, dropout, device)
        self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)
        self.do = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src sent len, hid dim]
        #src_mask = [batch size, src sent len]
        
        src = self.ln(src + self.do(self.sa(src, src, src, src_mask)))
        
        src = self.ln(src + self.do(self.pf(src)))
        
        return src

In [13]:
class SelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        
        assert hid_dim % n_heads == 0
        
        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc = nn.Linear(hid_dim, hid_dim)
        
        self.do = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
        
    def forward(self, query, key, value, mask=None):
        
        bsz = query.shape[0]
        
        #query = key = value [batch size, sent len, hid dim]
                
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
        
        #Q, K, V = [batch size, sent len, hid dim]
        
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        
        #Q, K, V = [batch size, n heads, sent len, hid dim // n heads]
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, sent len, sent len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = self.do(F.softmax(energy, dim=-1))
        
        #attention = [batch size, n heads, sent len, sent len]
        
        x = torch.matmul(attention, V)
        
        #x = [batch size, n heads, sent len, hid dim // n heads]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, sent len, n heads, hid dim // n heads]
        
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        
        #x = [batch size, src sent len, hid dim]
        
        x = self.fc(x)
        
        #x = [batch size, sent len, hid dim]
        
        return x

In [14]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.pf_dim = pf_dim
        
        self.fc_1 = nn.Conv1d(hid_dim, pf_dim, 1)
        self.fc_2 = nn.Conv1d(pf_dim, hid_dim, 1)
        
        self.do = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, sent len, hid dim]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, hid dim, sent len]
        
        x = self.do(F.relu(self.fc_1(x)))
        
        #x = [batch size, ff dim, sent len]
        
        x = self.fc_2(x)
        
        #x = [batch size, hid dim, sent len]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, sent len, hid dim]
        
        return x


In [15]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, decoder_layer, self_attention, positionwise_feedforward, dropout, device):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.decoder_layer = decoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(1000, hid_dim)
        
        self.layers = nn.ModuleList([decoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device)
                                     for _ in range(n_layers)])
        
        self.fc = nn.Linear(hid_dim, output_dim)
        
        self.do = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, src, trg_mask, src_mask):
        
        #trg = [batch_size, trg sent len]
        #src = [batch_size, src sent len]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
        
        pos = torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(self.device)
                
        trg = self.do((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        
        #trg = [batch size, trg sent len, hid dim]
        
        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)
            
        return self.fc(trg)

In [16]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):
        super().__init__()
        
        self.ln = nn.LayerNorm(hid_dim)
        self.sa = self_attention(hid_dim, n_heads, dropout, device)
        self.ea = self_attention(hid_dim, n_heads, dropout, device)
        self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)
        self.do = nn.Dropout(dropout)
        
    def forward(self, trg, src, trg_mask, src_mask):
        
        #trg = [batch size, trg sent len, hid dim]
        #src = [batch size, src sent len, hid dim]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
                
        trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))
                
        trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))
        
        trg = self.ln(trg + self.do(self.pf(trg)))
        
        return trg

In [17]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device
        
    def make_masks(self, src, trg):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
        
        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
        
        trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(3)

        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), dtype=torch.uint8, device=self.device))
        
        trg_mask = trg_pad_mask & trg_sub_mask
        
        return src_mask, trg_mask
    
    def forward(self, src, trg):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
                
        src_mask, trg_mask = self.make_masks(src, trg)
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src sent len, hid dim]
                
        out = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #out = [batch size, trg sent len, output dim]
        
        return out


In [18]:
input_dim = len(SRC.vocab)
hid_dim = 512
n_layers = 6
n_heads = 8
pf_dim = 2048
dropout = 0.1

enc = Encoder(input_dim, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionalEncoding, PositionwiseFeedforward, dropout, device)

In [19]:
output_dim = len(TRG.vocab)
hid_dim = 512
n_layers = 6
n_heads = 8
pf_dim = 2048
dropout = 0.1

dec = Decoder(output_dim, hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, dropout, device)

In [20]:
pad_idx = SRC.vocab.stoi['<pad>']

model = Seq2Seq(enc, dec, pad_idx, device).to(device)

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 52,054,476 trainable parameters


In [22]:
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [23]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

In [24]:
optimizer = NoamOpt(hid_dim, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [25]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg
        
        optimizer.optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])
                
        #output = [batch size, trg sent len - 1, output dim]
        #trg = [batch size, trg sent len]
            
        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)
                
        #output = [batch size * trg sent len - 1, output dim]
        #trg = [batch size * trg sent len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [26]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])
            
            #output = [batch size, trg sent len - 1, output dim]
            #trg = [batch size, trg sent len]
            
            output = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg sent len - 1, output dim]
            #trg = [batch size * trg sent len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [27]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [28]:
N_EPOCHS = 100
CLIP = 1
SAVE_DIR = 'models'
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'transformer-seq2seq.pt')

best_valid_loss = float('inf')

if not os.path.isdir(f'{SAVE_DIR}'):
    os.makedirs(f'{SAVE_DIR}')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
    
    print(f'| Epoch: {epoch+1:03} | Time: {epoch_mins}m {epoch_secs}s| Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f} |')

| Epoch: 001 | Time: 0m 53s| Train Loss: 5.772 | Train PPL: 321.177 | Val. Loss: 4.058 | Val. PPL:  57.838 |
| Epoch: 002 | Time: 0m 54s| Train Loss: 3.760 | Train PPL:  42.956 | Val. Loss: 3.230 | Val. PPL:  25.279 |
| Epoch: 003 | Time: 0m 55s| Train Loss: 3.173 | Train PPL:  23.877 | Val. Loss: 2.811 | Val. PPL:  16.631 |
| Epoch: 004 | Time: 0m 55s| Train Loss: 2.816 | Train PPL:  16.705 | Val. Loss: 2.588 | Val. PPL:  13.304 |
| Epoch: 005 | Time: 0m 55s| Train Loss: 2.559 | Train PPL:  12.923 | Val. Loss: 2.410 | Val. PPL:  11.133 |
| Epoch: 006 | Time: 0m 55s| Train Loss: 2.372 | Train PPL:  10.717 | Val. Loss: 2.307 | Val. PPL:  10.041 |
| Epoch: 007 | Time: 0m 55s| Train Loss: 2.241 | Train PPL:   9.402 | Val. Loss: 2.286 | Val. PPL:   9.834 |
| Epoch: 008 | Time: 0m 55s| Train Loss: 2.163 | Train PPL:   8.700 | Val. Loss: 2.286 | Val. PPL:   9.832 |
| Epoch: 009 | Time: 0m 55s| Train Loss: 2.121 | Train PPL:   8.339 | Val. Loss: 2.325 | Val. PPL:  10.226 |
| Epoch: 010 | Time

KeyboardInterrupt: 

In [29]:
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

| Test Loss: 2.283 | Test PPL:   9.809 |


In [49]:
batch=next(iter(test_iterator))
num=2
output = model(batch.src[num].unsqueeze(0), batch.trg[num,:-1].unsqueeze(0))
prediction = torch.max(output,2)[1]
print(prediction)
for j in range(1):
    print([TRG.vocab.itos[batch.trg[num,:-1].unsqueeze(0)[j][i].item()] for i in range(len(batch.trg[num,:-1].unsqueeze(0)[j]))])
    print([TRG.vocab.itos[prediction[j][i].item()] for i in range(len(prediction[j]))])

tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2]], device='cuda:0')
tensor([[   4,   34, 1296,    4,  264,  432,    4,  432,  432,    5,    5,    5,
            5]], device='cuda:0')
['<sos>', 'a', 'boy', 'riding', 'a', 'skateboard', 'on', 'a', 'skateboarding', 'ramp', '<eos>', '<pad>', '<pad>']
['a', 'boy', 'skateboards', 'a', 'skateboard', 'ramp', 'a', 'ramp', 'ramp', '.', '.', '.', '.']


In [96]:
def greedy_decode(model, src, max_len, start_symbol):
    model.eval()
    t = torch.tensor([[start_symbol]]).to(device)
    src_mask = model.make_masks(src,t)[0]
    
    memory = model.encoder(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).to(device)
    
    for i in range(max_len-1):
        
        prob = model.decoder(ys, memory, model.make_masks(src,ys)[1] , src_mask)
        
        #prob = model.generator(out[:, -1])
        
        _, next_word = torch.max(prob, dim = 2)
        
        next_word = next_word.data[0][0]
        
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        
        
    return ys



In [97]:
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
batch=next(iter(test_iterator))
src = batch.src[0].unsqueeze(0)
print(src)
    
out = greedy_decode(model, src, max_len=60, start_symbol=TRG.vocab.stoi["<s>"])

print("Translation:", end="\t")
for i in range(1, out.size(1)):
        
    sym = TRG.vocab.itos[out[0, i]]
    if sym == "</s>": break
    print(sym, end =" ")
    print()
    print("Target:", end="\t")
for i in range(1, batch.trg.size(1)):
    sym = TRG.vocab.itos[batch.trg.data[0, i]]
    if sym == "</s>": break
    print(sym, end =" ")
print()
#break

tensor([[  2,  18,   0, 121, 157,  42,  33, 123,   4,   3]], device='cuda:0')
Translation:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 
Target:	through 