Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions autoparallel/graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
tree_flatten,
)
from torch._inductor.codecache import sha256_hash
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import OpStrategy

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -52,7 +53,24 @@ def _normalize_args(
return (sorted_keys, tuple(_extract_args(arg) for arg in all_args))


def _prepare_op_strategy(op_strategy):
def _print_output_specs(op_strategy):
output = []
for s in op_strategy.strategies:
output_placements = []
output_specs = s.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = [output_specs]
for output_spec in output_specs:
if output_spec is None:
output_placements.append("(None)")
continue
plc_str = ",".join([str(p) for p in output_spec.placements])
output_placements.append(f"({plc_str})")
output.append(f"({','.join(output_placements)})")
return ", ".join(output)


def _prepare_op_strategy(op_strategy, output_only=False):
# hasing op_strategy is expensive, so we hash the string representation
# instead, which is much cheaper and is a reasonable proxy for the
# clustering
Expand All @@ -62,14 +80,20 @@ def _prepare_op_strategy(op_strategy):
# view ops, which propagate the input shardings to the output.
# So we also add the strategy for a node as a hash key to avoid
# clustering nodes that look the same but have different strategies
if output_only:
return _print_output_specs(op_strategy)
return str(op_strategy)


def _hash_node(node, op_strategy, input_pickler):
def _hash_node(node, strategies, input_pickler):
key = (
node.meta.get("stack_trace"),
_normalize_args(node),
_prepare_op_strategy(op_strategy),
_prepare_op_strategy(strategies[node]),
tuple(
_prepare_op_strategy(strategies[s], output_only=True)
for s in node.all_input_nodes
),
)
return sha256_hash(input_pickler.dumps(key))

Expand Down Expand Up @@ -104,9 +128,7 @@ def get_identical_regions(
if node.op == "placeholder":
continue

duplicates = hash_to_duplicates[
_hash_node(node, strategies[node], input_pickler)
]
duplicates = hash_to_duplicates[_hash_node(node, strategies, input_pickler)]
duplicates.append(node)
node_to_duplicates[node] = duplicates
logger.info(f"Hashed nodes in {time.time() - t} s")
Expand Down
12 changes: 9 additions & 3 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,13 @@ def create_cluster_links(self, clusters):
for n0, ni in zip(cluster0, cluster_i):
s0 = self.node_map[n0]
s1 = self.node_map[ni]
for argi, oi, ii in self.walk_over_options(n0):
options_n0 = list(self.walk_over_options(n0))
options_ni = list(self.walk_over_options(ni))
assert options_n0 == options_ni, (
f"Problem with graph clustering: {n0} and {ni} don't have the same number "
"of input/output placements. Please report a bug"
)
for argi, oi, ii in options_n0:
self.cluster_links[(s1, argi, oi, ii)] = (s0, argi, oi, ii)

def _build_pulp_variable(self, key, ds):
Expand Down Expand Up @@ -475,7 +481,7 @@ def add_output_input_consistent_constraint(self):
va = self.ds[key]["va"]
vars_s_j.setdefault(s_j_ii, []).append(va)

if vars_s_i.keys() != vars_s_j.keys():
if len(vars_s_j) == 0:
vars_s_j = {}
for _, s_j_oi, s_j_ii in self.walk_over_options(user, argj):
key = (s_j, argj, s_j_oi, s_j_ii)
Expand All @@ -485,7 +491,7 @@ def add_output_input_consistent_constraint(self):
va = self.ds[key]["va"]
vars_s_j.setdefault(s_j_ii, []).append(va)

if vars_s_i.keys() != vars_s_j.keys():
if len(vars_s_i) == 0:
vars_s_i = {}
for _, s_i_oi, s_i_ii in self.walk_over_options(node, argi):
key = (s_i, argi, s_i_oi, s_i_ii)
Expand Down
Loading