diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index 3341e2e9..09f3c11c 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -17,10 +17,10 @@ OpStrategy, StrategyType, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, is_tensor_shardable, - register_op_strategy, ) from torch.distributed.tensor.placement_types import Placement, Replicate, Shard