From fe9aa36de21e03693cf66df95d88abde986a87cf Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 1 Jul 2025 13:04:50 +0000 Subject: [PATCH] Remove more invalid / uneven shardings Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore --- autoparallel/optimize_sharding.py | 59 ------------------------------- autoparallel/propagation_rules.py | 34 ++++++++++++++++-- autoparallel/utils.py | 25 +++---------- 3 files changed, 37 insertions(+), 81 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index aa054ad0..81e8a075 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -206,67 +206,8 @@ def add_default_constraints(self): self.add_output_input_consistent_constraint() self.add_inf_cost_constraint() - self.remove_invalid_configurations() self.penalize_inefficient_collectives() - def remove_invalid_configurations(self): - """ - Remove shardings that could yield invalid configurations, - for example, when sharding a view on a dimension that would yield - an empty size. Maybe this should be fixed in the returned specs from PyTorch - though, but removing those invalid cases here for now - """ - for s_i, node in enumerate(self.graph.nodes): - if node.op != "call_function": - continue - # only targetting view for now - if node.target != torch.ops.aten.view.default: - continue - orig_shape = node.args[0].meta["val"].shape - shape = list(node.args[1]) - if len(orig_shape) > len(shape): - # TODO: FIXME as I think we should also handle this case - continue - # print("in heeeererereer", orig_shape, shape) - tgt_op_strat = self.strats[node] - for counter, parent in enumerate(node.all_input_nodes): - curr_op_strat = self.strats[parent] - - for oi, tgt_strat in enumerate(tgt_op_strat.strategies): - spec = tgt_strat.input_specs[counter] - if not isinstance(spec, DTensorSpec): - # TODO: check if this is correct - continue - - for ii, curr_strat in enumerate(curr_op_strat.strategies): - curr_spec = curr_strat.output_specs - if not isinstance(curr_spec, DTensorSpec): - continue - shape = list(node.args[1]) - if -1 in shape: - # handle cases where we need to infer the size - numel = math.prod(orig_shape) - index_loc = shape.index(-1) - # this works because the shape we infer is -1 - # and there is a single one - visible_numel = -math.prod(shape) - shape[index_loc] = numel // visible_numel - for mesh_shape, tgt_plc, curr_plc in zip( - spec.mesh.shape, spec.placements, curr_spec.placements - ): - # only keep view shardings that don't yield empty shapes - # which could happen with S(0)S(0) on a dimension whose shape - # is smaller than world_size - if tgt_plc.is_shard(): - dim = tgt_plc.dim - if shape[dim] % mesh_shape == 0: - shape[dim] /= mesh_shape - else: - self.prob += ( - self.ds[(s_i, counter, oi, ii)]["va"] == 0, - _get_next_name("invalid_view"), - ) - def penalize_inefficient_collectives(self): """ When performing shard_{n} -> replicate (for n != 0), there is additional diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 26f8fe7f..157294dd 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -76,6 +76,32 @@ def _build_meta_tensor(tensor_meta): ) +def remove_invalid_configs(out_strat, mesh): + kept = [] + for strategy in out_strat.strategies: + is_valid = True + output_specs = strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = [output_specs] + specs = list(strategy.input_specs) + list(output_specs) + for spec in specs: + if spec is None: + continue + shape = list(spec.tensor_meta.shape) + for mesh_shape, plc in zip(mesh.shape, spec.placements): + if plc.is_shard(): + dim = plc.dim + if shape[dim] % mesh_shape == 0: + shape[dim] //= mesh_shape + else: + is_valid = False + break + if is_valid: + kept.append(strategy) + + return OpStrategy(kept) + + def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None): if tensor_meta is None: tensor_meta = _gen_tensor_meta(shape) @@ -94,7 +120,9 @@ def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None): continue spec = DTensorSpec.from_dim_map(mesh, op, [], tensor_meta) strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])) - return OpStrategy(strats) + out_strats = OpStrategy(strats) + out_strats = remove_invalid_configs(out_strats, mesh) + return out_strats def _create_all_options(mesh, shape, tensor_meta=None, tensor=None): @@ -112,7 +140,9 @@ def _create_all_options(mesh, shape, tensor_meta=None, tensor=None): for placement in all_options: spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta) strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])) - return OpStrategy(strats) + out_strats = OpStrategy(strats) + out_strats = remove_invalid_configs(out_strats, mesh) + return out_strats @register_rule(operator.getitem) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 1d858684..d275f7a3 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -10,7 +10,7 @@ from torch.distributed.tensor._ops.utils import generate_redistribute_costs from torch.utils._pytree import tree_flatten, tree_map_only -from .propagation_rules import _op_partial_rules, _op_rules +from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs def propagate_tensor_meta(op, user_args, out_strat): @@ -90,7 +90,9 @@ def get_placement_options(mesh, op, specs, user_args): # print(op) if op in _op_rules: - return _op_rules[op](mesh, specs) + out_strat = _op_rules[op](mesh, specs) + out_strat = remove_invalid_configs(out_strat, mesh) + return out_strat strat = [] for spec in specs: @@ -119,24 +121,7 @@ def get_placement_options(mesh, op, specs, user_args): propagate_tensor_meta(op, user_args, out_strat) fill_missing_redistribute_cost(op, specs, out_strat) - - kept = [] - for strategy in out_strat.strategies: - is_valid = True - for input_spec in strategy.input_specs: - shape = list(input_spec.tensor_meta.shape) - for mesh_shape, plc in zip(mesh.shape, input_spec.placements): - if plc.is_shard(): - dim = plc.dim - if shape[dim] % mesh_shape == 0: - shape[dim] /= mesh_shape - else: - is_valid = False - break - if is_valid: - kept.append(strategy) - - out_strat = OpStrategy(kept) + out_strat = remove_invalid_configs(out_strat, mesh) return out_strat