Skip to content

Commit

Permalink
Fix allowed dtypes for mem_eff attention (pytorch#116026)
Browse files Browse the repository at this point in the history
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
pytorch-labs/gpt-fast#49

Pull Request resolved: pytorch#116026
Approved by: https://github.com/janeyx99
  • Loading branch information
drisspg authored and dmenig committed Dec 21, 2023
1 parent d7bf0db commit 432b6af
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
return false;
#endif
// Constraints specific to mem efficient attention
constexpr auto default_mem_efficient_dtypes =
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
constexpr auto sm50_mem_efficient_dtypes =
constexpr auto less_than_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);

// Define gate functions that determine if a mem efficient kernel can be ran
Expand Down Expand Up @@ -381,10 +381,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}

auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major == 5) {
return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug);
if (dprop->major >= 8) {
return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
}
return check_tensor_dtype(params, default_mem_efficient_dtypes, debug);
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
}

SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
Expand Down
6 changes: 4 additions & 2 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8

def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
Expand Down Expand Up @@ -1503,8 +1504,9 @@ def test_nested_fails_on_padding_head_dim(self, device):


@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
def test_mem_efficient_fail_bfloat16_sm50(self, device):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
"Current platform does not support fused SDPA or is an SM80+ device.")
def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
dtype = torch.bfloat16
size = SdpaShape(16, 16, 32, 32)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
Expand Down

0 comments on commit 432b6af

Please sign in to comment.