Skip to content

Commit

Permalink
Fix allowed dtypes for mem_eff attention (pytorch#116026)
Browse files Browse the repository at this point in the history
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
pytorch-labs/gpt-fast#49

Pull Request resolved: pytorch#116026
Approved by: https://github.com/janeyx99
  • Loading branch information
drisspg committed Dec 21, 2023
1 parent df3cab8 commit 67c9a77
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de
return false;
#endif
// Constraints specific to mem efficient attention
constexpr auto default_mem_efficient_dtypes =
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
constexpr auto sm50_mem_efficient_dtypes =
constexpr auto less_than_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);

// Define gate functions that determine if a mem efficient kernel can be ran
Expand Down Expand Up @@ -361,10 +361,10 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de
}

auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major == 5) {
return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug);
if (dprop->major >= 8) {
return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
}
return check_tensor_dtype(params, default_mem_efficient_dtypes, debug);
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
}

SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
Expand Down
6 changes: 4 additions & 2 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8

def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
Expand Down Expand Up @@ -1490,8 +1491,9 @@ def test_nested_fails_on_padding_head_dim(self, device):


@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
def test_mem_efficient_fail_bfloat16_sm50(self, device):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
"Current platform does not support fused SDPA or is an SM80+ device.")
def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
dtype = torch.bfloat16
size = SdpaShape(16, 16, 32, 32)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
Expand Down

0 comments on commit 67c9a77

Please sign in to comment.