Skip to content
46 changes: 46 additions & 0 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from typing import Dict, Tuple

import torch
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.flop_counter import FlopCounterMode

Expand Down Expand Up @@ -198,6 +200,50 @@ def tensor_bytes(data):
return read_bytes + write_bytes


def estimate_strategy_comms_cost(src_spec, tgt_spec):
# TODO: need to use optimal redistribution cost instead
comms_cost = redistribute_cost(src_spec, tgt_spec)
total_read_write_bytes = 0

src_sizes, _ = _get_sharded_shape_stride(src_spec)
tgt_sizes, _ = _get_sharded_shape_stride(tgt_spec)

gpu_memory_bandwidth = _get_device_gmem_bandwidth()

for src_plc, tgt_plc in zip(src_spec.placements, tgt_spec.placements):
if src_plc.is_partial() and tgt_plc.is_shard() and tgt_plc.dim != 0:
# penalize cases like P -> S(1) as there are additional compute cost
# which corresponds to reshuffling the whole input tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
read_write_bytes = (
math.prod(src_sizes) * 2 * src_spec.tensor_meta.dtype.itemsize
)
total_read_write_bytes += read_write_bytes
elif src_plc.is_shard() and src_plc.dim != 0 and tgt_plc.is_replicate():
# penalize cases like S(1) -> R as there are additional compute cost
# which corresponds to reshuffling the whole output tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
read_write_bytes = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it thanks - the copy is obviously bandwidth bound so we can just use mem bandwidth to estimate cost.

math.prod(tgt_sizes) * 2 * tgt_spec.tensor_meta.dtype.itemsize
)
total_read_write_bytes += read_write_bytes
elif src_plc.is_replicate() and tgt_plc.is_partial():
# forbit R -> P case as this doesn't make sense for us
total_read_write_bytes += math.inf

compute_cost = total_read_write_bytes / gpu_memory_bandwidth * 1e6 # us
# suppose 80% efficiency for memory-bound ops
factor = 1 / 0.8
compute_cost *= factor

# suppose 70% efficiency for comms
comms_cost *= 1 / 0.7

return comms_cost + compute_cost


def estimate_strategy_runtime_cost(node, strategy):
if node.op != "call_function":
return 0
Expand Down
69 changes: 13 additions & 56 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@
get_plain_input_and_grad_nodes,
get_plain_output_and_tangent_nodes,
)
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
from torch.utils._pytree import tree_flatten, tree_map_only

from .compute_estimation import (
_get_sharded_shape_stride,
estimate_strategy_comms_cost,
estimate_strategy_runtime_cost,
)
from .graph_clustering import get_identical_regions
Expand Down Expand Up @@ -301,6 +302,17 @@ def build_ds(self):
if node.op != "placeholder":
argi_strat = self.strats[self._all_input_nodes(node)[argi]]
for ii, comm_cost in enumerate(xxi):
if node.op != "placeholder":
src_spec = argi_strat.strategies[ii].output_specs
# TODO: operator.getitem being special is something
# we might want to change in the future
if node.target == operator.getitem:
src_spec = src_spec[node.args[1]]
tgt_spec = ssi.input_specs[argi]
assert isinstance(src_spec, DTensorSpec)
assert isinstance(tgt_spec, DTensorSpec)
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)

if node in grad_param_nodes:
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp
# Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at
Expand Down Expand Up @@ -506,61 +518,6 @@ def add_default_constraints(self):
self.add_output_input_consistent_constraint()
self.add_inf_cost_constraint()

self.penalize_inefficient_collectives()

def penalize_inefficient_collectives(self):
"""
EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
non-batch dimension shard-to-replicate conversions and forbid invalid transitions.

- Shard(dim≠0) → Replicate: multiply cost by 4
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
- Partial → Shard(dim≠0): multiply cost by 4

When performing shard_{n} -> replicate (for n != 0), there is additional
computation cost associated. Let's penalize it here while we don't add
the computation cost together in the comm cost
"""
# return
for s_i, node in enumerate(self.graph.nodes):
if node.op != "call_function":
continue
tgt_op_strat = self.strats[node]
for counter, parent in enumerate(node.all_input_nodes):
curr_op_strat = self.strats[parent]

for oi, tgt_strat in enumerate(tgt_op_strat.strategies):
spec = tgt_strat.input_specs[counter]
if not isinstance(spec, DTensorSpec):
# TODO: check if this is correct
continue

for ii, curr_strat in enumerate(curr_op_strat.strategies):
curr_spec = curr_strat.output_specs
if not isinstance(curr_spec, DTensorSpec):
continue
for tgt_plc, curr_plc in zip(
spec.placements, curr_spec.placements
):
if (
tgt_plc.is_replicate()
and curr_plc.is_shard()
and curr_plc.dim != 0
):
# penalize case S(1) -> R as there are additional compute cost
# TODO: add proper compute cost in the optimization objective
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4
elif tgt_plc.is_partial() and curr_plc.is_replicate():
# forbit R -> P case as this doesn't make sense for us
self.prob += self.ds[(s_i, counter, oi, ii)]["va"] == 0
elif (
tgt_plc.is_shard()
and tgt_plc.dim != 0
and curr_plc.is_partial()
):
# penalize case P -> S(1) as there are additional compute cost
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4

def get_violated_constraints_log(self):
violated_constraints = [
(k, c) for k, c in self.prob.constraints.items() if not c.valid()
Expand Down