Skip to content

Commit

Permalink
Basic seq2seq. #5
Browse files Browse the repository at this point in the history
  • Loading branch information
lizeyan committed Sep 16, 2018
1 parent 93bd495 commit e5da9c3
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions snippets/modules/sequence/rnn_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,56 @@ def forward_n(self, x, hidden=None, *, n_steps: int, **kwargs):
decoder_outputs[di] = decoder_output[0]
output_length += 1
return decoder_outputs


class LuongGlobalAttentionDecoderSeq(DecoderSeq):
def __init__(self, hidden_size: int, output_size: int, embedding_size: int, *,
n_layers: int = 1, dropout_p: float = 0.5,
bidirectional: bool=False):
super().__init__(hidden_size, output_size, embedding_size,
n_layers=n_layers, dropout_p=dropout_p, bidirectional=bidirectional)
self.attn = nn.Linear(2 * hidden_size * self.n_direction, 1, bias=False)
self.out = nn.Sequential(
nn.Linear(2 * hidden_size * self.n_direction, output_size),
nn.LogSoftmax(dim=-1),
)

def forward(self, x, hidden=None, **kwargs):
"""
:param x: A Tensor in shape (seq_length, batch_size, )
:param hidden: Corresponding hidden state, keeping it None will use new hidden state
:param kwargs:
encoder_outputs: required, a Tensor in shape (encoder_seq_length, batch_size, hidden_size * n_direction)
:return: A Tensor in shape (seq_length, batch_size, output_size * n_direction)
"""
encoder_outputs = kwargs["encoder_outputs"]
target_seq_length = x.size(0)
source_seq_length = encoder_outputs.size(0)
batch_size = x.size(1)
assert encoder_outputs.size(1) == batch_size

embedded = self.embedding(x)
embedded = self.input_dropout(embedded)
embedded = func.relu(embedded)
assert embedded.size() == (target_seq_length, batch_size, self.embedding_size)

rnn_outputs, hidden = self.rnn(embedded, hidden)
assert rnn_outputs.size() == (target_seq_length, batch_size, self.hidden_size * self.n_direction)
assert hidden.size() == (self.n_layers * self.n_direction, batch_size, self.hidden_size)

cat = torch.cat([
rnn_outputs.unsqueeze(1).expand(-1, source_seq_length, -1, -1),
encoder_outputs.unsqueeze(0).expand(target_seq_length, -1, -1, -1)
], dim=-1)
assert cat.size() == (target_seq_length, source_seq_length, batch_size, self.hidden_size * self.n_direction * 2)
attn_prod = self.attn(cat).permute(0, 2, 3, 1)
attn_weights = func.softmax(attn_prod, dim=-1)
assert attn_weights.size() == (target_seq_length, batch_size, 1, source_seq_length)
context_vector = (attn_weights @ encoder_outputs.permute(1, 0, 2)).squeeze(2)
assert context_vector.size() == (target_seq_length, batch_size, self.hidden_size * self.n_direction)

output = self.out(torch.cat([
context_vector, rnn_outputs
], dim=-1))
return output, hidden

0 comments on commit e5da9c3

Please sign in to comment.