From: [OpenNMT/OpenNMT-py: Open Source Neural Machine Translation in PyTorch](https://github.com/OpenNMT/OpenNMT-py)

In [2]:
import math
import torch
import torch.nn as nn

In [3]:
def generate_relative_positions_matrix(length, max_relative_positions,
                                       cache=False):
    """Generate the clipped relative positions matrix
       for a given length and maximum relative positions"""
    if cache:
        distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0)
    else:
        range_vec = torch.arange(length)
        range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
        distance_mat = range_mat - range_mat.transpose(0, 1)
    distance_mat_clipped = torch.clamp(distance_mat,
                                       min=-max_relative_positions,
                                       max=max_relative_positions)
    # Shift values to be >= 0
    final_mat = distance_mat_clipped + max_relative_positions
    return final_mat


def relative_matmul(x, z, transpose):
    """Helper function for relative positions attention."""
    batch_size = x.shape[0]
    heads = x.shape[1]
    length = x.shape[2]
    x_t = x.permute(2, 0, 1, 3)
    x_t_r = x_t.reshape(length, heads * batch_size, -1)
    if transpose:
        z_t = z.transpose(1, 2)
        x_tz_matmul = torch.matmul(x_t_r, z_t)
    else:
        x_tz_matmul = torch.matmul(x_t_r, z)
    x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1)
    x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
    return x_tz_matmul_r_t

## MultiHead Attention

In [9]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, head_count, model_dim, dropout=0.1,
                 max_relative_positions=0):
        assert model_dim % head_count == 0
        self.dim_per_head = model_dim // head_count
        self.model_dim = model_dim

        super(MultiHeadedAttention, self).__init__()
        self.head_count = head_count

        self.linear_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        self.linear_query = nn.Linear(model_dim,
                                      head_count * self.dim_per_head)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.final_linear = nn.Linear(model_dim, model_dim)

        self.max_relative_positions = max_relative_positions

        if max_relative_positions > 0:
            vocab_size = max_relative_positions * 2 + 1
            self.relative_positions_embeddings = nn.Embedding(
                vocab_size, self.dim_per_head)

    def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None):
        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if attn_type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat(
                        (layer_cache["self_keys"], key),
                        dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat(
                        (layer_cache["self_values"], value),
                        dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif attn_type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and attn_type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len, self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and attn_type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and attn_type == "self":
            context = unshape(context_original
                              + relative_matmul(drop_attn,
                                                relations_values,
                                                False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)

        # Return multi-head attn
        attns = attn.view(batch_size, head_count, query_len, key_len)
        return output, attns

    def update_dropout(self, dropout):
        self.dropout.p = dropout

In [75]:
class MultiHeadedAttention2(nn.Module):
    def __init__(self, head_count, model_dim, dropout=0.1):
        assert model_dim % head_count == 0
        self.dim_per_head = model_dim // head_count
        self.model_dim = model_dim

        super(MultiHeadedAttention2, self).__init__()
        self.head_count = head_count

        self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
        self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head)
        self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.final_linear = nn.Linear(model_dim, model_dim)
    def forward(self, key, value, query, mask=None):
        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous().view(
                batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        key = self.linear_keys(key)
        value = self.linear_values(value)
        query = self.linear_query(query)
        key = shape(key)
        value = shape(value)
        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))
        scores = query_key
        scores = scores.float()

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)
        print(drop_attn.shape, value.shape)
        context_original = torch.matmul(drop_attn, value)
        print(context_original.shape)
        context = unshape(context_original)
        output = self.final_linear(context)
        print(output.shape)
        # Return multi-head attn
        print(attn.shape)
        attns = attn.view(batch_size, head_count, query_len, key_len)
        print(attns.shape)
        return output, attns

In [79]:
mh2 = MultiHeadedAttention2(1, 128)

In [80]:
output2, att2 = mh2(input, input, input)

torch.Size([1, 1, 10, 10]) torch.Size([1, 1, 10, 128])
torch.Size([1, 1, 10, 128])
torch.Size([1, 10, 128])
torch.Size([1, 1, 10, 10])
torch.Size([1, 1, 10, 10])


In [54]:
x = nn.Linear(128, 128)(input)
x.shape

torch.Size([1, 10, 128])

In [64]:
x.view(1, -1, 16, 8).transpose(1, 2).shape

torch.Size([1, 16, 10, 8])

In [63]:
_.transpose(1, 2).contiguous().view(1, -1, 16 * 8).shape

torch.Size([1, 10, 128])

In [101]:
nn.Embedding(1000, 64)(input).shape

torch.Size([1, 64])

In [102]:
input = torch.LongTensor([1])

In [103]:
input.shape

torch.Size([1])

In [35]:
# (batch, key_len, dim)
input = torch.randn(1, 10, 128)

In [36]:
mh = MultiHeadedAttention(16, 128)

In [37]:
output, attn = mh(input, input, input)

In [38]:
output.shape

torch.Size([1, 10, 128])

In [39]:
attn.shape

torch.Size([1, 16, 10, 10])

## Position Encoder

In [104]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

In [105]:
torch.arange(2, 100, 2)

tensor([ 2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36,
        38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72,
        74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98])

In [224]:
max_len = 10
dim = 16

In [245]:
pe = torch.zeros(max_len, dim)

In [246]:
pe.shape

torch.Size([10, 16])

In [247]:
position = torch.arange(0, max_len).unsqueeze(1)
position.shape

torch.Size([10, 1])

In [248]:
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
div_term.shape

torch.Size([8])

In [249]:
torch.sin(position * div_term).shape

torch.Size([10, 8])

In [250]:
pe[:, 0::2] = torch.sin(position.float() * div_term)

In [251]:
pe[:, 1::2] = torch.cos(position.float() * div_term)

In [233]:
pe.shape

torch.Size([10, 16])

In [234]:
pe = pe.unsqueeze(1)
pe.shape

torch.Size([10, 1, 16])

In [235]:
x = torch.randn(10, 1, 16)
x.shape

torch.Size([10, 1, 16])

In [242]:
x[0]

tensor([[ 0.0599,  1.7412,  0.9173,  1.3884,  0.6826, -0.6830,  0.8005, -0.4316,
          0.1788,  0.0446, -1.1409, -0.8781,  0.8874,  0.7887,  0.3160, -1.0217]])

In [243]:
pe[0]

tensor([[0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.]])

In [244]:
(x + pe[:x.size(0)])[0]

tensor([[ 0.0599,  2.7412,  0.9173,  2.3884,  0.6826,  0.3170,  0.8005,  0.5684,
          0.1788,  1.0446, -1.1409,  0.1219,  0.8874,  1.7887,  0.3160, -0.0217]])

In [252]:
pe = pe.unsqueeze(0)
pe.shape

torch.Size([1, 10, 16])

In [253]:
x = torch.randn(1, 10, 16)
x.shape

torch.Size([1, 10, 16])

In [266]:
(x + (pe[:, :x.size(1), :])).shape

torch.Size([1, 10, 16])

In [267]:
(x + pe).shape

torch.Size([1, 10, 16])

In [268]:
pe[:, 1].shape

torch.Size([1, 16])

## Encoder

In [5]:
def sequence_mask(lengths, max_len=None):
    """
    Creates a boolean mask from sequence lengths.
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return (torch.arange(0, max_len, device=lengths.device)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))

In [85]:
# (batch_size, )
# 每个元素为序列的长度
ts = torch.Tensor([5, 4, 3, 2])
ts.shape

torch.Size([4])

In [89]:
mask = ~sequence_mask(ts).unsqueeze(1)
mask.shape

torch.Size([4, 1, 5])

In [90]:
mask

tensor([[[False, False, False, False, False]],

        [[False, False, False, False,  True]],

        [[False, False, False,  True,  True]],

        [[False, False,  True,  True,  True]]])

In [91]:
# (batch_size, 1, 1, seq_len)
mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
mask.shape

torch.Size([4, 1, 1, 5])

In [93]:
# (batch_size, head_count, query_len, key_len)
score = torch.randint(1, 10, (4, 8, 5, 5))
score.shape

torch.Size([4, 8, 5, 5])

In [98]:
scores = score.masked_fill(mask, -1e18)
scores.shape

torch.Size([4, 8, 5, 5])

## Decoder

In [99]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [100]:
sequence_mask(ts)

tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False],
        [ True,  True,  True, False, False],
        [ True,  True, False, False, False]])

In [101]:
sequence_mask(ts)

tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False],
        [ True,  True,  True, False, False],
        [ True,  True, False, False, False]])