In [1]:
import torch
import numpy as np
import torch.nn as nn
import math
import torch.nn.functional as F

In [30]:
class selfAttention(nn.Module):
    def __init__(self, num_attention_heads, input_size, hidden_size):
        super(selfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0 :
            raise ValueError(
                "the hidden size %d is not a multiple of the number of attention heads"
                "%d" % (hidden_size, num_attention_heads)
            )
        #input_size 为输入特征维度
        #hidden_size the number of neurons in the output layer.
        print(f"num heads {num_attention_heads} input size {input_size} hidden size {hidden_size}")
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = hidden_size

        self.key_layer = nn.Linear(input_size, hidden_size)
        self.query_layer = nn.Linear(input_size, hidden_size)
        self.value_layer = nn.Linear(input_size, hidden_size)

    
    def trans_to_multiple_heads(self, x):
        print(f"x.size {x.size()}")
        print(f"x.size()[ : -1] {x.size()[ : -1]}")
        new_size = x.size()[ : -1] + (self.num_attention_heads, self.attention_head_size)
        print(f"new size {new_size}")
        x = x.view(new_size)
        print(f"x shape {x.shape}")
        return x.permute(0, 2, 1, 3)

    def forward(self, x):
        key = self.key_layer(x)
        query = self.query_layer(x)
        value = self.value_layer(x)

        # (batch_size, seq_len, hidden_size)
        print(f"orignal key shape {key.shape}")
        key_heads = self.trans_to_multiple_heads(key)
        print(f"砸裂 key multi heads shape {key_heads.shape}")
        
        query_heads = self.trans_to_multiple_heads(query)
        value_heads = self.trans_to_multiple_heads(value)

        attention_scores = torch.matmul(query_heads, key_heads.permute(0, 1, 3, 2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)  # / dk

        attention_probs = F.softmax(attention_scores, dim = -1)
        context = torch.matmul(attention_probs, value_heads)
        print(f"context {context.shape}")
        
        context = context.permute(0, 2, 1, 3).contiguous()
        new_size = context.size()[ : -2] + (self.all_head_size , )
        context = context.view(*new_size)
        return context

        


In [31]:
features = torch.rand((32, 20, 10))
attention = selfAttention(2, 10, 20)
result = attention.forward(features)
print(result.shape)


num heads 2 input size 10 hidden size 20
orignal key shape torch.Size([32, 20, 20])
x.size torch.Size([32, 20, 20])
x.size()[ : -1] torch.Size([32, 20])
new size torch.Size([32, 20, 2, 10])
x shape torch.Size([32, 20, 2, 10])
砸裂 key multi heads shape torch.Size([32, 2, 20, 10])
x.size torch.Size([32, 20, 20])
x.size()[ : -1] torch.Size([32, 20])
new size torch.Size([32, 20, 2, 10])
x shape torch.Size([32, 20, 2, 10])
x.size torch.Size([32, 20, 20])
x.size()[ : -1] torch.Size([32, 20])
new size torch.Size([32, 20, 2, 10])
x shape torch.Size([32, 20, 2, 10])
context torch.Size([32, 2, 20, 10])
torch.Size([32, 20, 20])
