diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e2c17754b1a53..15d83a890aabf 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1676,6 +1676,20 @@ def check_if_enable(test: unittest.TestCase): "rocm": TEST_WITH_ROCM, "asan": TEST_WITH_ASAN } + + invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms)) + if len(invalid_platforms) > 0: + invalid_plats_str = ", ".join(invalid_platforms) + valid_plats = ", ".join(platform_to_conditional.keys()) + + print(f"Test {disabled_test} is disabled for some unrecognized ", + f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ", + "assigned to this flaky test, changing \"Platforms: ...\" to a comma separated ", + f"subset of the following (or leave it blank to match all platforms): {valid_plats}") + + # Sanitize the platforms list so that we continue to disable the test for any valid platforms given + platforms = list(filter(lambda p: p in platform_to_conditional, platforms)) + if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]): skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \ f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \