Skip to content

Commit

Permalink
update blocksparse to use additive attention, more consistent with core
Browse files Browse the repository at this point in the history
  • Loading branch information
Diana Liskovich committed Nov 7, 2021
1 parent ec26827 commit 8f3c01e
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn

from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.utils import bool_mask_to_additive

_mask_type_warning = True

Expand Down Expand Up @@ -119,10 +120,9 @@ def update_mask_type(self, mask: torch.Tensor, to_dtype: torch.dtype):
global _mask_type_warning
if _mask_type_warning:
logging.warning(
"Mask has to be multiplicative. Fixing that but this slows things down"
"Mask has to be additive. Fixing that but this slows things down"
)
_mask_type_warning = False # Only warn once
mask = mask.to(to_dtype)
mask = bool_mask_to_additive(mask)

def forward(
self,
Expand All @@ -136,11 +136,10 @@ def forward(
**kwargs,
) -> torch.Tensor:
r"""
att_mask A 2D attention mask. The dtype must be the same as q. Multiplicative mask where a value
of 1 will keep the value, while a value of 0 will mask the value.
att_mask A 2D attention mask. The dtype must be the same as q. An additive mask is expected,
meaning float values using "-inf" to mask values.
key_padding_mask A mask with size (batch size x sequence length). The dtype must be the same as q.
Multiplicative mask where a value of 1 will keep the value, while a value of 0 will
mask the value.
An additive mask is expected, meaning float values using "-inf" to mask values
"""

# NOTE:
Expand All @@ -149,9 +148,9 @@ def forward(
# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the
# initial attention setup

if att_mask is not None and att_mask.dtype != q.dtype:
if att_mask is not None and att_mask.dtype == torch.bool:
self.update_mask_type(att_mask, q.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != q.dtype:
if key_padding_mask is not None and key_padding_mask.dtype == torch.bool:
self.update_mask_type(key_padding_mask, q.dtype)

assert (
Expand Down Expand Up @@ -197,8 +196,8 @@ def forward(
scale=scale,
key_padding_mask=key_padding_mask,
attn_mask=att_mask,
key_padding_mask_mode=MaskType.MUL,
attn_mask_mode=MaskType.MUL,
key_padding_mask_mode=MaskType.ADD,
attn_mask_mode=MaskType.ADD,
)

# - then (dense) attention is (sparse) attention matrix * dense (value)
Expand Down

0 comments on commit 8f3c01e

Please sign in to comment.