내가 봤던 seq2seq with attention

In [10]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [11]:
SEQ_LEN = 10
BATCH_SIZE = 3
INPUT_DIM = 30
OUTPUT_DIM = 37
ENC_EMB_DIM = DEC_EMB_DIM = 32
ENC_HID_DIM = DEC_HID_DIM = 64
ENC_DROPOUT = DEC_DROPOUT = 0.5

x = torch.randint(0+1, INPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
x[0, :] = 0 
x[-1, :] = INPUT_DIM - 1

y = torch.randint(0+1, OUTPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
y[0, :] = 0
y[-1, :] = OUTPUT_DIM - 1

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0,  0,  0],
        [13, 11, 19],
        [25, 22,  2],
        [27,  9, 19],
        [27,  4,  6],
        [12, 13,  8],
        [ 5, 12, 13],
        [20, 14, 12],
        [ 3,  7, 15],
        [29, 29, 29]]) torch.Size([10, 3])

tensor([[ 0,  0,  0],
        [13,  4, 19],
        [23, 12, 21],
        [29, 14, 28],
        [15, 23, 33],
        [ 4, 14, 24],
        [17, 10, 27],
        [20, 24, 18],
        [ 6, 26, 16],
        [36, 36, 36]]) torch.Size([10, 3])


In [12]:
class Encoder(nn.Module):

    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        H = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        hidden = torch.tanh(self.fc(H))
        return outputs, hidden


class Decoder(nn.Module):

    def __init__(self,
                 output_dim,
                 emb_dim,
                 enc_hid_dim,
                 dec_hid_dim,
                 dropout,
                 attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden, encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        assert (output == hidden).all()
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        return prediction, hidden.squeeze(0)


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)
        input = trg[0,:]
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
        return outputs

In [13]:
class Attention(nn.Module):

    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        attn_input = torch.cat((hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        annotation = F.softmax(attention, dim=1)
        return annotation

In [14]:
device = torch.device('cpu')

In [15]:
attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [17]:
model(x, y).shape

torch.Size([10, 3, 37])

김기현님 attention

In [18]:
SEQ_LEN      = 10
BATCH_SIZE   = 3
input_size   = INPUT_DIM = 30
output_size  = OUTPUT_DIM = 37
word_vec_dim = ENC_EMB_DIM = DEC_EMB_DIM = 32
hidden_size  = ENC_HID_DIM = DEC_HID_DIM = 64
dropout_p    = ENC_DROPOUT = DEC_DROPOUT = 0.2 # 0.5 to 0.2

In [84]:
SRC_PAD_IDX = TRG_PAD_IDX = 1
MIN_WORDS   = 5

src_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
trg_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
if SEQ_LEN - 1 not in src_seq_length:
    src_seq_length[-1] = SEQ_LEN - 2
if SEQ_LEN - 1 not in trg_seq_length:
    trg_seq_length[-1] = SEQ_LEN - 2

x = torch.randint(0+2, INPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
x[:, 0] = 0
for i, ind in enumerate(src_seq_length):
    x[i, ind+1 ] = INPUT_DIM - 1
    x[i, ind+2:] = SRC_PAD_IDX

y = torch.randint(0+2, OUTPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
y[:, 0] = 0
for i, ind in enumerate(trg_seq_length):
    y[i, ind+1 ] = OUTPUT_DIM - 1
    y[i, ind+2:] = TRG_PAD_IDX

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0, 21, 23, 23, 26,  4, 21, 29,  1,  1],
        [ 0,  6, 18,  6,  3,  3, 27, 29,  1,  1],
        [ 0, 11, 11, 23, 22, 21, 19,  4, 22, 29]]) torch.Size([3, 10])

tensor([[ 0, 30, 25, 15, 11, 33, 17, 36,  1,  1],
        [ 0, 17,  2, 21,  6, 10, 36,  1,  1,  1],
        [ 0, 27, 34, 28,  3, 34, 18,  5, 14, 36]]) torch.Size([3, 10])


In [85]:
batch_size = y.size(0)
batch_size

3

In [86]:
mask, x_length = None, None

In [87]:
x

tensor([[ 0, 21, 23, 23, 26,  4, 21, 29,  1,  1],
        [ 0,  6, 18,  6,  3,  3, 27, 29,  1,  1],
        [ 0, 11, 11, 23, 22, 21, 19,  4, 22, 29]])

In [88]:
emb_src = nn.Embedding(input_size, word_vec_dim)

In [89]:
emb_src_ = emb_src(x)

In [90]:
emb_src_.shape

torch.Size([3, 10, 32])

In [91]:
encoder_rnn = nn.LSTM(word_vec_dim,
                      hidden_size,
                      num_layers=1,
                      dropout=0,
                      bidirectional=False,
                      batch_first=True)

In [92]:
[i.shape for i in list(encoder_rnn.parameters())]

[torch.Size([256, 32]),
 torch.Size([256, 64]),
 torch.Size([256]),
 torch.Size([256])]

In [93]:
h_src, h_0_tgt = encoder_rnn(emb_src_)
h_0_tgt, c_0_tgt = h_0_tgt

In [94]:
h_src.shape

torch.Size([3, 10, 64])

In [95]:
h_0_tgt.shape, c_0_tgt.shape

(torch.Size([1, 3, 64]), torch.Size([1, 3, 64]))

In [96]:
h_0_tgt = (h_0_tgt, c_0_tgt)
type(h_0_tgt)

tuple

In [98]:
emb_dec = nn.Embedding(output_size, word_vec_dim)

In [100]:
emb_tgt = emb_dec(y)
emb_tgt.shape

torch.Size([3, 10, 32])

In [101]:
h_tilde = []
h_t_tilde = None

In [102]:
decoder_hidden = h_0_tgt