In [1]:
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import pytorch_lightning as pl

## Model

In [2]:
class Encoder(pl.LightningModule):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Encoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size,
            hidden_size=int(hidden_size / 2),
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=True,
        )
        
        
    def forward(self, x):
        y, h = self.rnn(x)
        
        return y, h    

In [3]:
class Attention(pl.LightningModule):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.softmax = nn.Softmax(dim=-1)
        
        
    def forward(self, decoder_hidden, encoder_hidden, mask=None):
        # |decoder_hidden| = (bs, 1, hidden_size)
        # |encoder_hidden| = (bs, n, hidden_size)
        
        query = self.linear(decoder_hidden)
        # |query| = (bs, 1, hidden_size)
        
        weight = torch.bmm(query, encoder_hidden.transpose(1, 2))
        # |weight| = (bs, 1, hidden_size) dot (bs, hidden_size, n)
        #          = (bs, 1, n)
        
        if mask is not None:
            weight.masked_fill_(mask.unsqueeze(-1), -np.inf)
            
        weight = self.softmax(weight)
        
        value = torch.bmm(weight, encoder_hidden)
        # |value| = (bs, 1, n) dot (bs, n, hidden_size)
        #         = (bs, 1, hidden_size)
        
        return value

In [4]:
class Decoder(pl.LightningModule):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Decoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=False,
        )
        
        
    def forward(self, emb_t, h_prev_tilde, h_prev):
#         print('|emb_t|: ', emb_t.shape)
#         print('|h_prev_tilde|: ', h_prev_tilde.shape if h_prev_tilde is not None else 'None')
#         print('|h_prev|: ', h_prev.shape)
        if h_prev_tilde is None:
            batch_size = emb_t.size(0)
            hidden_size = h_prev.size(-1)
            
            h_prev_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()
#             print('New |h_tilde|:', h_prev_tilde.shape)
        
        x = torch.cat([emb_t, h_prev_tilde], dim=-1)
#         print('|x|:', x.shape)
        
        y, h = self.rnn(x, h_prev)
        
        return y, h

In [5]:
class Generator(pl.LightningModule):
    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()
        
        self.output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, x):
        x = self.output(x)
        y = self.softmax(x)
        
        return y

In [6]:
class Seq2Seq(pl.LightningModule):
    def __init__(self,
                 input_size,
                 word_vec_size,
                 hidden_size,
                 output_size,
                 n_layers=4,
                 dropout_p=.2
                ):
        self.input_size = input_size
        self.word_vec_size = word_vec_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        
        super(Seq2Seq, self).__init__()
        
        self.encoder_emb = nn.Embedding(input_size, word_vec_size)
        self.decoder_emb = nn.Embedding(output_size, word_vec_size)
        
        self.encoder = Encoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        self.attention = Attention(hidden_size)
        self.decoder = Decoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.tanh = nn.Tanh()
        self.generator = Generator(hidden_size, output_size)
        
        
    def merge_z(self, z):
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        batch_size = z.size(1)
        
        z = z.transpose(0, 1).contiguous().view(batch_size,
                                                -1,
                                                self.hidden_size).transpose(0, 1).contiguous()
        # |.transpose(0, 1| = (bs, n_layers * 2, hidden_size / 2)
        # |.view| = (bs, n_layers, hidden_size)
        # |.transpose(0, 1)| = (n_layers, bs, hidden_size)
        # |z| = (n_layers, bs, hidden_size)
        
        return z
        
        
    def forward(self, src, tgt):
        # |src| = (bs, n, |V|)
        # |tgt| = (bs, m, |V|)
        
        batch_size = tgt.size(0)
        
        mask = None
        
        encoder_emb_vec = self.encoder_emb(src)
        # |encoder_emb_vec| = (bs, n, word_vec_size)
        
        encoder_hidden, z = self.encoder(encoder_emb_vec)
        # |encoder_hidden| = (bs, n, hidden_size)
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        
        z = self.merge_z(z)
        # |z| = (n_layers, bs, hidden_size)
        
        decoder_emb_vec = self.decoder_emb(tgt)
        # |decoder_emb_vec| = (bs, m, word_vec_size)
        
        h_tilde = []
        h_t_tilde = None
        decoder_hidden = z
        
        for t in range(tgt.size(1)) :
            
            emb_t = decoder_emb_vec[:, t, :].unsqueeze(1)
            # |emb_t| = (bs, 1, word_vec_size)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden)
            # |decoder_output| = (bs, 1, hidden_size)
            # |decoder_hidden| = (n_layers, bs, hidden_size)
            
            context_vector = self.attention(decoder_output, encoder_hidden, mask)
            # |context_vector| = (bs, 1, hidden_size)
            
            h_t_tilde = torch.cat([decoder_output, context_vector], dim=-1)
            # |h_t_tilde| = (bs, 1, hidden_size * 2)
            
            h_t_tilde = self.concat(h_t_tilde)
            # |h_t_tilde| = (bs, 1, hidden_size)
            
            h_t_tilde = self.tanh(h_t_tilde)
            
            h_tilde += [h_t_tilde]
            
        h_tilde = torch.cat(h_tilde, dim=1)
        # |h_tilde| = (bs, m, hidden_size)
        
        y_hat = self.generator(h_tilde)
        # |y_hat| = (bs, m, output_size)
        
        return y_hat

## Trainer

In [7]:
class CustomModule(pl.LightningModule):
    def __init__(self, model):
        
        super(CustomModule, self).__init__()
        
        self.model = model
        self.crit = nn.NLLLoss()
        self.optimizer = optim.Adam(model.parameters())
    
    
    def forward(self, src, tgt):
        return self.model(src, tgt)
    
    
    def training_step(self, batch, batch_idx):
        x, _ = batch
        logits = self(x[0], x[1])
        loss = self.crit(logits.contiguous().view(-1, logits.size(-1)),
                         x[1].contiguous().view(-1))
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        logits = self(x[0], x[1])
        loss = self.crit(logits.contiguous().view(-1, logits.size(-1)),
                         x[1].contiguous().view(-1))
        metrics = {'val_loss': loss}
        self.log_dict(metrics)
        
        
    def configure_optimizers(self):
        return  self.optimizer
        

## DataLoader

In [8]:
import torchtext
from torchtext.legacy import data
from torch.utils.data import DataLoader
from typing import Optional

In [9]:
class CustomDataLoader:
    def __init__(self, batch_size=64):
        self.batch_size = batch_size
        
        self.SRC = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            fix_length=256,
        )
        self.TGT = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            fix_length=256,
            init_token='<BOS>',
            eos_token='<EOS>'
        )
        
        train, valid = data.TabularDataset.splits(
            path='./kor_eng_translation/',
            train='train copy.tsv',
            validation='valid copy.tsv',
            format='tsv',
            fields=[('src',self.SRC), ('tgt', self.TGT)]
        )
        self.SRC.build_vocab(train, max_size=10000)
        self.TGT.build_vocab(train, max_size=10000)
        
        self.train_loader = data.BucketIterator(
            train,
            batch_size,
            device='cuda:0',
            sort_key = lambda x : len(x.SRC)
        )
        self.valid_loader = data.BucketIterator(
            valid,
            batch_size,
            device='cuda:0',
            sort_key = lambda x : len(x.SRC)
        )

In [10]:
dm = CustomDataLoader(batch_size=256)

In [11]:
len(dm.SRC.vocab), len(dm.TGT.vocab)

(10002, 10004)

In [12]:
input_size = len(dm.SRC.vocab)
output_size = len(dm.TGT.vocab)

In [13]:
model = Seq2Seq(input_size, 256, 256, output_size).cuda()

In [14]:
module = CustomModule(model)
trainer = pl.Trainer(precision=16, max_epochs=1, gpus=1)
trainer.fit(module, dm.train_loader, dm.valid_loader)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params
----------------------------------
0 | model | Seq2Seq | 10.9 M
1 | crit  | NLLLoss | 0     
----------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
21.702    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



In [15]:
torch.save(
    {
        'model': model.state_dict(),
        'opt': module.optimizer.state_dict(),
        'src_vocab': dm.SRC.vocab,
        'tgt_vocab': dm.TGT.vocab,
    }, './model.pth'
)
