In [1]:
import torch
import torch.nn as nn

In [2]:
class Attention(nn.Module):

    def __init__(self):
        super().__init__()

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, mask=None, dk=64):
        # |Q| = (batch_size, m, hidden_size)
        # |K| = |V| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)

        w = torch.bmm(Q, K.transpose(1, 2))
        # |w| = (batch_size, m, n)
        if mask is not None:
            assert w.size() == mask.size()
            w.masked_fill_(mask, -float('inf'))

        w = self.softmax(w / (dk**.5))
        c = torch.bmm(w, V)
        # |c| = (batch_size, m, hidden_size)

        return c


class MultiHead(nn.Module):

    def __init__(self, hidden_size, n_splits):
        super().__init__()

        self.hidden_size = hidden_size
        self.n_splits = n_splits

        # Note that we don't have to declare each linear layer, separately.
        self.Q_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

        self.attn = Attention()

    def forward(self, Q, K, V, mask=None):
        # |Q| = (batch_size, m, hidden_size)
        # |K| = |V| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)

        QWs = self.Q_linear(Q).split(self.hidden_size // self.n_splits, dim=-1)
        KWs = self.K_linear(K).split(self.hidden_size // self.n_splits, dim=-1)
        VWs = self.V_linear(V).split(self.hidden_size // self.n_splits, dim=-1)
        # |QW_i| = (batch_size, m, hidden_size / n_splits)
        # |KW_i| = |VW_i| = (batch_size, n, hidden_size / n_splits)

        # By concatenating splited linear transformed results,
        # we can remove sequential operations,
        # like mini-batch parallel operations.
        QWs = torch.cat(QWs, dim=0)
        KWs = torch.cat(KWs, dim=0)
        VWs = torch.cat(VWs, dim=0)
        # |QWs| = (batch_size * n_splits, m, hidden_size / n_splits)
        # |KWs| = |VWs| = (batch_size * n_splits, n, hidden_size / n_splits)

        if mask is not None:
            mask = torch.cat([mask for _ in range(self.n_splits)], dim=0)
            # |mask| = (batch_size * n_splits, m, n)

        c = self.attn(
            QWs, KWs, VWs,
            mask=mask,
            dk=self.hidden_size // self.n_splits,
        )
        # |c| = (batch_size * n_splits, m, hidden_size / n_splits)

        # We need to restore temporal mini-batchfied multi-head attention results.
        c = c.split(Q.size(0), dim=0)
        # |c_i| = (batch_size, m, hidden_size / n_splits)
        c = self.linear(torch.cat(c, dim=-1))
        # |c| = (batch_size, m, hidden_size)

        return c

In [87]:
class EncoderBlock(nn.Module):

    def __init__(self, hidden_size, n_splits,
                 dropout_p=.1, use_leaky_relu=False
                 ):
        super().__init__()

        self.attn = MultiHead(hidden_size, n_splits)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.attn_dropout = nn.Dropout(dropout_p)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LeakyReLU() if use_leaky_relu else nn.ReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.fc_dropout = nn.Dropout(dropout_p)

    def forward(self, x, mask):
        # |x| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, n, n)
        print('\n\n\n\n\n\n')
        print(x.mean())
        print(mask.shape)
        z = self.attn_norm(x + self.attn_dropout(self.attn(Q=x,
                                                           K=x,
                                                           V=x,
                                                           mask=mask)))
        z = self.fc_norm(z + self.fc_dropout(self.fc(z)))
        # |z| = (batch_size, n, hidden_size)
        print(z.mean())
        return z, mask

In [88]:
class MySequential(nn.Sequential):

    def forward(self, *x):
        # nn.Sequential class does not provide multiple input arguments and returns.
        # Thus, we need to define new class to solve this issue.
        # Note that each block has same function interface.

        for module in self._modules.values():
            x = module(*x)

        return x

In [89]:
encoder = MySequential(
            *[EncoderBlock(
                512,
                8,
                0.1,
                False,
              ) for _ in range(8)],
        )

In [90]:
x = [[    2, 10549,  3900, 45910,  3680, 45908,  2282,    16,     3,     1,
             1,     1],
        [    2,  2977,  5923,   956, 88842,    39,  2539,     3,     1,     1,
             1,     1],
        [    2,  1230,  1231,  1232,  1233,   323,  1234,  1235,  1236,   110,
             3,     1],
        [    2,  6726, 22523,  2457,   804, 61496,   105, 78511,  1664, 11442,
         71906,     3],
        [    2,  3085,   394,  2615,   133, 26667,     3,     1,     1,     1,
             1,     1]]

In [91]:
x = torch.LongTensor(x)

In [92]:
x

tensor([[    2, 10549,  3900, 45910,  3680, 45908,  2282,    16,     3,     1,
             1,     1],
        [    2,  2977,  5923,   956, 88842,    39,  2539,     3,     1,     1,
             1,     1],
        [    2,  1230,  1231,  1232,  1233,   323,  1234,  1235,  1236,   110,
             3,     1],
        [    2,  6726, 22523,  2457,   804, 61496,   105, 78511,  1664, 11442,
         71906,     3],
        [    2,  3085,   394,  2615,   133, 26667,     3,     1,     1,     1,
             1,     1]])

In [93]:
embedded_x = nn.Embedding(x.max()+5, 512)(x)

In [94]:
embedded_x

tensor([[[ 0.7817,  0.3814,  1.1691,  ...,  0.3171, -1.5591,  0.2325],
         [-0.1527,  0.2612,  0.7388,  ..., -0.1990,  0.8230, -0.0373],
         [ 0.6160, -0.4583,  0.1465,  ...,  0.6104,  0.3407,  1.5700],
         ...,
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559],
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559],
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559]],

        [[ 0.7817,  0.3814,  1.1691,  ...,  0.3171, -1.5591,  0.2325],
         [-0.3984,  0.0500,  0.4003,  ...,  0.8389,  0.6758,  1.1031],
         [ 0.4136,  0.6204, -0.7397,  ..., -0.1243,  0.4232,  0.7346],
         ...,
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559],
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559],
         [ 0.3604, -0.9384, -0.6313,  ..., -0.1238, -1.0589, -1.2559]],

        [[ 0.7817,  0.3814,  1.1691,  ...,  0.3171, -1.5591,  0.2325],
         [ 0.9963,  0.2772, -0.6779,  ..., -0

In [95]:
mask = x == 1

In [96]:
mask_enc = mask.unsqueeze(1).expand(mask.size(0), x.size(1), mask.size(-1))

In [97]:
mask_enc.shape

torch.Size([5, 12, 12])

In [98]:
enc = embedded_x.new_zeros(embedded_x.shape[1:])

In [99]:
enc.shape

torch.Size([12, 512])

In [100]:
pos = torch.arange(0, enc.size(0)).float().unsqueeze(-1)
pos.shape

torch.Size([12, 1])

In [101]:
dim = (10000. ** (torch.arange(0, 512//2).div(512))).unsqueeze(0)

In [102]:
dim.shape

torch.Size([1, 256])

In [103]:
enc[:, 0::2] = torch.sin(pos / dim)
enc[:, 1::2] = torch.cos(pos / dim)

In [104]:
embedded_x += enc

In [105]:
z = nn.Dropout(0.2)(embedded_x)

In [106]:
z

tensor([[[ 0.9771,  1.7267,  1.4614,  ...,  1.6464, -1.9489,  1.5406],
         [ 0.8610,  1.0018,  1.9753,  ...,  0.0000,  2.0805,  0.6288],
         [ 1.9066, -0.0000,  1.3198,  ...,  0.2428,  1.5625,  1.4423],
         ...,
         [ 0.9657, -0.0000, -0.2739,  ..., -1.2937, -0.8085, -2.7088],
         [-0.2295, -2.2218, -1.4691,  ..., -1.2036, -2.0037, -2.6188],
         [-0.7994, -1.1675, -2.0391,  ..., -0.1492, -2.5737, -1.5644]],

        [[ 0.9771,  1.7267,  0.0000,  ...,  1.6464, -1.9489,  1.5406],
         [ 0.5539,  0.0000,  0.0000,  ...,  1.7240,  1.8966,  2.0543],
         [ 1.6536,  0.0000,  0.2120,  ..., -0.6755,  1.6656,  0.3980],
         ...,
         [ 0.9657, -2.3119, -0.2739,  ..., -1.2937, -0.0000, -2.7088],
         [-0.0000, -2.2218, -0.0000,  ..., -0.0000, -2.0037, -2.6188],
         [-0.7994, -1.1675, -2.0391,  ..., -0.1492, -2.5737, -1.5644]],

        [[ 0.0000,  0.0000,  1.4614,  ...,  1.6464, -0.0000,  1.5406],
         [ 2.2973,  1.0219,  0.2045,  ...,  0

In [107]:
encoder(z, mask_enc)








tensor(-0.0028, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(-1.1859e-08, grad_fn=<MeanBackward0>)







tensor(-1.1859e-08, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(-2.9802e-09, grad_fn=<MeanBackward0>)







tensor(-2.9802e-09, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(-7.9473e-09, grad_fn=<MeanBackward0>)







tensor(-7.9473e-09, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(-9.1890e-09, grad_fn=<MeanBackward0>)







tensor(-9.1890e-09, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(4.2220e-09, grad_fn=<MeanBackward0>)







tensor(4.2220e-09, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(3.4769e-09, grad_fn=<MeanBackward0>)







tensor(3.4769e-09, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(9.9341e-10, grad_fn=<MeanBackward0>)







tensor(9.9341e-10, grad_fn=<MeanBackward0>)
torch.Size([5, 12, 12])
tensor(6.9539e-09, grad_fn=<MeanBackward0>)


(tensor([[[ 5.1711e-01,  2.6203e-02,  1.9705e+00,  ...,  1.3971e+00,
           -1.9742e+00,  8.4321e-01],
          [ 7.6448e-01, -4.3853e-01,  1.0880e+00,  ..., -6.9765e-03,
            8.7684e-01, -1.8241e-02],
          [ 1.5311e+00, -1.5038e-01,  9.8502e-01,  ...,  3.0358e-01,
            2.5034e-01,  1.8279e+00],
          ...,
          [ 7.8941e-01, -9.3172e-02,  6.8404e-01,  ...,  2.3345e-01,
           -7.9399e-01, -3.6145e-01],
          [-2.1299e-01, -1.7069e+00,  2.1603e-02,  ...,  3.3092e-01,
           -9.4661e-01, -2.4939e-02],
          [-2.0125e-01, -1.0280e+00, -4.2630e-01,  ...,  1.4328e+00,
           -9.7178e-01,  6.7541e-01]],
 
         [[ 6.5046e-01,  8.5465e-02,  2.2483e-01,  ...,  5.7862e-01,
           -1.5922e+00, -4.8244e-01],
          [-2.8916e-01, -1.5411e+00,  4.4682e-01,  ...,  1.0592e+00,
            5.5628e-01,  7.6843e-01],
          [ 2.2324e+00, -9.1635e-01,  8.9765e-01,  ..., -1.6690e-01,
            5.0763e-01, -7.0927e-02],
          ...,
    