In [1]:
import torch
import torch.nn.functional as F

In [86]:
batch_size = 4
hidden_size = 128
vocab_size = 100
max_len = 50
pad_id = 0
sos_id = 1
eos_id = 2
unk_id = 3

In [87]:
class Decoder(torch.nn.Module):
    def __init__(self, hidden_size, vocab_size, num_layers, dropout):
        super(Decoder, self).__init__()
        
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.rnn = torch.nn.GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)
        self.linear = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, inputs, encoder_hidden):
        # inputs: [batch, time]
        
        output = self.embedding(inputs)
        output, hidden = self.rnn(output, encoder_hidden)
        output = self.linear(output)
        output = F.softmax(output, dim=-1)
        
        # output: [batch, time, vocab]
        return output

In [88]:
decoder = Decoder(hidden_size, vocab_size, 1, 0.0)

In [89]:
x = torch.ones([batch_size, 1], dtype=torch.int64)*sos_id
x = decoder(x,None)

In [129]:
def beam_decode(decoder, encoder_hidden, beam_size, batch_size, max_len, sos_id, eos_id):
    """Beam search decoding.
    
    Args:
        decoder: Pytorch RNN decoder.
        encoder_hidden: Hidden state of RNN encoder.
        beam_size: Beam width.
        batch_size: Batch size.
        max_len: Maximum steps of beam search.
        sos_id: Id of <SOS>
        eos_id: Id of <EOS>
    """
    
    # save k(=beam_size) paths
    k_paths = [[[sos_id] for i in range(beam_size)] for j in range(batch_size)]
    
    # make <SOS> batch: [batch, 1]
    sos_batch = torch.ones([batch_size, 1], dtype=torch.int64) * sos_id
    
    # outputs: [batch, vocab]
    outputs = decoder(sos_batch, encoder_hidden)[:, -1, :]
    
    """first step"""
    for b, batch in enumerate(outputs):
        # probs, preds: [beam_size]
        probs, preds = torch.topk(batch, beam_size)
        
        for idx, pred in enumerate(preds):
            k_paths[b][idx].append(pred.item())
    """"""
    
    for step in range(2, max_len+1):
        # inputs: [batch, beam_size, step]
        inputs = torch.LongTensor(k_paths)
        # [batch, 1, step] * beam_size
        inputs = torch.split(inputs, 1, dim=1)
        
        outputs = []
        for inputs_ in inputs:
            # inputs: [batch, step]
            # output: [batch. vocab]
            output = decoder(inputs_.view(batch_size, -1), encoder_hidden)[:, -1, :]
            outputs.append(output)
        
        # outputs: [batch, vocab * beam_size]
        outputs = torch.cat(outputs, dim=1)
        
        for b, batch in enumerate(outputs):
            probs, preds = torch.topk(batch, beam_size)
            tmp_path = []
            for idx, pred in enumerate(preds):
                for i in range(beam_size):
                    if pred >= i*vocab_size and pred < (i+1)*vocab_size:
                        break
                tmp_path.append(k_paths[b][i].copy())
                tmp_path[idx].append(pred.item()-(i*100))
            k_paths[b] = tmp_path
            
    return k_paths

In [135]:
beam_decode(decoder, None, 3, 4, 2, 1, 2)[0]

[[1, 21, 67], [1, 65, 97], [1, 21, 60]]