Skip to content

Commit

Permalink
Making it explicit whether the attention mechanism supports an attent…
Browse files Browse the repository at this point in the history
…ion mask or not

check the assert
  • Loading branch information
blefaudeux committed Apr 18, 2022
1 parent 9cad6bb commit 8113277
Show file tree
Hide file tree
Showing 16 changed files with 86 additions and 14 deletions.
19 changes: 15 additions & 4 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_xformer_encoder_block(
device: torch.device,
reversible: bool,
):

block_size = 16

attention_config = {
Expand Down Expand Up @@ -112,7 +113,13 @@ def test_xformer_encoder_block(

# Check that we support attention masking, at least interface wise (do not check correctness yet)
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
_ = block(inputs, att_mask=att_mask)
if block.mha.attention.supports_attention_mask:
_ = block(inputs, att_mask=att_mask)
else:
with pytest.raises(AssertionError):
# Check that passing an attention mask to a mechanism which does not support it raises
# an exception
_ = block(inputs, att_mask=att_mask)

# Check that we support input masking, at least interface wise (do not check correctness yet)
input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
Expand Down Expand Up @@ -223,7 +230,10 @@ def test_xformer_decoder_block(
input_mask[input_mask < 0.0] = -float("inf")

encoded = encoder_block(inputs)
_ = decoder_block(inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask)
if decoder_block.mha.attention.supports_attention_mask:
_ = decoder_block(
inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.mha.attention.requires_same_k_q_dimensions:
Expand Down Expand Up @@ -303,8 +313,9 @@ def test_embedding_projection():
_ = block(inputs)

# Check that we support attention masking, at least interface wise (do not check correctness yet)
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
_ = block(inputs, att_mask=att_mask)
if block.mha.attention.supports_attention_mask:
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
_ = block(inputs, att_mask=att_mask)

# Check that we support input masking, at least interface wise (do not check correctness yet)
input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer",
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
Expand Down Expand Up @@ -73,7 +73,7 @@
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "linformer",
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
Expand All @@ -84,7 +84,7 @@
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "linformer",
"name": "scaled_dot_product",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False

@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
Expand Down
5 changes: 4 additions & 1 deletion xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ def __init__(

# key padding mask and attention mask must be passed in separately
self.requires_separate_masks = True

self.requires_same_k_q_dimensions = True

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = True

def update_mask_type(self, mask: torch.Tensor):
global _mask_type_warning
if _mask_type_warning:
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def __init__(

self.causal = causal

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False

self._reset_parameters()

def _reset_parameters(self):
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/favor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __init__(

self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore

# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.supports_key_padding_mask = False

@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
# Only promote fp16 buffers, bfloat16 would be fine for instance
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/attention/fourier_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, dropout: float, *_, **__):
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)

# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.requires_input_projection = False

def forward(self, q: torch.Tensor, *_, **__):
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/attention/global_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def __init__(
else maybe_sparsify(self.attention_mask)
)

# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False

def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
)
self.rel_pos = calc_rel_pos(seq_len)
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)

# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
Expand Down
6 changes: 6 additions & 0 deletions xformers/components/attention/linformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ def __init__(
self.F = nn.Linear(seq_len, k, bias=False)
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.seq_len = seq_len

# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = True

# Properties specific to this attention mechanism
self.supports_attention_mask = False

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(
self.attention_mask: Optional[torch.Tensor] = None
self.requires_same_k_q_dimensions = True

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False

def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
mask = local_1d_pattern(shape[1], window_size)
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/nystrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __init__(
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None

# This attention does not support attention masks
self.supports_attention_mask = False
self.supports_key_padding_mask = True

def forward(
self,
q: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def __init__(
self.subsample_fraction = subsample_fraction
self.landmark_selection = landmark_selection

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False

def forward(
self,
q: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/attention/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def __init__(
self.rand_attention_mask: Optional[torch.Tensor] = None
self.constant_masking = constant_masking
self.force_sparsity = force_sparsity

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False

self.requires_same_k_q_dimensions = True

def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/scaled_dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def __init__(
else:
self.mask = None

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False

def forward(
self,
q: torch.Tensor,
Expand Down
21 changes: 15 additions & 6 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,21 @@ def forward(
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
)

kw_mask_args = {}
if att_mask is not None:
assert (
self.attention.supports_attention_mask
), "This attention does not support attention masks"
kw_mask_args["att_mask"] = att_mask

if key_padding_mask is not None:
assert (
self.attention.supports_key_padding_mask
), "This attention does not support key padding masks"
kw_mask_args["key_padding_mask"] = key_padding_mask

if self.attention.requires_skip_multi_head:
return self.attention(
query, key, value, att_mask=att_mask, key_padding_mask=key_padding_mask
)
return self.attention(query, key, value, **kw_mask_args)

# Calculate query, key, values for all heads in batch
if self.attention.requires_input_projection:
Expand Down Expand Up @@ -199,9 +210,7 @@ def check(t, name):
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_k)

# Self-attend
y = self.attention(
q=q, k=k, v=v, att_mask=att_mask, key_padding_mask=key_padding_mask
)
y = self.attention(q=q, k=k, v=v, **kw_mask_args)

# Re-assemble all head outputs side by side
y = (
Expand Down

0 comments on commit 8113277

Please sign in to comment.