Skip to content

Commit

Permalink
0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhsianghuang committed Aug 24, 2018
1 parent 0b0eabb commit 1d4df3b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
15 changes: 13 additions & 2 deletions transformer/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

__author__ = "Yu-Hsiang Huang"


class EncoderLayer(nn.Module):
''' Compose with two layers '''

Expand All @@ -13,12 +14,17 @@ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

def forward(self, enc_input, slf_attn_mask=None):
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output *= non_pad_mask

enc_output = self.pos_ffn(enc_output)
enc_output *= non_pad_mask

return enc_output, enc_slf_attn


class DecoderLayer(nn.Module):
''' Compose with three layers '''

Expand All @@ -28,11 +34,16 @@ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output *= non_pad_mask

dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output *= non_pad_mask

dec_output = self.pos_ffn(dec_output)
dec_output *= non_pad_mask

return dec_output, dec_slf_attn, dec_enc_attn
35 changes: 28 additions & 7 deletions transformer/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

__author__ = "Yu-Hsiang Huang"

def get_non_pad_mask(seq):
assert seq.dim() == 2
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)

def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''
Expand All @@ -28,9 +31,10 @@ def get_posi_angle_vec(position, d_hid):

return torch.FloatTensor(sinusoid_table)

def get_padding_mask(seq_q, seq_k):
''' For masking out the padding part. '''
def get_attn_key_pad_mask(seq_k, seq_q):
''' For masking out the padding part of key sequence. '''

# Expand to fit the shape of key query attention matrix.
len_q = seq_q.size(1)
padding_mask = seq_k.eq(Constants.PAD)
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
Expand Down Expand Up @@ -74,12 +78,19 @@ def __init__(
def forward(self, src_seq, src_pos, return_attns=False):

enc_slf_attn_list = []
slf_attn_mask = get_padding_mask(src_seq, src_seq)

# -- Prepare masks
slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
non_pad_mask = get_non_pad_mask(src_seq)

# -- Forward
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)

for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=slf_attn_mask)
enc_output, enc_slf_attn = enc_layer(
enc_output,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]

Expand Down Expand Up @@ -113,15 +124,25 @@ def __init__(
def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):

dec_slf_attn_list, dec_enc_attn_list = [], []
slf_attn_mask = (get_padding_mask(tgt_seq, tgt_seq) + get_subsequent_mask(tgt_seq)).gt(0)
dec_enc_attn_mask = get_padding_mask(tgt_seq, src_seq)

# -- Prepare masks
non_pad_mask = get_non_pad_mask(tgt_seq)

slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

# -- Forward
dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)

for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output,
slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask)
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask,
dec_enc_attn_mask=dec_enc_attn_mask)

if return_attns:
dec_slf_attn_list += [dec_slf_attn]
Expand Down

0 comments on commit 1d4df3b

Please sign in to comment.