Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

update blocksparse to use additive attention, more consistent with core #85

Merged
merged 1 commit into from
Nov 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 10 additions & 11 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn

from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.utils import bool_mask_to_additive

_mask_type_warning = True

Expand Down Expand Up @@ -119,10 +120,9 @@ def update_mask_type(self, mask: torch.Tensor, to_dtype: torch.dtype):
global _mask_type_warning
if _mask_type_warning:
logging.warning(
"Mask has to be multiplicative. Fixing that but this slows things down"
"Mask has to be additive. Fixing that but this slows things down"
)
_mask_type_warning = False # Only warn once
mask = mask.to(to_dtype)
mask = bool_mask_to_additive(mask)

def forward(
self,
Expand All @@ -136,11 +136,10 @@ def forward(
**kwargs,
) -> torch.Tensor:
r"""
att_mask A 2D attention mask. The dtype must be the same as q. Multiplicative mask where a value
of 1 will keep the value, while a value of 0 will mask the value.
att_mask A 2D attention mask. The dtype must be the same as q. An additive mask is expected,
meaning float values using "-inf" to mask values.
key_padding_mask A mask with size (batch size x sequence length). The dtype must be the same as q.
Multiplicative mask where a value of 1 will keep the value, while a value of 0 will
mask the value.
An additive mask is expected, meaning float values using "-inf" to mask values
"""

# NOTE:
Expand All @@ -149,9 +148,9 @@ def forward(
# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the
# initial attention setup

if att_mask is not None and att_mask.dtype != q.dtype:
if att_mask is not None and att_mask.dtype == torch.bool:
self.update_mask_type(att_mask, q.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != q.dtype:
if key_padding_mask is not None and key_padding_mask.dtype == torch.bool:
self.update_mask_type(key_padding_mask, q.dtype)

assert (
Expand Down Expand Up @@ -197,8 +196,8 @@ def forward(
scale=scale,
key_padding_mask=key_padding_mask,
attn_mask=att_mask,
key_padding_mask_mode=MaskType.MUL,
attn_mask_mode=MaskType.MUL,
key_padding_mask_mode=MaskType.ADD,
attn_mask_mode=MaskType.ADD,
)

# - then (dense) attention is (sparse) attention matrix * dense (value)
Expand Down