In [1]:
import numpy as np

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

from torch.nn import functional as F
from torch import optim

import pytorch_lightning as pl

## Model

In [2]:
class Encoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Encoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size,
            hidden_size=int(hidden_size / 2),
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=True,
        )
        
        
    def forward(self, emb):
        if isinstance(emb, tuple):
            x, lengths = emb
            x = pack(x, lengths.tolist(), batch_first=True)
        
        else:
            x = emb
            
        y, h = self.rnn(x)
        
        if isinstance(emb, tuple):
            y, _ = unpack(y, batch_first=True)
        
        return y, h    

In [3]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        
        
    def forward(self, decoder_hidden, encoder_hidden, mask=None):
        # |decoder_hidden| = (bs, 1, hidden_size)
        # |encoder_hidden| = (bs, n, hidden_size)
        
        query = self.linear(decoder_hidden)
        # |query| = (bs, 1, hidden_size)
        
        weight = torch.bmm(query, encoder_hidden.transpose(1, 2))
        # |weight| = (bs, 1, hidden_size) dot (bs, hidden_size, n)
        #          = (bs, 1, n)
        
        if mask is not None:
            weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
            
        weight = self.softmax(weight)
        
        value = torch.bmm(weight, encoder_hidden)
        # |value| = (bs, 1, n) dot (bs, n, hidden_size)
        #         = (bs, 1, hidden_size)
        
        return value

In [4]:
class Decoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Decoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=False,
        )
        
        
    def forward(self, emb_t, h_prev_tilde, h_prev):
        batch_size = emb_t.size(0)
        hidden_size = h_prev.size(-1)
        
        if h_prev_tilde is None:
            h_prev_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()
        
        x = torch.cat([emb_t, h_prev_tilde], dim=-1)
        
        y, h = self.rnn(x, h_prev)
        
        return y, h

In [5]:
class Generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()
        
        self.output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, x):
        x = self.output(x)
        y = self.softmax(x)
        
        return y

In [6]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 input_size,
                 word_vec_size,
                 hidden_size,
                 output_size,
                 n_layers=4,
                 dropout_p=.2
                ):
        self.input_size = input_size
        self.word_vec_size = word_vec_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        
        super(Seq2Seq, self).__init__()
        
        self.encoder_emb = nn.Embedding(input_size, word_vec_size)
        self.decoder_emb = nn.Embedding(output_size, word_vec_size)
        
        self.encoder = Encoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        self.attention = Attention(hidden_size)
        self.decoder = Decoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.tanh = nn.Tanh()
        self.generator = Generator(hidden_size, output_size)
        
        
    def merge_z(self, z):
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        batch_size = z.size(1)
        
        z = z.transpose(0, 1).contiguous().view(batch_size,
                                                -1,
                                                self.hidden_size).transpose(0, 1).contiguous()
        # |.transpose(0, 1| = (bs, n_layers * 2, hidden_size / 2)
        # |.view| = (bs, n_layers, hidden_size)
        # |.transpose(0, 1)| = (n_layers, bs, hidden_size)
        # |z| = (n_layers, bs, hidden_size)
        
        return z

    
    def generate_mask(self, x, length):
        mask = []

        max_length = max(length)
        for l in length:
            if max_length - l > 0:
                # If the length is shorter than maximum length among samples, 
                # set last few values to be 1s to remove attention weight.
                mask += [torch.cat([x.new_ones(1, l).zero_(),
                                    x.new_ones(1, (max_length - l))
                                    ], dim=-1)]
            else:
                # If the length of the sample equals to maximum length among samples, 
                # set every value in mask to be 0.
                mask += [x.new_ones(1, l).zero_()]

        mask = torch.cat(mask, dim=0).bool()

        return mask
    
        
    def forward(self, src, tgt):
        # |src| = (bs, n, |V|)
        # |tgt| = (bs, m, |V|)
        
        batch_size = tgt.size(0)
        
        mask = None
        x_length = None
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
        
        else :
            x = src
        
        encoder_emb_vec = self.encoder_emb(x)
        # |encoder_emb_vec| = (bs, n, word_vec_size)
        
        encoder_hidden, z = self.encoder((encoder_emb_vec, x_length))
        # |encoder_hidden| = (bs, n, hidden_size)
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        
        z = self.merge_z(z)
        # |z| = (n_layers, bs, hidden_size)
        
        decoder_emb_vec = self.decoder_emb(tgt)
        # |decoder_emb_vec| = (bs, m, word_vec_size)
        
        h_tilde = []
        
        h_t_tilde = None
        decoder_hidden = z
        
        for t in range(tgt.size(1)) :
            
            emb_t = decoder_emb_vec[:, t, :].unsqueeze(1)
            # |emb_t| = (bs, 1, word_vec_size)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden)
            # |decoder_output| = (bs, 1, hidden_size)
            # |decoder_hidden| = (n_layers, bs, hidden_size)
            
            context_vector = self.attention(decoder_output, encoder_hidden, mask)
            # |context_vector| = (bs, 1, hidden_size)
            
            h_t_tilde = torch.cat([decoder_output, context_vector], dim=-1)
            # |h_t_tilde| = (bs, 1, hidden_size * 2)
            
            h_t_tilde = self.concat(h_t_tilde)
            # |h_t_tilde| = (bs, 1, hidden_size)
            
            h_t_tilde = self.tanh(h_t_tilde)
            
            h_tilde += [h_t_tilde]
            
        h_tilde = torch.cat(h_tilde, dim=1)
        # |h_tilde| = (bs, m, hidden_size)
        
        y_hat = self.generator(h_tilde)
        # |y_hat| = (bs, m, output_size)
        
        return y_hat
    
    
    def search(self, src, is_greedy=True, max_length=255):
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            
        else:
            x, x_length = src, None
            mask = None
            
        batch_size = x.size(0)
        
        encoder_emb_vec = self.encoder_emb(x)
        encoder_hidden, z = self.encoder(encoder_emb_vec)
        
        z = self.merge_z(z)
        
        y = x.new(batch_size, 1).zero_() + 2 # index of <BOS>
        
        is_decoding = x.new_ones(batch_size, 1).bool()
        
        h_t_tilde, y_hats, indice = None, [], []
        decoder_hidden = z
        
        
        while is_decoding.sum() > 0 and len(indice) < max_length:
            emb_t = self.decoder_emb(y)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden)
            context_vector = self.attention(decoder_output, encoder_hidden, mask)

            h_t_tilde = torch.cat([decoder_output, context_vector], dim=-1)
            h_t_tilde = self.concat(h_t_tilde)
            h_t_tilde = self.tanh(h_t_tilde)
            
            y_hat = self.generator(h_t_tilde)
            
            y_hats += [y_hat]
            
            if is_greedy:
                y = y_hat.argmax(dim=-1)
                
            else:
                # take random sampling
                y = torch.multinomial(y_hat.exp().view(batch_size, -1), 1)
                
            # 이번 step에 EOS가 없을 경우 PAD
            y = y.masked_fill_(~is_decoding, 0) # index of <PAD>
            # 있을 경우 EOS
            is_decoding = is_decoding * torch.ne(y, 3) # index of <EOS>
            
            indice += [y]
            
        y_hats = torch.cat(y_hats, dim=1)
        indice = torch.cat(indice, dim=1)
        
        return y_hats, indice

## Trainer

In [41]:
class CustomModule(pl.LightningModule):
    def __init__(self, model, output_size):
        
        super(CustomModule, self).__init__()
        
        
        def get_crit(output_size, pad_index=1):
            loss_weight = torch.ones(output_size)
            loss_weight[pad_index] = 0.
            crit = nn.NLLLoss(
                weight=loss_weight,
                reduction='sum'
            )

            return crit
        
        self.model = model
        self.crit = get_crit(output_size)
        self.optimizer = optim.Adam(self.model.parameters())
    
    
    def forward(self, src, tgt):
        return self.model(src, tgt)
    
    
    def training_step(self, batch, batch_idx):
        mini_batch, _ = batch
        x, y = mini_batch[0], mini_batch[1][:, 1:]
        y_hat = self(x, mini_batch[1][:, :-1])
        loss = self.crit(y_hat.contiguous().view(-1, y_hat.size(-1)),
                         y.contiguous().view(-1))
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        mini_batch, _ = batch
        x, y = mini_batch[0], mini_batch[1][:, 1:]
        y_hat = self(x, mini_batch[1][:, :-1])
        loss = self.crit(y_hat.contiguous().view(-1, y_hat.size(-1)),
                         y.contiguous().view(-1))
        metrics = {'val_loss': loss}
        self.log_dict(metrics)
        
        
    def configure_optimizers(self):
        return  self.optimizer

## DataLoader

In [8]:
import torchtext
from torchtext.legacy import data
from torch.utils.data import DataLoader
from typing import Optional

In [19]:
class CustomDataLoader:
    def __init__(self, batch_size=64, max_length=70, shuffle=True, train=True):
        self.batch_size = batch_size
        
        self.SRC = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            preprocessing=lambda x : x if len(x) < max_length else x[:max_length],
            include_lengths=True,
        )
        self.TGT = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            preprocessing=lambda x : x if len(x) < max_length else x[:max_length],
            init_token='<BOS>',
            eos_token='<EOS>'
        )
        
        if train :
            train, valid = data.TabularDataset.splits(
                path='./kor_eng_translation/',
                train='train.tsv',
                validation='valid.tsv',
                format='tsv',
                fields=[('src',self.SRC), ('tgt', self.TGT)]
            )

            self.train_loader = data.BucketIterator(
                train,
                batch_size,
                device='cuda:0',
                shuffle=shuffle,
                sort_key = lambda x : len(x.tgt) + (80 * len(x.src)),
                sort_within_batch=True,
            )
            self.valid_loader = data.BucketIterator(
                valid,
                batch_size,
                device='cuda:0',
                sort_key = lambda x : len(x.tgt) + (80 * len(x.src)),
                sort_within_batch=True,
            )
            
            self.SRC.build_vocab(train, max_size=30000, min_freq=5)
            self.TGT.build_vocab(train, max_size=30000, min_freq=5)
            
    def load_vocab(self, src_vocab, tgt_vocab):
        self.SRC.vocab = src_vocab
        self.TGT.vocab = tgt_vocab

## Train

In [20]:
batch_size = 128
dm = CustomDataLoader(batch_size=batch_size)

In [21]:
input_size = len(dm.SRC.vocab)
output_size = len(dm.TGT.vocab)

In [22]:
input_size, output_size

(30002, 30004)

In [23]:
model = Seq2Seq(input_size, 512, 512, output_size).cuda()

In [42]:
module = CustomModule(model, output_size)
trainer = pl.Trainer(max_epochs=1, gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [43]:
trainer.fit(module, dm.train_loader, dm.valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params
----------------------------------
0 | model | Seq2Seq | 58.7 M
1 | crit  | NLLLoss | 0     
----------------------------------
58.7 M    Trainable params
0         Non-trainable params
58.7 M    Total params
234.892   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



In [44]:
torch.save(
    {
        'model': model.state_dict(),
        'opt': module.optimizer.state_dict(),
        'src_vocab': dm.SRC.vocab,
        'tgt_vocab': dm.TGT.vocab,
    }, './model.pth'
)

## Translate

In [45]:
saved_data = torch.load(
    './model.pth',
    map_location='cuda:0' if torch.cuda.is_available else 'cpu'
)
loader = CustomDataLoader(train=False)
loader.load_vocab(saved_data['src_vocab'], saved_data['tgt_vocab'])

input_size = len(loader.SRC.vocab)
output_size = len(loader.TGT.vocab)

model2 = Seq2Seq(input_size, 512, 512, output_size).cuda()
model2.load_state_dict(saved_data['model'])
model2.eval()

Seq2Seq(
  (encoder_emb): Embedding(30002, 512)
  (decoder_emb): Embedding(30004, 512)
  (encoder): Encoder(
    (rnn): GRU(512, 256, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)
  )
  (attention): Attention(
    (linear): Linear(in_features=512, out_features=512, bias=False)
    (softmax): Softmax(dim=-1)
  )
  (decoder): Decoder(
    (rnn): GRU(1024, 512, num_layers=4, batch_first=True, dropout=0.2)
  )
  (concat): Linear(in_features=1024, out_features=512, bias=True)
  (tanh): Tanh()
  (generator): Generator(
    (output): Linear(in_features=512, out_features=30004, bias=True)
    (softmax): LogSoftmax(dim=-1)
  )
)

In [46]:
from konlpy.tag import Mecab
tokenizer = Mecab()

In [62]:
def to_text(indice, vocab):
    lines = []
    
    for i in range(len(indice)):
        line = []
        for j in range(len(indice[i])):
            index = indice[i][j]
            
            if index == 3:
                break
                
            else :
                line += [vocab.itos[index]]
        line = ' '.join(line)
        lines += [line]
        
    return lines
        
with torch.no_grad():
    sentence = [input()]
    
    x = dm.SRC.numericalize(
        ([tokenizer.morphs(sentence[0])], [len(tokenizer.morphs(sentence[0]))]),
        device='cuda:0' if torch.cuda.is_available else 'cpu'
    )
    
    y_hats, indice = model2.search(x)
    output = to_text(indice, loader.TGT.vocab)
    
    print(output)

나는 내일 학교에 가야 한다.
["i ' m going to go to school ."]
