diff --git a/tests/test_dtensor.py b/tests/test_dtensor.py index 22ce387..767c28d 100644 --- a/tests/test_dtensor.py +++ b/tests/test_dtensor.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import functools +from unittest.case import expectedFailure import numpy as np import torch @@ -351,6 +352,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin class ImplicitRegistrationTest(DTensorTestBase): + @expectedFailure @with_comms def test_implicit_registration(self): mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))