Skip to content

Commit

Permalink
[helpers] Better handling of different sequence dimensions (#143)
Browse files Browse the repository at this point in the history
* Better handling of  different sequence dimensions

* Handle different lengths in rotary embeddings

* catch causal + different sequence lengths in MHA and raise an error + explanation

* Assert when different K/Q lengths, and check the assert in a unit test

* adding a rotary embedding diff q/k unit test
  • Loading branch information
blefaudeux committed Dec 7, 2021
1 parent 2d355ad commit adfb645
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*

- [Sine](xformers/components/positional_embedding/sine.py)
- [Vocabulary](xformers/components/positional_embedding/vocab.py)
- [Rotary](xformers/components/positional_embedding/rotary.py)

</p></details>

Expand Down
13 changes: 4 additions & 9 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,12 @@ def test_different_kq_dimensions(
heads: int,
device: torch.device,
):
if attention_name in {
"global",
"local",
"random",
"lambda",
"linformer",
"blocksparse",
}:

multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

if multi_head.attention.requires_same_k_q_dimensions:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.")
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

seq_q = SEQ - 16
q = torch.rand((BATCH, seq_q, MODEL), device=device)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,16 @@ def test_xformer_decoder_block(

encoded = encoder_block(inputs)
_ = decoder_block(inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask)

# Test different sequence lengths when encoding and decoding
if not decoder_block.mha.attention.requires_same_k_q_dimensions:
if not causal or not hasattr(decoder_block.mha.attention, "causal"):
_ = decoder_block(inputs[:, :-16], encoded)
else:
# Check that we assert properly
with pytest.raises(AssertionError):
_ = decoder_block(inputs[:, :-16], encoded)
else:
# Check that we assert properly
with pytest.raises(AssertionError):
_ = decoder_block(inputs[:, :-16], encoded)
5 changes: 4 additions & 1 deletion tests/test_rotary_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def test_rotary_embeddings(device):
0, 0, 0, 0
].clone() # all diagonal elements will have the same value
att_rot = (
att_rot <= 1e-5
att_rot <= 1e-4
) # all non diagonal elements had lower attention than diagonal (+ float tolerance)
assert torch.all(att_rot)

# Test that different sequence lengths is ok
_, _ = rotary(q[:, :, :-16, :], k)
8 changes: 8 additions & 0 deletions xformers/components/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@ class Attention(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
super().__init__()

# Requires the inputs to be projected
self.requires_input_projection = True

# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
self.requires_head_dimension = False

# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
self.requires_separate_masks = False

# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False

@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
Expand Down
8 changes: 5 additions & 3 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def __init__(
# key padding mask and attention mask must be passed in separately
self.requires_separate_masks = True

def update_mask_type(self, mask: torch.Tensor, to_dtype: torch.dtype):
self.requires_same_k_q_dimensions = True

def update_mask_type(self, mask: torch.Tensor):
global _mask_type_warning
if _mask_type_warning:
logging.warning(
Expand Down Expand Up @@ -141,9 +143,9 @@ def forward(
# initial attention setup

if att_mask is not None and att_mask.dtype == torch.bool:
self.update_mask_type(att_mask, q.dtype)
self.update_mask_type(att_mask)
if key_padding_mask is not None and key_padding_mask.dtype == torch.bool:
self.update_mask_type(key_padding_mask, q.dtype)
self.update_mask_type(key_padding_mask)

assert (
att_mask is None or att_mask.dim() == 2
Expand Down
2 changes: 2 additions & 0 deletions xformers/components/attention/global_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(
else maybe_sparsify(self.attention_mask)
)

self.requires_same_k_q_dimensions = True

def forward(
self,
q: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions xformers/components/attention/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
)
self.rel_pos = calc_rel_pos(seq_len)
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
self.requires_same_k_q_dimensions = True

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
Expand Down
1 change: 1 addition & 0 deletions xformers/components/attention/linformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self.F = nn.Linear(seq_len, k, bias=False)
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.seq_len = seq_len
self.requires_same_k_q_dimensions = True

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
Expand Down
1 change: 1 addition & 0 deletions xformers/components/attention/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(

self.window_size = window_size
self.attention_mask: Optional[torch.Tensor] = None
self.requires_same_k_q_dimensions = True

def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
Expand Down
1 change: 1 addition & 0 deletions xformers/components/attention/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
self.rand_attention_mask: Optional[torch.Tensor] = None
self.constant_masking = constant_masking
self.force_sparsity = force_sparsity
self.requires_same_k_q_dimensions = True

def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
sparsity = 1 - self.r
Expand Down
12 changes: 12 additions & 0 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ def forward(
B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent)
_, S_K, _ = key.size() # K, Q's sequence length could differ

# Catch different query and key length but a causal attention
if S_Q != S_K:
assert (
not self.attention.requires_same_k_q_dimensions
), "This attention mechanism requires query and key to have the same sequence (context) lengths"

if hasattr(self.attention, "causal"):
assert not self.attention.causal, (
"Causal attention is not supported when key and query have different sequence lengths.\n"
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
)

# Calculate query, key, values for all heads in batch
if self.attention.requires_input_projection:
q, k, v = self.in_proj_container(query=query, key=key, value=value)
Expand Down
16 changes: 12 additions & 4 deletions xformers/components/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ def rotate_half(x):


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
def apply_rotary_pos_emb(x, cos, sin):
# NOTE: This could probably be moved to Triton
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]

return (x * cos) + (rotate_half(x) * sin)


class RotaryEmbedding(torch.nn.Module):
Expand Down Expand Up @@ -73,7 +78,10 @@ def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
q, seq_dimension=-2
k, seq_dimension=-2
)

return apply_rotary_pos_emb(q, k, self._cos_cached, self._sin_cached)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)

0 comments on commit adfb645

Please sign in to comment.