## BeamSearch
- Generation == Search Problem
- Greedy Search: 지금의 최선이 나중에는 나쁜 선택이 될 수 있음
- Beam Search: top-k를 tracking하여 greedy search를 조금 더 안전하게 수행
- Beam Search를 병렬처리하면 속도와 성능 모두 만족스러울 것

In [1]:
import numpy as np

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

from torch.nn import functional as F
from torch import optim

import pytorch_lightning as pl

## Model

In [6]:
class Encoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Encoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size,
            hidden_size=int(hidden_size / 2),
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=True,
        )
        
        
    def forward(self, emb):
        if isinstance(emb, tuple):
            x, lengths = emb
            x = pack(x, lengths.tolist(), batch_first=True)
        
        else:
            x = emb
            
        y, h = self.rnn(x)
        
        if isinstance(emb, tuple):
            y, _ = unpack(y, batch_first=True)
        
        return y, h    


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        
        
    def forward(self, decoder_hidden, encoder_hidden, mask=None):
        # |decoder_hidden| = (bs, 1, hidden_size)
        # |encoder_hidden| = (bs, n, hidden_size)
        
        query = self.linear(decoder_hidden)
        # |query| = (bs, 1, hidden_size)
        
        weight = torch.bmm(query, encoder_hidden.transpose(1, 2))
        # |weight| = (bs, 1, hidden_size) dot (bs, hidden_size, n)
        #          = (bs, 1, n)
        
        if mask is not None:
            weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
            
        weight = self.softmax(weight)
        
        value = torch.bmm(weight, encoder_hidden)
        # |value| = (bs, 1, n) dot (bs, n, hidden_size)
        #         = (bs, 1, hidden_size)
        
        return value
    
    
class Decoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Decoder, self).__init__()
        
        self.rnn = nn.GRU(
            input_size=word_vec_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=False,
        )
        
        
    def forward(self, emb_t, h_prev_tilde, h_prev):
        batch_size = emb_t.size(0)
        hidden_size = h_prev.size(-1)
        
        if h_prev_tilde is None:
            h_prev_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()
        
        x = torch.cat([emb_t, h_prev_tilde], dim=-1)
        
        y, h = self.rnn(x, h_prev)
        
        return y, h
    
    
class Generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()
        
        self.output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, x):
        x = self.output(x)
        y = self.softmax(x)
        
        return y
    
    
class Seq2Seq(nn.Module):
    def __init__(self,
                 input_size,
                 word_vec_size,
                 hidden_size,
                 output_size,
                 n_layers=4,
                 dropout_p=.2
                ):
        self.input_size = input_size
        self.word_vec_size = word_vec_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        
        super(Seq2Seq, self).__init__()
        
        self.encoder_emb = nn.Embedding(input_size, word_vec_size)
        self.decoder_emb = nn.Embedding(output_size, word_vec_size)
        
        self.encoder = Encoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        self.attention = Attention(hidden_size)
        self.decoder = Decoder(word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p)
        
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.tanh = nn.Tanh()
        self.generator = Generator(hidden_size, output_size)
        
        
    def merge_z(self, z):
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        batch_size = z.size(1)
        
        z = z.transpose(0, 1).contiguous().view(batch_size,
                                                -1,
                                                self.hidden_size).transpose(0, 1).contiguous()
        # |.transpose(0, 1| = (bs, n_layers * 2, hidden_size / 2)
        # |.view| = (bs, n_layers, hidden_size)
        # |.transpose(0, 1)| = (n_layers, bs, hidden_size)
        # |z| = (n_layers, bs, hidden_size)
        
        return z

    
    def generate_mask(self, x, length):
        mask = []

        max_length = max(length)
        for l in length:
            if max_length - l > 0:
                # If the length is shorter than maximum length among samples, 
                # set last few values to be 1s to remove attention weight.
                mask += [torch.cat([x.new_ones(1, l).zero_(),
                                    x.new_ones(1, (max_length - l))
                                    ], dim=-1)]
            else:
                # If the length of the sample equals to maximum length among samples, 
                # set every value in mask to be 0.
                mask += [x.new_ones(1, l).zero_()]

        mask = torch.cat(mask, dim=0).bool()

        return mask
    
        
    def forward(self, src, tgt):
        # |src| = (bs, n, |V|)
        # |tgt| = (bs, m, |V|)
        
        batch_size = tgt.size(0)
        
        mask = None
        x_length = None
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
        
        else :
            x = src
        
        encoder_emb_vec = self.encoder_emb(x)
        # |encoder_emb_vec| = (bs, n, word_vec_size)
        
        encoder_hidden, z = self.encoder((encoder_emb_vec, x_length))
        # |encoder_hidden| = (bs, n, hidden_size)
        # |z| = (n_layers * 2, bs, hidden_size / 2)
        
        z = self.merge_z(z)
        # |z| = (n_layers, bs, hidden_size)
        
        decoder_emb_vec = self.decoder_emb(tgt)
        # |decoder_emb_vec| = (bs, m, word_vec_size)
        
        h_tilde = []
        
        h_t_tilde = None
        decoder_hidden = z
        
        for t in range(tgt.size(1)) :
            
            emb_t = decoder_emb_vec[:, t, :].unsqueeze(1)
            # |emb_t| = (bs, 1, word_vec_size)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden)
            # |decoder_output| = (bs, 1, hidden_size)
            # |decoder_hidden| = (n_layers, bs, hidden_size)
            
            context_vector = self.attention(decoder_output, encoder_hidden, mask)
            # |context_vector| = (bs, 1, hidden_size)
            
            h_t_tilde = torch.cat([decoder_output, context_vector], dim=-1)
            # |h_t_tilde| = (bs, 1, hidden_size * 2)
            
            h_t_tilde = self.concat(h_t_tilde)
            # |h_t_tilde| = (bs, 1, hidden_size)
            
            h_t_tilde = self.tanh(h_t_tilde)
            
            h_tilde += [h_t_tilde]
            
        h_tilde = torch.cat(h_tilde, dim=1)
        # |h_tilde| = (bs, m, hidden_size)
        
        y_hat = self.generator(h_tilde)
        # |y_hat| = (bs, m, output_size)
        
        return y_hat
    
    
    def search(self, src, is_greedy=True, max_length=255):
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            
        else:
            x, x_length = src, None
            mask = None
            
        batch_size = x.size(0)
        
        encoder_emb_vec = self.encoder_emb(x)
        encoder_hidden, z = self.encoder(encoder_emb_vec)
        
        z = self.merge_z(z)
        
        y = x.new(batch_size, 1).zero_() + 2 # index of <BOS>
        
        is_decoding = x.new_ones(batch_size, 1).bool()
        
        h_t_tilde, y_hats, indice = None, [], []
        decoder_hidden = z
        
        
        while is_decoding.sum() > 0 and len(indice) < max_length:
            emb_t = self.decoder_emb(y)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden)
            context_vector = self.attention(decoder_output, encoder_hidden, mask)

            h_t_tilde = torch.cat([decoder_output, context_vector], dim=-1)
            h_t_tilde = self.concat(h_t_tilde)
            h_t_tilde = self.tanh(h_t_tilde)
            
            y_hat = self.generator(h_t_tilde)
            
            y_hats += [y_hat]
            
            if is_greedy:
                y = y_hat.argmax(dim=-1)
                
            else:
                # take random sampling
                y = torch.multinomial(y_hat.exp().view(batch_size, -1), 1)
                
            # 이번 step에 EOS가 없을 경우 PAD
            y = y.masked_fill_(~is_decoding, 0) # index of <PAD>
            # 있을 경우 EOS
            is_decoding = is_decoding * torch.ne(y, 3) # index of <EOS>
            
            indice += [y]
            
        y_hats = torch.cat(y_hats, dim=1)
        indice = torch.cat(indice, dim=1)
        
        return y_hats, indice
    
    
    def batch_beam_search(
        self,
        src,
        beam_size=5,
        max_length=255,
        n_best=1,
        length_penalty=.2
    ):
        mask, x_length = None, None
        
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            
        else:
            x = src
            
        batch_size = x.size(0)
        
        encoder_emb_vec = self.encoder_emb(x)
        encoder_hidden, z = self.encoder((encoder_emb_vec, x_length))
        z = self.merge_z(z)
        
        
        boards = [SingleBeamSearchBoard(
            encoder_hidden.device, 
            {
                'hidden_state': {
                    'init_status': z[:, i, :].unsqueeze(1),
                    'batch_dim_index': 1,
                },
                'h_prev_tilde': {
                    'init_status': None,
                    'batch_dim_index': 0
                }
            },
            beam_size=beam_size,
            max_length=max_length,
        ) for i in range(batch_size)]
        done_cnt = [board.is_done() for board in boards]
        
        length = 0
        
        while sum(done_cnt) < batch_size and length <= max_length:
            
            fab_input, fab_hidden, fab_h_t_tilde = [], [], []
            fab_encoder_hidden, fab_mask = [], []
            
            for i, board in enumerate(boards):
                if not board.is_done():
                    y_hat_i, prev_status = board.get_batch()
                    hidden_i = prev_status['hidden_state']
                    h_t_tilde_i = prev_status['h_prev_tilde']
                    
                    fab_input += [y_hat_i]
                    fab_hidden += [hidden_i]
                    fab_encoder_hidden += [encoder_hidden[i, :, :]] * beam_size
                    fab_mask += [mask[i, :]] * beam_size
                    if h_t_tilde_i is not None:
                        fab_h_t_tilde += [h_t_tilde_i]
                    else:
                        fab_h_t_tilde = None
                        
            # 가짜 미니배치
            fab_input = torch.cat(fab_input, dim=0)
            fab_hidden = torch.cat(fab_hidden, dim=1)
            fab_encoder_hidden = torch.stack(fab_encoder_hidden)
            fab_mask = torch.stack(fab_mask)
            if fab_h_t_tilde is not None:
                fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
                
            decoder_emb_vec = self.decoder_emb(fab_input)
            
            fab_decoder_output, fab_hidden = self.decoder(decoder_emb_vec,
                                                          fab_h_t_tilde,
                                                          fab_hidden
                                                         )
            
            context_vector = self.attn(fab_encoder_hidden, fab_decoder_output, fab_mask)
            fab_h_t_tilde = torch.cat([fab_decoder_output, context_vector], dim=-1)
            fab_h_t_tilde = self.concat(fab_h_t_tilde)
            fab_h_t_tilde = self.tanh(fab_h_t_tilde)
            
            y_hat = self.generator(fab_h_t_tilde)
            
            cnt = 0
            for board in boards:
                if not board.is_done():
                    begin = cnt * beam_size
                    end = begin + beam_size
                    
                    board.collect_result(
                        y_hat[begin:end],
                        {
                            'hidden_state': fab_hidden[:, begin:end, :],
                            'h_prev_tilde': fab_h_t_tilde[begin:end]
                        }
                    )
                    cnt += 1
            done_cnt = [board.is_done() for bard in boards]
            length += 1
            
        batch_sentences, batch_probs = [], []
        
        for i, board in enumerate(boards):
            sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty)
            
            batch_sentences += [sentences]
            batch_probs += [probs]
            
        return batch_sentences, batch_probs

## Search

In [4]:
from operator import itemgetter

LENGTH_PENALTY = .2
MIN_LENGTH = 5

class SingleBeamSearchBoard():
    
    def __init__(
        self,
        device,
        prev_status_config,
        beam_sizem=5,
        max_length=255,
    ):
        self.beam_size = beam_size
        self.max_length = max_length
        
        self.device = device
        # BOS 미리 지정
        self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + 2]
        # 이전의 beam 위치, -1에서 위치로 바뀔 것
        self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device) - 1]
        # 각 beam에 대한 누적확률, BOS가 들어갔을 때 나머지 확률을 무시하기 위해 -inf 처리
        self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(self.device)]
        self.masks = [torch.BoolTensor(beam_size).zero_().to(self.device)]
        
        self.prev_status = {}
        self.batch_dims = {}
        
        # hidden_state, h_prev_tilde 초기화
        for prev_status_name, each_config in prev_status_config.items():
            init_status = each_config['init_status']
            # if hidden_state, |init_status| = (n_layers, 1, hidden_size)
            batch_dim_index = each_config['batch_dim_index']
            
            if init_status is not None:
                self.prev_status[prev_status_name] = torch.cat([init_satus] * beam_size,
                                                               dim=batch_dim_index)
                # if hidden_state, |prev_status[prev_status_name]| = (n_layers, beam_size, hidden_size)
            else:
                self.prev_status[prev_status_name] = None
            self.batch_dims[prev_status_name] = batch_dim_index
            
        self.current_time_step = 0
        self.done_cnt = 0
        
    
    def get_length_penalty(
        self,
        length,
        alpha=LENGTH_PENALTY,
        min_lengt=MIN_LENGTH,
    ):
        p = ((min_length + 1) / (min_length + length)) ** alpha
        
        return p
    
    
    def is_done(self):
        return self.done_cnt >= self.beam_size
    
    
    def get_batch(self):
        y_hat = self.word_indice[-1].unsqueeze(-1)
        # |y_hat| = (beam_size, 1)
#         prev_status = [v for k, v in self.prev_status.items()]
        
        return y_hat, self.prev_status
    
    def collect_result(self, y_hat, prev_status):
        output_size = y_hat.size(-1)
        
        self.current_time_step += 1
        
        # 이미 끝난 beam은 -inf 통해 안보겠다.
        cumulative_prob = self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf'))
        cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size)
        
        top_log_prob, top_indice = torch.topk(
            cumulative_prob.view(-1),
            self.beam_size,
            dim=-1
        )
        
        # Beam 내에서의 순서
        self.word_indice += [top_indice.fmod(output_size)]
        # Beam 순서
        self.beam_indice += [top_indice.div(float(output_size)).long()]
        
        self.cumulative_probs += [top_log_prob]
        self.masks += [torch.eq(self.word_indice[-1], 3)]
        self.done_cnt += self.masks[-1].float().sum()
        
        
        for prev_status_name, prev_status in prev_status.items():
            self.prev_status[prev_status_name] = torch.index_select(
                prev_status,
                dim=self.batch_dims[prev_status_name],
                index=self.beam_indice[-1]
            ).contiguous()
            
            
    def get_n_best(self, n=1, length_penalty=.2):
        sentences, probs, founds = [], [], []
        
        for t in range(len(self.word_indice)):
            for b in range(self.beam_size):
                if self.masks[t][b] == 1: # eos 찾기
                    probs += [self.cumulative_probs[t][b] * self.get_length_penalty(t, alpha=length_penalty)]
                    founds += [(t, b)]
                    
        for b in range(self.beam_size):
            if self.cumulative_probs[-1][b] != -float('inf'):
                if not (len(self.cumulative_probs) - 1, b) in founds:
                    probs += [self.cumulative_probs[-1][b] * self.get_length_penalty(len(self.cumulative_probs), alpha=length_penalty)]
                    founds += [(t, b)]
                    
                    
        sorted_founds_with_probs = sorted(
            zip(founds, probs),
            key=itemgetter(1),
            reverse=True
        )[:n]
        probs = []
        
        for (end_index, b), prob in sorted_founds_woth_probs:
            sentence = []
            
            for t in range(end_index, 0, -1):
                sentence = [self.word_indice[t][b]] + sentence
                b = self.beam_indice[t][b]
                
            sentences += [sentence]
            probs += [prob]
            
        return sentences, probs

## DataLoader

In [5]:
class CustomDataLoader:
    def __init__(self, batch_size=64, max_length=70, shuffle=True, train=True):
        self.batch_size = batch_size
        
        self.SRC = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            preprocessing=lambda x : x if len(x) < max_length else x[:max_length],
            include_lengths=True,
        )
        self.TGT = data.Field(
            sequential=True,
            use_vocab=True,
            batch_first=True,
            preprocessing=lambda x : x if len(x) < max_length else x[:max_length],
            init_token='<BOS>',
            eos_token='<EOS>'
        )
        
        if train :
            train, valid = data.TabularDataset.splits(
                path='./kor_eng_translation/',
                train='train.tsv',
                validation='valid.tsv',
                format='tsv',
                fields=[('src',self.SRC), ('tgt', self.TGT)]
            )

            self.train_loader = data.BucketIterator(
                train,
                batch_size,
                device='cuda:0',
                shuffle=shuffle,
                sort_key = lambda x : len(x.tgt) + (80 * len(x.src)),
                sort_within_batch=True,
            )
            self.valid_loader = data.BucketIterator(
                valid,
                batch_size,
                device='cuda:0',
                sort_key = lambda x : len(x.tgt) + (80 * len(x.src)),
                sort_within_batch=True,
            )
            
            self.SRC.build_vocab(train, max_size=30000, min_freq=5)
            self.TGT.build_vocab(train, max_size=30000, min_freq=5)
            
    def load_vocab(self, src_vocab, tgt_vocab):
        self.SRC.vocab = src_vocab
        self.TGT.vocab = tgt_vocab