diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 924a53922efcd..20d2647dca4aa 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -320,9 +320,9 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de return false; #endif // Constraints specific to mem efficient attention - constexpr auto default_mem_efficient_dtypes = + constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat, at::kBFloat16); - constexpr auto sm50_mem_efficient_dtypes = + constexpr auto less_than_sm80_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat); // Define gate functions that determine if a mem efficient kernel can be ran @@ -361,10 +361,10 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de } auto dprop = at::cuda::getCurrentDeviceProperties(); - if (dprop->major == 5) { - return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug); + if (dprop->major >= 8) { + return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); } - return check_tensor_dtype(params, default_mem_efficient_dtypes, debug); + return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); } SDPBackend select_sdp_backend(sdp_params const& kernel_params) { diff --git a/test/test_transformers.py b/test/test_transformers.py index 81e574b756558..b5f0ca082f48b 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -64,6 +64,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool): isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)] isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5 +isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8 def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: deviation = true_value - computed_value @@ -1490,8 +1491,9 @@ def test_nested_fails_on_padding_head_dim(self, device): @onlyCUDA - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware") - def test_mem_efficient_fail_bfloat16_sm50(self, device): + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, + "Current platform does not support fused SDPA or is an SM80+ device.") + def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device): dtype = torch.bfloat16 size = SdpaShape(16, 16, 32, 32) make_tensor = partial(torch.rand, size, device=device, dtype=dtype)