In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

- Additive Attention

In [2]:
device = torch.device('cpu')

SEQ_LEN     = 15
BATCH_SIZE  = 5
INPUT_DIM   = 30
OUTPUT_DIM  = 37
HID_DIM     = 256
ENC_EMB_DIM = DEC_EMB_DIM = 32
ENC_HID_DIM = DEC_HID_DIM = 64
ENC_DROPOUT = DEC_DROPOUT = 0.1

In [3]:
SRC_PAD_IDX = TRG_PAD_IDX = 1
MIN_WORDS   = 5

src_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
trg_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
if SEQ_LEN - 1 not in src_seq_length:
    src_seq_length[-1] = SEQ_LEN - 2
if SEQ_LEN - 1 not in trg_seq_length:
    trg_seq_length[-1] = SEQ_LEN - 2

x = torch.randint(0+2, INPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
x[:, 0] = 0
for i, ind in enumerate(src_seq_length):
    x[i, ind+1 ] = INPUT_DIM - 1
    x[i, ind+2:] = SRC_PAD_IDX

y = torch.randint(0+2, OUTPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
y[:, 0] = 0
for i, ind in enumerate(trg_seq_length):
    y[i, ind+1 ] = OUTPUT_DIM - 1
    y[i, ind+2:] = TRG_PAD_IDX

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0, 26, 25, 27, 10,  2,  9,  3, 26,  9, 20, 29,  1,  1,  1],
        [ 0,  5, 26, 25, 13, 18, 21,  6,  6, 12, 29,  1,  1,  1,  1],
        [ 0, 17, 27,  5, 23, 16, 11,  5, 26, 11, 19, 29,  1,  1,  1],
        [ 0,  9, 13, 25, 25,  5, 15, 21, 27, 29,  1,  1,  1,  1,  1],
        [ 0, 15, 11, 13, 23,  4, 14,  3,  8,  5,  8,  8, 22, 27, 29]]) torch.Size([5, 15])

tensor([[ 0, 18,  4,  3,  2,  4, 15, 36,  1,  1,  1,  1,  1,  1,  1],
        [ 0,  2, 28, 30, 27,  3, 32,  8, 36,  1,  1,  1,  1,  1,  1],
        [ 0, 22, 26,  9,  7, 21, 22, 36,  1,  1,  1,  1,  1,  1,  1],
        [ 0, 14, 29, 16, 34, 16, 10, 36,  1,  1,  1,  1,  1,  1,  1],
        [ 0, 33,  7, 19, 19, 30, 31, 10, 33, 24,  2,  3,  9, 26, 36]]) torch.Size([5, 15])


In [4]:
class Encoder(nn.Module):

    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, 
                          enc_hid_dim, 
                          bidirectional=True,
                          batch_first=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        H = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        hidden = torch.tanh(self.fc(H))
        return outputs, hidden

In [5]:
encoder = Encoder(INPUT_DIM, 
                  ENC_EMB_DIM, 
                  ENC_HID_DIM, 
                  DEC_HID_DIM, 
                  ENC_DROPOUT)

In [6]:
encoder_outputs, hidden = encoder(x)

In [7]:
encoder_outputs.shape, hidden.shape

(torch.Size([5, 15, 128]), torch.Size([5, 64]))

In [8]:
attn = nn.Linear((ENC_HID_DIM * 2) + DEC_HID_DIM, DEC_HID_DIM)
v    = nn.Linear(DEC_HID_DIM, 1, bias=False)

In [9]:
batch_size = encoder_outputs.shape[0]
src_len    = encoder_outputs.shape[1]
batch_size, src_len

(5, 15)

In [10]:
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
hidden.shape

torch.Size([5, 15, 64])

In [11]:
encoder_outputs.shape

torch.Size([5, 15, 128])

In [12]:
attn_input = torch.cat((hidden, encoder_outputs), dim=2)
attn_input.shape

torch.Size([5, 15, 192])

In [13]:
energy = torch.tanh(attn(attn_input))
energy.shape

torch.Size([5, 15, 64])

In [14]:
attn.weight.shape

torch.Size([64, 192])

In [15]:
attention = v(energy).squeeze(2)
attention

tensor([[-0.0681, -0.0243, -0.0442, -0.0483, -0.0822, -0.0971, -0.0957, -0.0324,
         -0.0575, -0.0685, -0.0634,  0.0040,  0.0625,  0.0376,  0.0002],
        [-0.0816, -0.0831, -0.0451, -0.0357, -0.0238, -0.0653, -0.0959,  0.0309,
          0.0625,  0.0946,  0.0811,  0.0908,  0.0891,  0.0769,  0.0400],
        [-0.0255, -0.0423,  0.0063,  0.0458,  0.0461,  0.0255,  0.0104, -0.0207,
         -0.0012,  0.0412, -0.0437, -0.0008,  0.0510,  0.0682,  0.0592],
        [-0.0753, -0.0629, -0.0061, -0.0498, -0.1224, -0.1380, -0.1147, -0.1310,
         -0.0648,  0.0158,  0.0538,  0.0454,  0.0311,  0.0213,  0.0067],
        [-0.0818, -0.0326, -0.0146, -0.0075,  0.0335,  0.0188,  0.0304,  0.0152,
         -0.0397, -0.0828, -0.0442, -0.0273, -0.0453, -0.0897, -0.0477]],
       grad_fn=<SqueezeBackward1>)

In [16]:
annotation = F.softmax(attention, dim=1)
annotation

tensor([[0.0647, 0.0676, 0.0662, 0.0659, 0.0637, 0.0628, 0.0629, 0.0670, 0.0653,
         0.0646, 0.0650, 0.0695, 0.0737, 0.0719, 0.0692],
        [0.0607, 0.0607, 0.0630, 0.0636, 0.0644, 0.0617, 0.0599, 0.0680, 0.0702,
         0.0724, 0.0715, 0.0722, 0.0720, 0.0712, 0.0686],
        [0.0640, 0.0629, 0.0661, 0.0687, 0.0688, 0.0674, 0.0663, 0.0643, 0.0656,
         0.0684, 0.0629, 0.0656, 0.0691, 0.0703, 0.0697],
        [0.0642, 0.0650, 0.0688, 0.0658, 0.0612, 0.0603, 0.0617, 0.0607, 0.0649,
         0.0703, 0.0730, 0.0724, 0.0714, 0.0707, 0.0697],
        [0.0631, 0.0663, 0.0675, 0.0680, 0.0708, 0.0698, 0.0706, 0.0695, 0.0658,
         0.0630, 0.0655, 0.0666, 0.0655, 0.0626, 0.0653]],
       grad_fn=<SoftmaxBackward>)

In [17]:
annotation = annotation.unsqueeze(1)
annotation.shape

torch.Size([5, 1, 15])

In [18]:
weighted = torch.bmm(annotation, encoder_outputs)
weighted.shape
# (b, n, m) X (b, m, p) ==>> (b, n, p)

torch.Size([5, 1, 128])

In [19]:
dec_emb = nn.Embedding(OUTPUT_DIM, DEC_EMB_DIM)
embedded = dec_emb(y[:, 0].unsqueeze(1))
embedded.shape

torch.Size([5, 1, 32])

In [20]:
rnn_input = torch.cat((embedded, weighted), dim=2)
rnn_input.shape

torch.Size([5, 1, 160])

- Additive Self-attention

In [21]:
x

tensor([[ 0, 26, 25, 27, 10,  2,  9,  3, 26,  9, 20, 29,  1,  1,  1],
        [ 0,  5, 26, 25, 13, 18, 21,  6,  6, 12, 29,  1,  1,  1,  1],
        [ 0, 17, 27,  5, 23, 16, 11,  5, 26, 11, 19, 29,  1,  1,  1],
        [ 0,  9, 13, 25, 25,  5, 15, 21, 27, 29,  1,  1,  1,  1,  1],
        [ 0, 15, 11, 13, 23,  4, 14,  3,  8,  5,  8,  8, 22, 27, 29]])

In [22]:
x_lengths = (x != 1).sum(dim=1)

In [23]:
x_lengths

tensor([12, 11, 12, 10, 15])

In [24]:
sorted_lengths, ind = torch.sort(x_lengths, descending=True)

In [25]:
sorted_x = x.index_select(0, ind)

In [26]:
_, rev_ind = torch.sort(ind)

In [27]:
restored_x = sorted_x.index_select(0, rev_ind)

In [28]:
device = torch.device('cpu')

SEQ_LEN     = 15
BATCH_SIZE  = 5
INPUT_DIM   = 30
OUTPUT_DIM  = 37
HID_DIM     = 256
ENC_EMB_DIM = DEC_EMB_DIM = 32
ENC_HID_DIM = DEC_HID_DIM = 64
ENC_DROPOUT = DEC_DROPOUT = 0.1

r = 8

In [29]:
embed = nn.Embedding(INPUT_DIM,  # vocab_size
                     ENC_EMB_DIM # embedding_size
                    )

In [30]:
rnn = nn.LSTM(ENC_EMB_DIM,
              ENC_HID_DIM, # hidden_size
              num_layers=1,
              batch_first=True,
              bidirectional=True
             )

In [31]:
attn = nn.Linear(2 * ENC_HID_DIM, # num_directions*hidden_size
                 DEC_HID_DIM,     # attention_dimension
                 bias=False
                )
attn2 = nn.Linear(DEC_HID_DIM,    # attention_dimension
                  r,              # keywords
                                  # (different parts to be expected
                                  #  from the sentence)
                  bias=False
                 )

In [32]:
tanh = nn.Tanh()
sigmoid = nn.Sigmoid()
attn_dist = nn.Softmax(dim=2)

In [33]:
fc = nn.Sequential(
    nn.Linear(r * ENC_HID_DIM * 2, 16),
    nn.ReLU(),
    nn.Linear(16, 2), # fc의 hidden_size, output_size
)

In [34]:
embedded = embed(sorted_x)

In [35]:
embedded.shape

torch.Size([5, 15, 32])

In [36]:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

In [37]:
hidden = Variable(
    torch.zeros(1*2, batch_size, ENC_HID_DIM)
)
cell = Variable(
    torch.zeros(1*2, batch_size, ENC_HID_DIM)
)

In [38]:
embedded.shape

torch.Size([5, 15, 32])

In [39]:
packed = pack_padded_sequence(embedded, 
                              sorted_lengths.tolist(), 
                              batch_first=True)
packed

PackedSequence(data=tensor([[ 0.6832,  1.8163, -0.6768,  ...,  1.8970,  0.4989,  0.9598],
        [ 0.6832,  1.8163, -0.6768,  ...,  1.8970,  0.4989,  0.9598],
        [ 0.6832,  1.8163, -0.6768,  ...,  1.8970,  0.4989,  0.9598],
        ...,
        [-1.0283, -2.8824,  0.2092,  ..., -0.9529, -1.0811, -0.3010],
        [ 0.7270,  0.1353,  0.4488,  ...,  0.0234,  0.9150,  1.1722],
        [-0.0140,  0.8899,  0.0727,  ...,  0.1250,  1.0330, -0.0721]],
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 3, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [40]:
output, (hidden, cell) = rnn(packed, (hidden, cell))

In [41]:
output

PackedSequence(data=tensor([[ 0.0149, -0.0333, -0.0658,  ...,  0.1088,  0.1697, -0.1746],
        [ 0.0149, -0.0333, -0.0658,  ...,  0.0406,  0.2472, -0.0383],
        [ 0.0149, -0.0333, -0.0658,  ...,  0.0265,  0.1049, -0.0137],
        ...,
        [ 0.1855,  0.1605,  0.1007,  ...,  0.0813, -0.1200, -0.0586],
        [ 0.0753,  0.1276,  0.0922,  ...,  0.0783, -0.0189, -0.0041],
        [ 0.0692, -0.0353,  0.0728,  ...,  0.0188,  0.0904, -0.0005]],
       grad_fn=<CatBackward>), batch_sizes=tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 3, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [42]:
output.data.shape

torch.Size([60, 128])

In [43]:
hidden.shape, cell.shape

(torch.Size([2, 5, 64]), torch.Size([2, 5, 64]))

In [44]:
output, output_lengths = pad_packed_sequence(output, batch_first=True)

In [45]:
output.shape

torch.Size([5, 15, 128])

In [46]:
# self-attention
tanh_a1 = tanh(attn(output))
tanh_a1.shape

torch.Size([5, 15, 64])

In [47]:
score = attn2(tanh_a1)
score.shape

torch.Size([5, 15, 8])

In [48]:
A = attn_dist(score.transpose(1, 2)) # softmax
A.shape

torch.Size([5, 8, 15])

In [49]:
M = A.bmm(output)
M.shape

torch.Size([5, 8, 128])

In [50]:
# Penalization
eye = Variable(torch.eye(A.size(1)).expand(A.size(0), r, r))
eye

tensor([[[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]],

        [[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]],

        [[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0.],
       

In [51]:
P = torch.bmm(A, A.transpose(1, 2)) - eye
P.shape

torch.Size([5, 8, 8])

In [52]:
loss_P = ((P**2).sum(1).sum(1) + 1e-10) ** 0.5
loss_P = torch.sum(loss_P) / A.size(0)
loss_P

tensor(2.6865, grad_fn=<DivBackward0>)

In [53]:
# output
M.view(M.size(0), -1).shape

torch.Size([5, 1024])

In [54]:
fc

Sequential(
  (0): Linear(in_features=1024, out_features=16, bias=True)
  (1): ReLU()
  (2): Linear(in_features=16, out_features=2, bias=True)
)

In [55]:
fc(M.view(M.size(0), -1))

tensor([[0.0646, 0.0821],
        [0.0565, 0.0867],
        [0.0421, 0.0829],
        [0.0607, 0.0734],
        [0.0531, 0.0873]], grad_fn=<AddmmBackward>)

- Multiplicate attention

In [56]:
SEQ_LEN      = 10
BATCH_SIZE   = 3
input_size   = INPUT_DIM = 30
output_size  = OUTPUT_DIM = 37
word_vec_dim = ENC_EMB_DIM = DEC_EMB_DIM = 32
hidden_size  = ENC_HID_DIM = DEC_HID_DIM = 64
dropout_p    = ENC_DROPOUT = DEC_DROPOUT = 0.2 # 0.5 to 0.2

In [57]:
SRC_PAD_IDX = TRG_PAD_IDX = 1
MIN_WORDS   = 5

src_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
trg_seq_length = torch.randint(MIN_WORDS, SEQ_LEN-1, (BATCH_SIZE,))
if SEQ_LEN - 1 not in src_seq_length:
    src_seq_length[-1] = SEQ_LEN - 2
if SEQ_LEN - 1 not in trg_seq_length:
    trg_seq_length[-1] = SEQ_LEN - 2

x = torch.randint(0+2, INPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
x[:, 0] = 0
for i, ind in enumerate(src_seq_length):
    x[i, ind+1 ] = INPUT_DIM - 1
    x[i, ind+2:] = SRC_PAD_IDX

y = torch.randint(0+2, OUTPUT_DIM-2, size=(BATCH_SIZE, SEQ_LEN))
y[:, 0] = 0
for i, ind in enumerate(trg_seq_length):
    y[i, ind+1 ] = OUTPUT_DIM - 1
    y[i, ind+2:] = TRG_PAD_IDX

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0,  7,  6, 18, 27,  2, 29,  1,  1,  1],
        [ 0, 16, 27, 19,  9, 15, 21, 24,  6, 29],
        [ 0, 16, 27, 21,  6,  2, 11,  2, 10, 29]]) torch.Size([3, 10])

tensor([[ 0,  7, 21, 11, 28, 15, 27, 32, 14, 36],
        [ 0,  2, 13, 20, 24,  7, 36,  1,  1,  1],
        [ 0, 12, 19, 11, 14, 29, 27, 12, 23, 36]]) torch.Size([3, 10])


In [58]:
batch_size = y.size(0)
batch_size

3

In [59]:
mask, x_length = None, None

In [60]:
emb_src = nn.Embedding(input_size, word_vec_dim)
encoder_rnn = nn.LSTM(word_vec_dim,
                      hidden_size,
                      num_layers=1,
                      dropout=0,
                      bidirectional=False,
                      batch_first=True)
emb_dec = nn.Embedding(output_size, word_vec_dim)

In [61]:
emb_src_ = emb_src(x)

In [62]:
h_src, h_0_tgt = encoder_rnn(emb_src_)
h_0_tgt, c_0_tgt = h_0_tgt
h_src.shape, h_0_tgt.shape, c_0_tgt.shape

(torch.Size([3, 10, 64]), torch.Size([1, 3, 64]), torch.Size([1, 3, 64]))

In [63]:
h_0_tgt = (h_0_tgt, c_0_tgt)

In [64]:
emb_tgt_ = emb_dec(y)
emb_tgt_.shape

torch.Size([3, 10, 32])

In [65]:
h_tilde = []
h_t_tilde = None
decoder_hidden = h_0_tgt

In [66]:
emb_t = emb_tgt_[:, 0, :].unsqueeze(1)

In [67]:
emb_t.shape

torch.Size([3, 1, 32])

In [68]:
dec_rnn = nn.LSTM(word_vec_dim + hidden_size,
                  hidden_size,
                  num_layers=1,
                  dropout=0.0,
                  bidirectional=False,
                  batch_first=True
                  )

In [69]:
h_t_l_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()

In [70]:
h_t_l_tilde.shape

torch.Size([3, 1, 64])

In [71]:
# input_feeding_trick
input_ft = torch.cat([emb_t, h_t_l_tilde], dim=-1)
input_ft.shape

torch.Size([3, 1, 96])

In [72]:
decoder_output, decoder_hidden = dec_rnn(input_ft, decoder_hidden)

In [74]:
# Attention
linear = nn.Linear(hidden_size, hidden_size, bias=False)
softmax = nn.Softmax(dim=-1)

In [76]:
decoder_output.shape

torch.Size([3, 1, 64])

In [81]:
query = linear(decoder_output.squeeze()).unsqueeze(-1)
query.shape

torch.Size([3, 64, 1])

In [82]:
weight = torch.bmm(h_src, query).squeeze(-1)
weight.shape

torch.Size([3, 10])

# Transformer

In [267]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

device = torch.device('cpu')

SRC_SEQ_LEN  = 15
TRG_SEQ_LEN  = 18
BATCH_SIZE   = 5
INPUT_DIM    = 30
OUTPUT_DIM   = 37
HID_DIM      = 256
N_SPLITS     = 8
N_ENC_BLOCKS = N_DEC_BLOCKS = 6
DROPOUT_P    = 0.1

In [268]:
SRC_PAD_IDX = TRG_PAD_IDX = 1
MIN_WORDS   = 5

src_seq_length = torch.randint(MIN_WORDS, SRC_SEQ_LEN-1, (BATCH_SIZE,))
trg_seq_length = torch.randint(MIN_WORDS, TRG_SEQ_LEN-1, (BATCH_SIZE,))
if SRC_SEQ_LEN - 1 not in src_seq_length:
    src_seq_length[-1] = SRC_SEQ_LEN - 2
if TRG_SEQ_LEN - 1 not in trg_seq_length:
    trg_seq_length[-1] = TRG_SEQ_LEN - 2

x = torch.randint(0+2, INPUT_DIM-2, size=(BATCH_SIZE, SRC_SEQ_LEN))
x[:, 0] = 0
for i, ind in enumerate(src_seq_length):
    x[i, ind+1 ] = INPUT_DIM - 1
    x[i, ind+2:] = SRC_PAD_IDX

y = torch.randint(0+2, OUTPUT_DIM-2, size=(BATCH_SIZE, TRG_SEQ_LEN))
y[:, 0] = 0
for i, ind in enumerate(trg_seq_length):
    y[i, ind+1 ] = OUTPUT_DIM - 1
    y[i, ind+2:] = TRG_PAD_IDX

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0, 25, 18, 21,  8,  5,  3, 21, 17, 29,  1,  1,  1,  1,  1],
        [ 0, 12, 19, 17, 19, 16, 16, 15, 23, 23, 23, 17, 29,  1,  1],
        [ 0,  2, 24,  8, 10,  7, 29,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 0, 10,  5, 10, 21, 13, 19, 17,  3, 10,  7,  5,  7, 29,  1],
        [ 0, 17, 14,  7,  8, 13, 24,  3,  5, 13, 16,  3, 12, 21, 29]]) torch.Size([5, 15])

tensor([[ 0, 10, 29,  5, 19, 33, 36,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 0, 27, 33, 27, 23,  3, 23, 11, 21,  6, 25, 33, 12,  7, 33,  7, 36,  1],
        [ 0,  4, 28,  3,  2, 12,  2, 10, 11, 18, 26, 27,  2, 36,  1,  1,  1,  1],
        [ 0, 27, 14,  5, 13, 34, 34,  6, 19, 30, 22,  4, 15,  2, 18, 36,  1,  1],
        [ 0, 25,  6, 17, 15, 28,  3,  8, 29, 27, 25, 33, 10, 18, 13, 19, 15, 36]]) torch.Size([5, 18])


In [269]:
emb_enc     = nn.Embedding(INPUT_DIM, HID_DIM)
emb_dec     = nn.Embedding(OUTPUT_DIM, HID_DIM)
emb_dropout = nn.Dropout(DROPOUT_P)

In [270]:
x.shape, y.shape

(torch.Size([5, 15]), torch.Size([5, 18]))

In [271]:
x_lengths = (x != 1).sum(dim=1)

In [304]:
# _generate_mask
mask = []
max_length = max(x_lengths)
for l in x_lengths:
    if max_length - l > 0:
        mask += [torch.cat([x.new_ones(1, l).zero_(),
                            x.new_ones(1, (max_length - l))],
                           dim=-1)]
    else:
        mask += [x.new_ones(1, l).zero_()]
mask = torch.cat(mask, dim=0).bool()

In [305]:
mask

tensor([[False, False, False, False, False, False, False, False, False, False,
          True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True,
          True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False]])

In [306]:
x

tensor([[ 0, 25, 18, 21,  8,  5,  3, 21, 17, 29,  1,  1,  1,  1,  1],
        [ 0, 12, 19, 17, 19, 16, 16, 15, 23, 23, 23, 17, 29,  1,  1],
        [ 0,  2, 24,  8, 10,  7, 29,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 0, 10,  5, 10, 21, 13, 19, 17,  3, 10,  7,  5,  7, 29,  1],
        [ 0, 17, 14,  7,  8, 13, 24,  3,  5, 13, 16,  3, 12, 21, 29]])

In [307]:
(x == 1).byte()

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.uint8)

In [308]:
mask_enc = mask.unsqueeze(1).expand(
    mask.size(0), x.size(1), mask.size(-1))
mask_dec = mask.unsqueeze(1).expand(
    mask.size(0), y.size(1), mask.size(-1))

In [309]:
mask_enc.shape, mask_dec.shape

(torch.Size([5, 15, 15]), torch.Size([5, 18, 15]))

In [310]:
embedded_x = emb_enc(x)
embedded_x.shape

torch.Size([5, 15, 256])

In [311]:
# _positional_encoding
print(f"x.shape is {list(embedded_x.shape)}")
length, hidden_size = embedded_x.size(1), embedded_x.size(-1)
print(f"length is {length} and hidden_size is {hidden_size}")
enc = embedded_x.new_zeros(embedded_x.shape[1:])
print(f"enc is {enc}")
init_pos = 0
pos = init_pos + torch.arange(0, length).unsqueeze(-1)
dim = (10000. ** (torch.arange(0, hidden_size//2).div(hidden_size))).unsqueeze(0)
assert enc[:, 0::2].size() == (pos / dim).size()
assert enc[:, 1::2].size() == (pos / dim).size()
pos = pos.float()
dim = dim.float()
print(f"position is {pos.flatten()}")
print(f"pos div term is {dim}")
embedded_x = embedded_x + enc

x.shape is [5, 15, 256]
length is 15 and hidden_size is 256
enc is tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
position is tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14.])
pos div term is tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 

In [312]:
z = emb_dropout(embedded_x)

In [313]:
z.shape

torch.Size([5, 15, 256])

In [314]:
n_splits = N_SPLITS

In [315]:
# EncoderBlock
# input is z, mask_enc

# First Block
# MultiHead() -> Dropout() -> Residual Connection -> Seq(fc) -> Dropout() -> Residual Connection
# Sublayer's form is LayerNorm(x + sublayer(x))
# Seq(fc); Linear -> ReLU(or Leaky-ReLU) -> Linear

## MultiHead Attention!!
Q_linear = nn.Linear(hidden_size, hidden_size, bias=False)
K_linear = nn.Linear(hidden_size, hidden_size, bias=False)
V_linear = nn.Linear(hidden_size, hidden_size, bias=False)
linear   = nn.Linear(hidden_size, hidden_size, bias=False)

Q = K = V = z # Encoder는 모두 같음
              # (Batch_size, Seg_len, Hid_dim)
mask = mask

QWs = Q_linear(Q).split(hidden_size // n_splits, dim=-1) # (bsz, seqlen, h/8) * 8
KWs = K_linear(K).split(hidden_size // n_splits, dim=-1) # (bsz, seqlen, h/8) * 8
VWs = V_linear(V).split(hidden_size // n_splits, dim=-1) # (bsz, seqlen, h/8) * 8

QWs = torch.cat(QWs, dim=0) # (bsz*8, seqlen, h/8)
KWs = torch.cat(KWs, dim=0) # (bsz*8, seqlen, h/8)
VWs = torch.cat(VWs, dim=0) # (bsz*8, seqlen, h/8)

mask = torch.cat([mask_enc for _ in range(n_splits)], dim=0) # (bsz*8, m, n)
                                                             # in this case, m==n
### attention
dk = hidden_size // n_splits
w = torch.bmm(QWs, KWs.transpose(1, 2)) # reference
                                    # (bsz*n_splits, m, n)
assert w.size() == mask.size()
w.masked_fill_(mask, -float('inf'))

w = nn.Softmax(dim=-1)(w / (dk ** 0.5))
c = torch.bmm(w, VWs) # (bsz*n_splits, m, hidden_size/n_splits)

c = c.split(Q.size(0), dim=0)    # (bsz, m, h/8) * 8
c = linear(torch.cat(c, dim=-1)) # (bsz, m, h) -> (bsz, m, h)

# Prepare layers
attn_dropout = nn.Dropout(dropout_p)
attn_norm    = nn.LayerNorm(hidden_size)
fc = nn.Sequential(
    nn.Linear(hidden_size, hidden_size*4),
    nn.ReLU(),
    nn.Linear(hidden_size*4, hidden_size),
)
fc_dropout = nn.Dropout(dropout_p)
fc_norm    = nn.LayerNorm(hidden_size)

z = attn_norm(z + attn_dropout(c))
z = fc_norm(z + fc_dropout(fc(z))) # (bsz, n, hidden_size)

In [328]:
class Attention(nn.Module):

    def __init__(self):
        super().__init__()

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, mask=None, dk=64):
        # |Q| = (batch_size, m, hidden_size)
        # |K| = |V| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)

        w = torch.bmm(Q, K.transpose(1, 2))
        # |w| = (batch_size, m, n)
        if mask is not None:
            assert w.size() == mask.size()
            w.masked_fill_(mask, -float('inf'))

        w = self.softmax(w / (dk**.5))
        c = torch.bmm(w, V)
        # |c| = (batch_size, m, hidden_size)

        return c


class MultiHead(nn.Module):

    def __init__(self, hidden_size, n_splits):
        super().__init__()

        self.hidden_size = hidden_size
        self.n_splits = n_splits

        # Note that we don't have to declare each linear layer, separately.
        self.Q_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

        self.attn = Attention()

    def forward(self, Q, K, V, mask=None):
        # |Q| = (batch_size, m, hidden_size)
        # |K| = |V| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)

        QWs = self.Q_linear(Q).split(self.hidden_size // self.n_splits, dim=-1)
        KWs = self.K_linear(K).split(self.hidden_size // self.n_splits, dim=-1)
        VWs = self.V_linear(V).split(self.hidden_size // self.n_splits, dim=-1)
        # |QW_i| = (batch_size, m, hidden_size / n_splits)
        # |KW_i| = |VW_i| = (batch_size, n, hidden_size / n_splits)

        # By concatenating splited linear transformed results,
        # we can remove sequential operations,
        # like mini-batch parallel operations.
        QWs = torch.cat(QWs, dim=0)
        KWs = torch.cat(KWs, dim=0)
        VWs = torch.cat(VWs, dim=0)
        # |QWs| = (batch_size * n_splits, m, hidden_size / n_splits)
        # |KWs| = |VWs| = (batch_size * n_splits, n, hidden_size / n_splits)

        if mask is not None:
            mask = torch.cat([mask for _ in range(self.n_splits)], dim=0)
            # |mask| = (batch_size * n_splits, m, n)

        c = self.attn(
            QWs, KWs, VWs,
            mask=mask,
            dk=self.hidden_size // self.n_splits,
        )
        # |c| = (batch_size * n_splits, m, hidden_size / n_splits)

        # We need to restore temporal mini-batchfied multi-head attention results.
        c = c.split(Q.size(0), dim=0)
        # |c_i| = (batch_size, m, hidden_size / n_splits)
        c = self.linear(torch.cat(c, dim=-1))
        # |c| = (batch_size, m, hidden_size)

        return c

class EncoderBlock(nn.Module):

    def __init__(self, hidden_size, n_splits,
                 dropout_p=.1, use_leaky_relu=False
                 ):
        super().__init__()

        self.attn = MultiHead(hidden_size, n_splits)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.attn_dropout = nn.Dropout(dropout_p)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LeakyReLU() if use_leaky_relu else nn.ReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.fc_dropout = nn.Dropout(dropout_p)

    def forward(self, x, mask):
        # |x| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, n, n)

        z = self.attn_norm(x + self.attn_dropout(self.attn(Q=x,
                                                           K=x,
                                                           V=x,
                                                           mask=mask)))
        z = self.fc_norm(z + self.fc_dropout(self.fc(z)))
        # |z| = (batch_size, n, hidden_size)

        return z, mask
    
class MySequential(nn.Sequential):

    def forward(self, *x):
        # nn.Sequential class does not provide multiple input arguments and returns.
        # Thus, we need to define new class to solve this issue.
        # Note that each block has same function interface.

        for module in self._modules.values():
            x = module(*x)

        return x

In [329]:
encoder = MySequential(
            *[EncoderBlock(
                hidden_size,
                n_splits,
                dropout_p,
                False,
              ) for _ in range(6)],
        )

In [347]:
z = emb_dropout(embedded_x)
z, _ = encoder(z, mask_enc)
z.shape # (bsz, n, hid_dim)

torch.Size([5, 15, 256])

In [348]:
embedded_y = emb_dec(y)
# _positional_encoding
print(f"y.shape is {list(embedded_y.shape)}")
length, hidden_size = embedded_y.size(1), embedded_y.size(-1)
print(f"length is {length} and hidden_size is {hidden_size}")
enc = embedded_y.new_zeros(embedded_y.shape[1:])
print(f"enc is {enc}")
init_pos = 0
pos = init_pos + torch.arange(0, length).unsqueeze(-1)
dim = (10000. ** (torch.arange(0, hidden_size//2).div(hidden_size))).unsqueeze(0)
assert enc[:, 0::2].size() == (pos / dim).size()
assert enc[:, 1::2].size() == (pos / dim).size()
pos = pos.float()
dim = dim.float()
print(f"position is {pos.flatten()}")
print(f"pos div term is {dim}")
embedded_y = embedded_y + enc

h = emb_dropout(embedded_y)

y.shape is [5, 18, 256]
length is 18 and hidden_size is 256
enc is tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
position is tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17.])
pos div term is tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [349]:
# DecoderBlock
# input is h, z, mask_dec, None(prev)
prev = None
key_and_value = z

masked_attn = MultiHead(hidden_size, n_splits)
masked_attn_norm = nn.LayerNorm(hidden_size)
masked_attn_dropout = nn.Dropout(dropout_p)

attn = MultiHead(hidden_size, n_splits)
attn_norm = nn.LayerNorm(hidden_size)
attn_dropout = nn.Dropout(dropout_p)

# In the case of inference, we don't have to repeat same feed-forward operations,
# Thus, we save previous feed-forward results, including intermediate results.
if prev is not None:
    z = masked_attn_norm(
        h + masked_attn_dropout(
            masked_attn(h, prev, prev, mask=None)
        )
    ) # (bsz, 1, hidden_size)
else:
    # x shape is (bsz, m, hidden_size)
    batch_size = h.size(0)
    m = h.size(1)
    fwd_mask = torch.triu(h.new_ones((m, m)), diagonal=1).bool() # (m, m)
    fwd_mask = fwd_mask.unsqueeze(0).expand(batch_size, *fwd_mask.size()) # (bsz, m, m)
    z = masked_attn_norm(
        h + masked_attn_dropout(
            masked_attn(h, h, h, mask=fwd_mask)
        )
    ) # (bsz, m, hidden_size)
    
# key_and_value: (bsz, n, hidden_size)
# mask: (bsz, m, n)
z = attn_norm(z + attn_dropout(attn(Q=z,
                                    K=key_and_value,
                                    V=key_and_value,
                                    mask=mask_dec))) # (bsz, m, hidden_size)
z = fc_norm(z + fc_dropout(fc(z))) # (bsz, m, hidden_size)

In [350]:
z.shape

torch.Size([5, 18, 256])

In [356]:
class DecoderBlock(nn.Module):

    def __init__(self, hidden_size, n_splits,
                 dropout_p=.1, use_leaky_relu=False
                 ):
        super().__init__()

        self.masked_attn = MultiHead(hidden_size, n_splits)
        self.masked_attn_norm = nn.LayerNorm(hidden_size)
        self.masked_attn_dropout = nn.Dropout(dropout_p)

        self.attn = MultiHead(hidden_size, n_splits)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.attn_dropout = nn.Dropout(dropout_p)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LeakyReLU() if use_leaky_relu else nn.ReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.fc_dropout = nn.Dropout(dropout_p)

    def forward(self, x, key_and_value, mask, prev):
        # In case of inference, we don't have to repeat same feed-forward operations.
        # Thus, we save previous feed-forward results, including intermediate results.
        if prev is not None:
            # |x| = (batch_size, m=1, hidden_size)
            # |prev| = (batch_size, m', hidden_size)

            z = self.masked_attn_norm(x + self.masked_attn_dropout(
                self.masked_attn(x, prev, prev, mask=None)
            ))
            # |z| = (batch_size, 1, hidden_size)
        else:
            # |x| = (batch_size, m, hidden_size)
            batch_size = x.size(0)
            m = x.size(1)

            fwd_mask = torch.triu(x.new_ones((m, m)), diagonal=1).bool()
            # |fwd_mask| = (m, m)
            fwd_mask = fwd_mask.unsqueeze(0).expand(batch_size, *fwd_mask.size())
            # |fwd_mask| = (batch_size, m, m)

            z = self.masked_attn_norm(x + self.masked_attn_dropout(
                self.masked_attn(x, x, x, mask=fwd_mask)
            ))
            # |z| = (batch_size, m, hidden_size)

        # |key_and_value| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)
        z = self.attn_norm(z + self.attn_dropout(self.attn(Q=z,
                                                           K=key_and_value,
                                                           V=key_and_value,
                                                           mask=mask)))
        # |z| = (batch_size, m, hidden_size)

        z = self.fc_norm(z + self.fc_dropout(self.fc(z)))
        # |z| = (batch_size, m, hidden_size)

        return z, key_and_value, mask, prev

In [357]:
z = emb_dropout(embedded_x)
z, _ = encoder(z, mask_enc)
z.shape # (bsz, n, hid_dim)
# --------------------------------------------
embedded_y = emb_dec(y)
# _positional_encoding
print(f"y.shape is {list(embedded_y.shape)}")
length, hidden_size = embedded_y.size(1), embedded_y.size(-1)
print(f"length is {length} and hidden_size is {hidden_size}")
enc = embedded_y.new_zeros(embedded_y.shape[1:])
print(f"enc is {enc}")
init_pos = 0
pos = init_pos + torch.arange(0, length).unsqueeze(-1)
dim = (10000. ** (torch.arange(0, hidden_size//2).div(hidden_size))).unsqueeze(0)
assert enc[:, 0::2].size() == (pos / dim).size()
assert enc[:, 1::2].size() == (pos / dim).size()
pos = pos.float()
dim = dim.float()
print(f"position is {pos.flatten()}")
print(f"pos div term is {dim}")
embedded_y = embedded_y + enc

h = emb_dropout(embedded_y)

y.shape is [5, 18, 256]
length is 18 and hidden_size is 256
enc is tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
position is tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17.])
pos div term is tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [358]:
decoder = MySequential(
    *[DecoderBlock(
        hidden_size,
        n_splits,
        dropout_p,
        False,
      ) for _ in range(6)],
)

In [359]:
h, _, _, _ = decoder(h, z, mask_dec, None)
h.shape # (bsz, m, hid_dim)

torch.Size([5, 18, 256])

- feature를 어떻게 학습하는가 + inference는 나중시간에...