In [None]:
import torch
import torch.nn as nn

### Sequence mask

* works by setting X[i, valid_len[i]:, :, ...] = masked_value (e.g., 0) for tensor X

e.g.: if X is two-dimensional tensor of setences

![sequence mask illustraion](../../97_assets/images/03_transformers_attention_notes_fig1_sequence_mask_illustration.png)

In [None]:
""" sequence mask """

def sequence_mask(x, valid_len, value=0):
    """
    Apply mask to a sequence

    Args:
        x: (torch.tensor) input tensor that represents a sequence data
        valid_len: (torch.tensor; 1D) mask x[i, valid_len[i]:] to value
        value: (torch.dtype) value to fill in masked entries; default=0

    Return:
        (torch.tensor) a tensor of same size as x with masked entries 
    """
    output = x.clone()
    for idx, sequence in enumerate(output):
        try:
            sequence[int(valid_len[idx]):] = value
        except IndexError:
            print("valid_len length mismatch!")
    return output

In [None]:
""" masked softmax """

def masked_softmax(x, valid_len):
    """
    Masked softmax function

    - applies mask to dim=-1 of input tensor before softmax
    - i.e., set x[i, :, :, ..., len[i]] = value (very small negative s.t. exp(value)=0)

    Args:
        x: (torch.tensor) input tensor; only accepts 3D
        valid_len: (torch.tensor) for the last dimension, any value in x outside corresponding valid_length would be masked as zero

    Return:
        (torch.tensor) as tensor of the same shape as x but each value is within range [0,1] & sums to 1
   
    """

    if valid_len is None:
        return nn.functional.softmax(x, dim=-1)
    else:
        if x.dim() != 3:
            raise TypeError("input {}-D tensor shape not supported!".format(x.dim()))
        shape = x.shape
        if valid_len.dim() == 1:
            # broadcasts valid_len[i] to all entries in dim=1
            # by repeats=shape[1], this assumes x is a 3D tensor; if tensor is 4D, should do repeats=shape[1]*shape[2]
            valid_len = torch.repeat_interleave(valid_len, repeats=shape[1], dim=0)
        else:
            valid_len = valid_len.reshape(-1)
        
        x = sequence_mask(x.reshape(-1, shape[-1]), valid_len, value=-1e6)
        return nn.functional.softmax(x.reshape(shape), dim=-1)

In [None]:
""" dot-product attentIon """

class DotProductAttention(nn.Module):
    """
    dot-product attention

    - score(query, key) = dot-product(query, key)
    - requires that dim_q = dim_k
    """

    def __init__(self, dropout=0, **kwargs):
        """
        constructor

        Args:
            dropout: (float) dropout rate, default=0 (no dropout)
            **kwargs: pointer to additional arguments
        """
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, valid_len=None):
        """
        forward method

        Args:
            query: (torch.tensor) [batch, #_query, d]; query input to the attention module
            key: (torch.tensor) [batch, #_kv, d]; use query and key to compute attention weights with score function + softmax
            value: (torch.tensor) [batch, #_kv, dim_v]; output is weighted sum of values
            valid_len: (torch.tensor) attention mask length applied to softmax for weights; [batch] or [batch, xx](?)
        """
        d = query.shape[-1]
        # 1) use torch.bmm to compute dot-product between query and keys; need to transpose dim=1 and dim=2 in key tensor to match dimensions
        scores = torch.bmm(query, key.transpose(1, 2))
        # 2) Q: why apply dropout to masked softmax output? dropout is random for all entries right?
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        # 3) use torch.bmm again to compute weighted average of values by attention_weights
        # -- no need for additional transpose, as attention_weights = [batch, #_query, #_kv]
        return torch.bmm(attention_weights, value)


In [None]:
""" 
test dot-product attention module
"""

### ---- dot-product attention --- ###
atten = DotProductAttention(dropout=0.5)
atten.eval()
# keys = [batch=2, #_kv=10, d=2]
keys = torch.ones(2, 10, 2)
# values = [batch=2, #kv_10, dim_v=4]
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
# query = [batch=2, #_query=1, d=2]
query = torch.ones(2, 1, 2)
# valid_len: for batch=0, mask=2:; for batch=1, mask=6:
valid_len = torch.tensor([2, 6])
# output = [batch=2, #_query, dim_v]
out = atten(query, keys, values, valid_len)
print(out.shape)
print(out)