Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,6 @@ def build_ds(self):
self.strats[all_input_nodes[argi]] if all_input_nodes else None
)
for ii, comm_cost in enumerate(xxi):
if argi_strat is not None:
src_spec = argi_strat.strategies[ii].output_specs
# TODO: operator.getitem being special is something
# we might want to change in the future
if node.target == operator.getitem:
src_spec = src_spec[node.args[1]]
tgt_spec = ssi.input_specs[argi]
assert isinstance(src_spec, DTensorSpec)
assert isinstance(tgt_spec, DTensorSpec)
# we use our custom comm_cost function to estimate the cost
# of the collective operation
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)

if node in grad_param_nodes:
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp
# Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at
# RR, and that node_{i+1} is an op with zero cost (like alias).
# In this case, all of the following chains yield the same cost:
Expand All @@ -332,19 +317,34 @@ def build_ds(self):
# in a single go. To do this, we add a tie-break cost that is 1 if a redistribution
# happens prior to getting to this configuration, and 0 otherwise. This way,
# we will favor having fewer redistributions happening in the graph.
if argi_strat is not None and node.target != operator.getitem:
original_placement = argi_strat.strategies[
ii
].output_specs.placements
current_placement = ssi.input_specs[argi].placements
if argi_strat is not None:
src_spec = argi_strat.strategies[ii].output_specs
# TODO: operator.getitem being special is something
# we might want to change in the future
if node.target == operator.getitem:
src_spec = src_spec[node.args[1]]
tgt_spec = ssi.input_specs[argi]
assert isinstance(src_spec, DTensorSpec)
assert isinstance(tgt_spec, DTensorSpec)
# we use our custom comm_cost function to estimate the cost
# of the collective operation
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)

redistribution_happened = (
current_placement != original_placement
src_spec.placements != tgt_spec.placements
)
sharding_transition_cost = (
int(redistribution_happened) * sharding_transition_scale
)
else:
sharding_transition_cost = 0

if node in grad_param_nodes:
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp

# update OpSpec redistribution cost with our newly-computed cost
# this is useful for print_costs_for_node to print the updated cost
xxi[ii] = comm_cost
key = (s_i, argi, ss, ii)
# NOTE: this modifies ds in-place sometimes
# we might want to refactor this in the future
Expand Down