diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index 3d8678cb..570a7462 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -3,10 +3,12 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import math from dataclasses import dataclass from typing import Dict, Tuple import torch +from torch.distributed.tensor._collective_utils import redistribute_cost from torch.utils._pytree import tree_flatten, tree_map_only from torch.utils.flop_counter import FlopCounterMode @@ -198,6 +200,50 @@ def tensor_bytes(data): return read_bytes + write_bytes +def estimate_strategy_comms_cost(src_spec, tgt_spec): + # TODO: need to use optimal redistribution cost instead + comms_cost = redistribute_cost(src_spec, tgt_spec) + total_read_write_bytes = 0 + + src_sizes, _ = _get_sharded_shape_stride(src_spec) + tgt_sizes, _ = _get_sharded_shape_stride(tgt_spec) + + gpu_memory_bandwidth = _get_device_gmem_bandwidth() + + for src_plc, tgt_plc in zip(src_spec.placements, tgt_spec.placements): + if src_plc.is_partial() and tgt_plc.is_shard() and tgt_plc.dim != 0: + # penalize cases like P -> S(1) as there are additional compute cost + # which corresponds to reshuffling the whole input tensor + # we multiply the cost by 2 because we need to count input and output + # reads for the reshuffle + read_write_bytes = ( + math.prod(src_sizes) * 2 * src_spec.tensor_meta.dtype.itemsize + ) + total_read_write_bytes += read_write_bytes + elif src_plc.is_shard() and src_plc.dim != 0 and tgt_plc.is_replicate(): + # penalize cases like S(1) -> R as there are additional compute cost + # which corresponds to reshuffling the whole output tensor + # we multiply the cost by 2 because we need to count input and output + # reads for the reshuffle + read_write_bytes = ( + math.prod(tgt_sizes) * 2 * tgt_spec.tensor_meta.dtype.itemsize + ) + total_read_write_bytes += read_write_bytes + elif src_plc.is_replicate() and tgt_plc.is_partial(): + # forbit R -> P case as this doesn't make sense for us + total_read_write_bytes += math.inf + + compute_cost = total_read_write_bytes / gpu_memory_bandwidth * 1e6 # us + # suppose 80% efficiency for memory-bound ops + factor = 1 / 0.8 + compute_cost *= factor + + # suppose 70% efficiency for comms + comms_cost *= 1 / 0.7 + + return comms_cost + compute_cost + + def estimate_strategy_runtime_cost(node, strategy): if node.op != "call_function": return 0 diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 2838a1ca..62deffd0 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -92,12 +92,13 @@ get_plain_input_and_grad_nodes, get_plain_output_and_tangent_nodes, ) -from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.tensor.placement_types import Placement, Replicate, Shard from torch.utils._pytree import tree_flatten, tree_map_only from .compute_estimation import ( _get_sharded_shape_stride, + estimate_strategy_comms_cost, estimate_strategy_runtime_cost, ) from .graph_clustering import get_identical_regions @@ -301,6 +302,17 @@ def build_ds(self): if node.op != "placeholder": argi_strat = self.strats[self._all_input_nodes(node)[argi]] for ii, comm_cost in enumerate(xxi): + if node.op != "placeholder": + src_spec = argi_strat.strategies[ii].output_specs + # TODO: operator.getitem being special is something + # we might want to change in the future + if node.target == operator.getitem: + src_spec = src_spec[node.args[1]] + tgt_spec = ssi.input_specs[argi] + assert isinstance(src_spec, DTensorSpec) + assert isinstance(tgt_spec, DTensorSpec) + comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec) + if node in grad_param_nodes: comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp # Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at @@ -506,61 +518,6 @@ def add_default_constraints(self): self.add_output_input_consistent_constraint() self.add_inf_cost_constraint() - self.penalize_inefficient_collectives() - - def penalize_inefficient_collectives(self): - """ - EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like - non-batch dimension shard-to-replicate conversions and forbid invalid transitions. - - - Shard(dim≠0) → Replicate: multiply cost by 4 - - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) - - Partial → Shard(dim≠0): multiply cost by 4 - - When performing shard_{n} -> replicate (for n != 0), there is additional - computation cost associated. Let's penalize it here while we don't add - the computation cost together in the comm cost - """ - # return - for s_i, node in enumerate(self.graph.nodes): - if node.op != "call_function": - continue - tgt_op_strat = self.strats[node] - for counter, parent in enumerate(node.all_input_nodes): - curr_op_strat = self.strats[parent] - - for oi, tgt_strat in enumerate(tgt_op_strat.strategies): - spec = tgt_strat.input_specs[counter] - if not isinstance(spec, DTensorSpec): - # TODO: check if this is correct - continue - - for ii, curr_strat in enumerate(curr_op_strat.strategies): - curr_spec = curr_strat.output_specs - if not isinstance(curr_spec, DTensorSpec): - continue - for tgt_plc, curr_plc in zip( - spec.placements, curr_spec.placements - ): - if ( - tgt_plc.is_replicate() - and curr_plc.is_shard() - and curr_plc.dim != 0 - ): - # penalize case S(1) -> R as there are additional compute cost - # TODO: add proper compute cost in the optimization objective - self.ds[(s_i, counter, oi, ii)]["cost"] *= 4 - elif tgt_plc.is_partial() and curr_plc.is_replicate(): - # forbit R -> P case as this doesn't make sense for us - self.prob += self.ds[(s_i, counter, oi, ii)]["va"] == 0 - elif ( - tgt_plc.is_shard() - and tgt_plc.dim != 0 - and curr_plc.is_partial() - ): - # penalize case P -> S(1) as there are additional compute cost - self.ds[(s_i, counter, oi, ii)]["cost"] *= 4 - def get_violated_constraints_log(self): violated_constraints = [ (k, c) for k, c in self.prob.constraints.items() if not c.valid()