From 3982ee1a9e42839d3f2b2a32f684cc8786bfda77 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Tue, 20 Apr 2021 10:33:51 -0700 Subject: [PATCH] Switch assertWarnsOnceRegex logic to check any instead of all. (#56434) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56434 If we hit multiple TORCH_WARN from different sources when running the statement, it makes more sense to me that we want to check the regex is met in any one of the warning messages instead of all messages. Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D27871946 Pulled By: ailzhang fbshipit-source-id: 5940a8e43e4cc91aef213ef01e48d506fd9a1132 --- torch/testing/_internal/common_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index a42f61385b26c..3ced08f4f8292 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1469,10 +1469,10 @@ def assertWarnsOnceRegex(self, category, regex=''): torch.set_warn_always(prev) if len(ws) == 0: self.fail('no warning caught') - for w in ws: - self.assertTrue(type(w.message) is category) - self.assertTrue(re.match(pattern, str(w.message)), - f'{pattern}, {w.message}') + self.assertTrue(any([type(w.message) is category for w in ws])) + self.assertTrue( + any([re.match(pattern, str(w.message)) for w in ws]), + f'{pattern}, {[w.message for w in ws if type(w.message) is category]}') def assertExpected(self, s, subname=None): r"""