In [5]:
import numpy as np
import torch
import torch.nn as nn

## Pad Mask

In [2]:
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], True is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

In [3]:
batch_size = 2
seq_len = 3
seq_q = torch.randint(-10, 10, (batch_size, seq_len))
seq_k = torch.randint(-10, 10, (batch_size, seq_len))
seq_k[1][1] = 0

batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()

pad_attn_mask = seq_k.data.eq(0)
pad_attn_mask = pad_attn_mask.unsqueeze(1)

pad_attn_mask_expanded = pad_attn_mask.expand(batch_size, len_q, len_k)

## Subsequence Mask

In [18]:
data = np.ones((5,5))
print('data: \n %s'%(data, ))

data_triu = np.triu(data)
print('data k=0: \n %s'%(data_triu, ))

data_triu = np.triu(data, k=1)
print('data k=1: \n %s'%(data_triu, ))

data_triu = np.triu(data, k=-1)
print('data k=-1: \n %s'%(data_triu, ))

data: 
 [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
data k=0: 
 [[1. 1. 1. 1. 1.]
 [0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 1.]]
data k=1: 
 [[0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]]
data k=-1: 
 [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1.]
 [0. 0. 0. 1. 1.]]


## ScaledDotProductAttention

In [20]:
data = torch.ones((5,5 ))
mask = torch.arange(25).view(data.size())

data.masked_fill_(mask % 2 ==0, -1)

tensor([[-1.,  1., -1.,  1., -1.],
        [ 1., -1.,  1., -1.,  1.],
        [-1.,  1., -1.,  1., -1.],
        [ 1., -1.,  1., -1.,  1.],
        [-1.,  1., -1.,  1., -1.]])

## MultiHeadAttention
