## Step1. multi-head가 없는 그냥 attention이라면?

In [2]:
def forward(self, query, key, value):

    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)
    # Q: [query_len, hidden_dim]
    # K: [key_len, hidden_dim]
    # V: [value_len, hidden_dim]

    # 어텐션(attention) 스코어 계산
    energy = torch.matmul(Q, K.permute(1, 0)) / self.scale
    # energy: [query_len, key_len]

    # 확률 분포로 변환
    attention = torch.softmax(energy, dim=-1)
    # attention: [query_len, key_len]

    # Scaled Dot-Product Attention 계산
    x = torch.matmul(attention, V)
    # x: [query_len, hidden_dim]
    return x

matmul에 대한 파이토치 공식 문서: https://pytorch.org/docs/stable/generated/torch.matmul.html  
파이토치에서의 matmul과 dot의 차이: https://velog.io/@regista/torch.dot-torch.matmul-torch.mm-torch.bmm  


In [4]:
"""In this work we employ h = 8 parallel attention layers, or heads. For each of these we use
dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost
is similar to that of single-head attention with full dimensionality."""
"""문장에서 dmodel == 코드에서 hidden_dim"""

'In this work we employ h = 8 parallel attention layers, or heads. For each of these we use\ndk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost\nis similar to that of single-head attention with full dimensionality.'

## Step2. 논문에 따라, head의 개수로 쪼갠 attention이라면?

In [None]:
def forward(self, query, key, value):

    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)

    # Q: [query_len, hidden_dim]
    # K: [key_len, hidden_dim]
    # V: [value_len, hidden_dim]

    # hidden_dim → n_heads X head_dim 형태로 변형
    # n_heads(h)개의 서로 다른 어텐션(attention) 컨셉을 학습하도록 유도
    Q = Q.view(-1, self.n_heads, self.head_dim)
    K = K.view(-1, self.n_heads, self.head_dim)
    V = V.view(-1, self.n_heads, self.head_dim)

    # Q: [query_len, n_heads, head_dim]
    # K: [key_len, n_heads, head_dim]
    # V: [value_len, n_heads, head_dim]
    
    Q = Q.permute(1, 0, 2)
    K = K.permute(1, 0, 2)
    V = V.permute(1, 0, 2)
    
    # Q: [n_heads, query_len, head_dim]
    # K: [n_heads, key_len, head_dim]
    # V: [n_heads, value_len, head_dim]

    # 어텐션(attention) 스코어 계산
    energy = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale
    # energy: [n_heads, query_len, key_len]

    # 확률 분포로 변환
    attention = torch.softmax(energy, dim=-1)
    # attention: [n_heads, query_len, key_len]

    # Scaled Dot-Product Attention 계산
    x = torch.matmul((attention), V)
    # x: [ n_heads, query_len, head_dim]

    x = x.permute(1, 0, 2).contiguous()
    # x: [query_len, n_heads, head_dim]

    x = x.view(-1, self.hidden_dim)
    # x: [query_len, hidden_dim]

    x = self.fc_o(x)
    # x: [query_len, hidden_dim]

    return x, attention

"""step1의 코드와 비교해보면, 제일 첫번째로 들어가는 차원에서 n_heads만 추가되었다 뿐이지, 그 뒤에 차원은 그대로 진행됩니다."""

## Step3. 그런데 실무에서는, 이 데이터(문장)가 하나씩 들어오는 게 아니라, batch size만큼의 뭉치로 들어옵니다. 따라서 batch_size만큼 쌓여서 들어온다고 가정해줘야 합니다.

In [None]:
def forward(self, query, key, value):
    
    # query: [batch_size, query_len, hidden_dim]
    # key: [batch_size, key_len, hidden_dim]
    # value: [batch_size, value_len, hidden_dim]

    batch_size = query.shape[0]
    
    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)

    # Q: [batch_size, query_len, hidden_dim]
    # K: [batch_size, key_len, hidden_dim]
    # V: [batch_size, value_len, hidden_dim]

    # hidden_dim → n_heads X head_dim 형태로 변형
    # n_heads(h)개의 서로 다른 어텐션(attention) 컨셉을 학습하도록 유도
    Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.view(batch_s`ze, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

    # Q: [batch_size, n_heads, query_len, head_dim]
    # K: [batch_size, n_heads, key_len, head_dim]
    # V: [batch_size, n_heads, value_len, head_dim]

    # 어텐션(attention) 스코어 계산
    energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
    # energy: [batch_size, n_heads, query_len, key_len]

    # 확률 분포로 변환
    attention = torch.softmax(energy, dim=-1)
    # attention: [batch_size, n_heads, query_len, key_len]

    # Scaled Dot-Product Attention 계산
    x = torch.matmul(self.dropout(attention), V)
    # x: [batch_size, n_heads, query_len, head_dim]

    x = x.permute(0, 2, 1, 3).contiguous()
    # x: [batch_size, query_len, n_heads, head_dim]

    x = x.view(batch_size, -1, self.hidden_dim)
    # x: [batch_size, query_len, hidden_dim]

    x = self.fc_o(x)
    # x: [batch_size, query_len, hidden_dim]

    return x, attention

## Step4. mask가 있을 경우

masked_fill 파이토치 공식 문서: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

In [None]:
def forward(self, query, key, value, mask=None):

    batch_size = query.shape[0]

    # query: [batch_size, query_len, hidden_dim]
    # key: [batch_size, key_len, hidden_dim]
    # value: [batch_size, value_len, hidden_dim]
    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)

    # Q: [batch_size, query_len, hidden_dim]
    # K: [batch_size, key_len, hidden_dim]
    # V: [batch_size, value_len, hidden_dim]

    # hidden_dim → n_heads X head_dim 형태로 변형
    # n_heads(h)개의 서로 다른 어텐션(attention) 컨셉을 학습하도록 유도
    Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

    # Q: [batch_size, n_heads, query_len, head_dim]
    # K: [batch_size, n_heads, key_len, head_dim]
    # V: [batch_size, n_heads, value_len, head_dim]

    # 어텐션(attention) 스코어 계산
    energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
    # energy: [batch_size, n_heads, query_len, key_len]

    # 마스크(mask)를 사용하는 경우
    if mask is not None:
        # 마스크(mask) 값이 0인 부분을 -1e10으로 채우기
        energy = energy.masked_fill(mask==0, -1e10)
    
    # 확률 분포로 변환
    attention = torch.softmax(energy, dim=-1)
    # attention: [batch_size, n_heads, query_len, key_len]

    # Scaled Dot-Product Attention 계산
    x = torch.matmul(self.dropout(attention), V)
    # x: [batch_size, n_heads, query_len, head_dim]

    x = x.permute(0, 2, 1, 3).contiguous()
    # x: [batch_size, query_len, n_heads, head_dim]

    x = x.view(batch_size, -1, self.hidden_dim)
    # x: [batch_size, query_len, hidden_dim]

    x = self.fc_o(x)
    # x: [batch_size, query_len, hidden_dim]

    return x, attention