Skip to content

Commit

Permalink
fMHA: Explain reason why inputs are not supported
Browse files Browse the repository at this point in the history
Should make debugging easier

**EXAMPLE**

```
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : torch.Size([1, 4096, 160, 128]) (torch.float32)
     key         : torch.Size([1, 4096, 160, 128]) (torch.float32)
     value       : torch.Size([1, 4096, 160, 127]) (torch.float32)
     attn_bias   : <class 'NoneType'>
     p           : 0.0
`cutlassF` is not supported because:
    (value.shape[-1] % 4) != 0
`flshattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    query.shape[-1] != value.shape[-1]
`tritonflashattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    query.shape[-1] != value.shape[-1]
`smallkF` is not supported because:
    query.shape[-1] != value.shape[-1]
    max(query.shape[-1] != value.shape[-1]) > 32
    unsupported embed per head: 128
```

ghstack-source-id: b94263f9ef62017451f94a2f8855a0d95a81700d
Pull Request resolved: https://github.com/fairinternal/xformers/pull/416

__original_commit__ = fairinternal/xformers@d1bc8c6fc220d29326e305d33314bb9908b36484
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 12, 2023
1 parent b4d8ae3 commit ac5fd49
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 107 deletions.
71 changes: 63 additions & 8 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import random
from dataclasses import dataclass
from typing import Any, Sequence, Tuple, Type
from typing import Any, Sequence, Tuple, Type, TypeVar

import pytest
import torch
Expand All @@ -19,16 +19,14 @@

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
compute_capability = (0, 0)
if torch.cuda.is_available():
_devices = ["cuda"]
_is_sm75 = torch.cuda.get_device_capability(_devices[0]) >= (7, 5)
else:
_devices = []
_is_sm75 = False
sm75_or_better_only = pytest.mark.skipif(not _is_sm75, reason="requires sm75+")
compute_capability = torch.cuda.get_device_capability("cuda")
sm75_or_better_only = pytest.mark.skipif(
compute_capability < (7, 5), reason="requires sm75+"
)
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


ALL_FW_OPS: Sequence[Type[fmha.common.AttentionFwOpBase]] = [
fmha.cutlass.FwOp,
fmha.flash.FwOp,
Expand All @@ -43,6 +41,23 @@
fmha.small_k.BwOp,
]

T = TypeVar(
"T", Type[fmha.common.AttentionFwOpBase], Type[fmha.common.AttentionBwOpBase]
)


def _filter_unsupported_ops(ops: Sequence[T]) -> Sequence[T]:
return [
op
for op in ops
if "cpu" in op.SUPPORTED_DEVICES
or op.CUDA_MINIMUM_COMPUTE_CAPABILITY <= compute_capability
]


ALL_FW_OPS = _filter_unsupported_ops(ALL_FW_OPS)
ALL_BW_OPS = _filter_unsupported_ops(ALL_BW_OPS)


def sample_random_supported_fw(
inp: fmha.Inputs, seed: int
Expand Down Expand Up @@ -1111,3 +1126,43 @@ def test_grad_checkpointing(
use_reentrant=use_reentrant,
)
x.mean().backward()


ALL_FW_OPS_NO_SMALLK = [op for op in ALL_FW_OPS if op is not fmha.small_k.FwOp]


@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_cpu(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 1, 1, 32])
with pytest.raises(ValueError):
fmha.memory_efficient_attention(q, q, q, op=(op, None))


@cuda_only
@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 1, 32, 4], device="cuda", dtype=torch.float16).permute(
0, 1, 3, 2
)
try:
fmha.memory_efficient_attention(q, q, q, op=(op, None))
except ValueError:
q = q.contiguous()
fmha.memory_efficient_attention(q, q, q, op=(op, None))


@cuda_only
@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 2, 2, 33], device="cuda", dtype=torch.float16)[:, :, :, :32]
try:
fmha.memory_efficient_attention(q, q, q, op=(op, None))
except ValueError:
q = q.contiguous()
fmha.memory_efficient_attention(q, q, q, op=(op, None))
21 changes: 9 additions & 12 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LowerTriangularMask,
bmk2bmhk,
)
from .dispatch import _dispatch_bw, _dispatch_fw
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
from .tensor_with_seqlen import TensorWithSeqLen # noqa

MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
Expand Down Expand Up @@ -307,10 +307,8 @@ def _memory_efficient_attention_forward(
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp)
elif not op.supports(inp):
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)

out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)
Expand All @@ -323,10 +321,8 @@ def _memory_efficient_attention_forward_requires_grad(
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp)
elif not op.supports(inp):
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
Expand Down Expand Up @@ -378,10 +374,11 @@ def _memory_efficient_attention_backward(

if op is None:
op = _dispatch_bw(inp)
elif not op.supports(inp):
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
else:
_ensure_op_supports_or_raise(
ValueError, "memory_efficient_attention_backward", op, inp
)

grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
Expand Down
48 changes: 38 additions & 10 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class AttentionOpBase(BaseOperator):

OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
Expand All @@ -196,40 +197,51 @@ class AttentionOpBase(BaseOperator):

@classmethod
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
reasons = []
device_type = d.query.device.type
dtype = d.query.dtype
if not cls.SUPPORTS_TENSOR_WITH_SEQLEN and (
isinstance(d.query, TensorWithSeqLen)
or isinstance(d.key, TensorWithSeqLen)
or isinstance(d.value, TensorWithSeqLen)
):
return False
reasons.append("tensors with custom seqlen are not supported")
if device_type not in cls.SUPPORTED_DEVICES:
return False
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
if dtype not in cls.SUPPORTED_DTYPES:
return False
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if (
not cls.SUPPORTS_DIFFERENT_VALUE_EMBED
and d.query.shape[-1] != d.value.shape[-1]
):
return False
reasons.append("query.shape[-1] != value.shape[-1]")
if max(d.query.shape[-1], d.value.shape[-1]) > cls.SUPPORTED_MAX_K:
return False
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
return False
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
return False
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
return False
reasons.append("has custom scale")
# bfloat16 is only supported on A100+
# ... although the kernels can still run and give the
# correct result
if dtype is torch.bfloat16 and (
not device_type.startswith("cuda")
or torch.cuda.get_device_capability(d.query.device)[0] < 8
):
return False
return True
reasons.append("bf16 is only supported on A100+ GPUs")
return reasons


class AttentionFwOpBase(AttentionOpBase):
Expand Down Expand Up @@ -315,3 +327,19 @@ def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)


def check_lastdim_alignment_stride1(
reasons: List[str], name: str, x: torch.Tensor, alignment: int
) -> None:
if x.shape[-1] % alignment != 0:
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
elif x.stride(-2) % alignment != 0:
reasons.append(
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
)
# We can have stride=0 sometimes if dimension=1
if x.stride(-1) > 1:
reasons.append(
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
)
60 changes: 30 additions & 30 deletions xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Gradients,
Inputs,
LowerTriangularMask,
check_lastdim_alignment_stride1,
)
from .tensor_with_seqlen import TensorWithSeqLen

Expand All @@ -29,7 +30,9 @@ def _uses_tensorcores(sm: int, is_half: bool) -> bool:


def _minimum_gemm_alignment(inp: Inputs) -> int:
cap = torch.cuda.get_device_capability(inp.query.device)
if inp.device.type != "cuda":
return 1
cap = torch.cuda.get_device_capability(inp.device)
sm = cap[0] * 10 + cap[1]
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
inp.query.dtype
Expand Down Expand Up @@ -117,15 +120,12 @@ def apply(
return out, ctx

@classmethod
def supports(cls, d: Inputs) -> bool:
if not super(FwOp, cls).supports(d):
return False
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
if (d.query.shape[-1] % matmul_alignment_mn != 0) or (
d.value.shape[-1] % matmul_alignment_mn != 0
):
return False
return True
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
return reasons


@register_operator
Expand All @@ -148,28 +148,28 @@ class BwOp(AttentionBwOpBase):
]

@classmethod
def supports(cls, d: Inputs) -> bool:
if not FwOp.supports(d):
return False
cap = torch.cuda.get_device_capability(d.query.device)
sm = cap[0] * 10 + cap[1]
# Sm86 does not have enough shared-memory
# See https://github.com/facebookresearch/xformers/issues/517
if (
sm >= 80
and sm != 80
and d.query.dtype is torch.float
and max(d.query.shape[-1], d.key.shape[-1]) > 64
):
return False
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
if (
(d.query.shape[-1] % matmul_alignment_mn != 0)
or (d.value.shape[-1] % matmul_alignment_mn != 0)
or (d.key.shape[-1] % matmul_alignment_mn != 0)
):
return False
return True
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
if d.device.type == "cuda":
cap = torch.cuda.get_device_capability(d.device)
sm = cap[0] * 10 + cap[1]
# Sm86 does not have enough shared-memory
# See https://github.com/facebookresearch/xformers/issues/517
if (
sm >= 80
and sm != 80
and d.query.dtype is torch.float
and max(d.query.shape[-1], d.key.shape[-1]) > 64
):
reasons.append(
f"Sm{sm} does not have enough shared-memory to run this kernel"
" - see https://github.com/facebookresearch/xformers/issues/517"
)
return reasons

@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
Expand Down

0 comments on commit ac5fd49

Please sign in to comment.