diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index a42f61385b26c..3ced08f4f8292 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1469,10 +1469,10 @@ def assertWarnsOnceRegex(self, category, regex=''): torch.set_warn_always(prev) if len(ws) == 0: self.fail('no warning caught') - for w in ws: - self.assertTrue(type(w.message) is category) - self.assertTrue(re.match(pattern, str(w.message)), - f'{pattern}, {w.message}') + self.assertTrue(any([type(w.message) is category for w in ws])) + self.assertTrue( + any([re.match(pattern, str(w.message)) for w in ws]), + f'{pattern}, {[w.message for w in ws if type(w.message) is category]}') def assertExpected(self, s, subname=None): r"""