From e6dea638a68c2c40657499b7317ba45d4ec29890 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 4 Sep 2025 15:28:52 +0000 Subject: [PATCH] Add sharding_transition_cost to getitem operators Also update the OpSpec redistribution cost with the newly computed communication cost, so that it makes it easier for debugging with print_costs_for_node --- autoparallel/optimize_sharding.py | 42 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 49f6d020..da4a1d8c 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -304,21 +304,6 @@ def build_ds(self): self.strats[all_input_nodes[argi]] if all_input_nodes else None ) for ii, comm_cost in enumerate(xxi): - if argi_strat is not None: - src_spec = argi_strat.strategies[ii].output_specs - # TODO: operator.getitem being special is something - # we might want to change in the future - if node.target == operator.getitem: - src_spec = src_spec[node.args[1]] - tgt_spec = ssi.input_specs[argi] - assert isinstance(src_spec, DTensorSpec) - assert isinstance(tgt_spec, DTensorSpec) - # we use our custom comm_cost function to estimate the cost - # of the collective operation - comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec) - - if node in grad_param_nodes: - comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp # Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at # RR, and that node_{i+1} is an op with zero cost (like alias). # In this case, all of the following chains yield the same cost: @@ -332,19 +317,34 @@ def build_ds(self): # in a single go. To do this, we add a tie-break cost that is 1 if a redistribution # happens prior to getting to this configuration, and 0 otherwise. This way, # we will favor having fewer redistributions happening in the graph. - if argi_strat is not None and node.target != operator.getitem: - original_placement = argi_strat.strategies[ - ii - ].output_specs.placements - current_placement = ssi.input_specs[argi].placements + if argi_strat is not None: + src_spec = argi_strat.strategies[ii].output_specs + # TODO: operator.getitem being special is something + # we might want to change in the future + if node.target == operator.getitem: + src_spec = src_spec[node.args[1]] + tgt_spec = ssi.input_specs[argi] + assert isinstance(src_spec, DTensorSpec) + assert isinstance(tgt_spec, DTensorSpec) + # we use our custom comm_cost function to estimate the cost + # of the collective operation + comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec) + redistribution_happened = ( - current_placement != original_placement + src_spec.placements != tgt_spec.placements ) sharding_transition_cost = ( int(redistribution_happened) * sharding_transition_scale ) else: sharding_transition_cost = 0 + + if node in grad_param_nodes: + comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp + + # update OpSpec redistribution cost with our newly-computed cost + # this is useful for print_costs_for_node to print the updated cost + xxi[ii] = comm_cost key = (s_i, argi, ss, ii) # NOTE: this modifies ds in-place sometimes # we might want to refactor this in the future