手写attention

In [6]:
import torch
import torch.nn as nn
import math

class Attention(nn.Module):
    def __init__(self, hidden_dim, dropout_rate=0.1):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.query = nn.Linear(hidden_dim,hidden_dim)
        self.key = nn.Linear(hidden_dim,hidden_dim)
        self.value = nn.Linear(hidden_dim,hidden_dim)

        self.atten_dropout = nn.Dropout(dropout_rate)

    def forward(self, X, attention_mask=None):
        # X (batch,seq,hidden)
        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)

        attention_weight = torch.matmul(Q,K.transpose(-1,-2))/math.sqrt(self.hidden_dim)
        # (batch,seq,seq)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        attention_weight = torch.softmax(attention_weight,dim=-1)
        attention_weight = self.atten_dropout(attention_weight)
        # (bat,seq,hidden)
        attention_output = attention_weight @ V

        return attention_output

X = torch.rand(3,4,2)
attention_mask = torch.tensor([[1,1,1,0],
                            [1,1,0,0],
                            [1,0,0,0]])
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1,4,1)

net = Attention(2)
net(X, attention_mask)


tensor([[[ 0.2128, -0.2109],
         [ 0.3900, -0.3293],
         [ 0.3915, -0.3305],
         [ 0.4655, -0.4188]],

        [[ 0.4893, -0.3809],
         [ 0.4892, -0.3807],
         [ 0.4895, -0.3812],
         [ 0.4888, -0.3800]],

        [[ 0.3125, -0.2668],
         [ 0.3125, -0.2668],
         [ 0.3125, -0.2668],
         [ 0.3125, -0.2668]]], grad_fn=<UnsafeViewBackward0>)