In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

[Understanding Masking in PyTorch for Attention Mechanisms](https://medium.com/@swarms/understanding-masking-in-pytorch-for-attention-mechanisms-e725059fd49f) 를 참고함.

## Padding Mask
Padding masks are used to ignore the padding tokens in the input sequences.   
In NLP tasks, sequences are often padded to the same length to enable batch processing.   Padding tokens should not influence the model’s learning, so we apply a padding mask to   ensure they are ignored.  

In [2]:
def  create_padding_mask ( seq, pad_token= 0 ): 
    mask = (seq == pad_token).unsqueeze( 1 ).unsqueeze( 2 ) 
    return mask   # (배치 크기, 1, 1, seq_len) 

In [3]:
# Example usage
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)

tensor([[[[False, False,  True,  True]]],


        [[[False, False, False,  True]]]])


## Sequence Mask
Sequence masks are more flexible and can be used to hide arbitrary parts of the sequence.

torch.triu는 PyTorch에서 **행렬의 상삼각 행렬(upper triangular matrix)** 을
추출하거나 생성하는 데 사용되는 함수입니다. 
이 함수는 행렬의 대각선과 그 위쪽 요소를 유지하고,
나머지 요소를 0으로 설정합니다.

In [4]:
def create_sequence_mask(seq):
    seg_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len,seq_len)), diagonal=1)
    return mask

In [5]:
# Example usage
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])


## Look-ahead Mask
Look-ahead masks prevent the model from looking at future tokens.

In [6]:
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

In [7]:
# Example usage
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])


In [8]:
def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))

    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

In [10]:
# Example usage
d_model = 512
batch_size = 2
seq_len = 4

q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)

In [14]:
mask.shape, q.shape

(torch.Size([4, 4]), torch.Size([2, 4, 512]))

In [19]:
k.transpose(-1, -2).shape

torch.Size([2, 512, 4])

In [22]:
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_weights.shape)

torch.Size([2, 4, 4])


In [20]:
matmul_qk = torch.matmul(q, k.transpose(-2, -1))

In [21]:
matmul_qk.shape

torch.Size([2, 4, 4])