Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,67 +206,8 @@ def add_default_constraints(self):
self.add_output_input_consistent_constraint()
self.add_inf_cost_constraint()

self.remove_invalid_configurations()
self.penalize_inefficient_collectives()

def remove_invalid_configurations(self):
"""
Remove shardings that could yield invalid configurations,
for example, when sharding a view on a dimension that would yield
an empty size. Maybe this should be fixed in the returned specs from PyTorch
though, but removing those invalid cases here for now
"""
for s_i, node in enumerate(self.graph.nodes):
if node.op != "call_function":
continue
# only targetting view for now
if node.target != torch.ops.aten.view.default:
continue
orig_shape = node.args[0].meta["val"].shape
shape = list(node.args[1])
if len(orig_shape) > len(shape):
# TODO: FIXME as I think we should also handle this case
continue
# print("in heeeererereer", orig_shape, shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

tgt_op_strat = self.strats[node]
for counter, parent in enumerate(node.all_input_nodes):
curr_op_strat = self.strats[parent]

for oi, tgt_strat in enumerate(tgt_op_strat.strategies):
spec = tgt_strat.input_specs[counter]
if not isinstance(spec, DTensorSpec):
# TODO: check if this is correct
continue

for ii, curr_strat in enumerate(curr_op_strat.strategies):
curr_spec = curr_strat.output_specs
if not isinstance(curr_spec, DTensorSpec):
continue
shape = list(node.args[1])
if -1 in shape:
# handle cases where we need to infer the size
numel = math.prod(orig_shape)
index_loc = shape.index(-1)
# this works because the shape we infer is -1
# and there is a single one
visible_numel = -math.prod(shape)
shape[index_loc] = numel // visible_numel
for mesh_shape, tgt_plc, curr_plc in zip(
spec.mesh.shape, spec.placements, curr_spec.placements
):
# only keep view shardings that don't yield empty shapes
# which could happen with S(0)S(0) on a dimension whose shape
# is smaller than world_size
if tgt_plc.is_shard():
dim = tgt_plc.dim
if shape[dim] % mesh_shape == 0:
shape[dim] /= mesh_shape
else:
self.prob += (
self.ds[(s_i, counter, oi, ii)]["va"] == 0,
_get_next_name("invalid_view"),
)

def penalize_inefficient_collectives(self):
"""
When performing shard_{n} -> replicate (for n != 0), there is additional
Expand Down
34 changes: 32 additions & 2 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ def _build_meta_tensor(tensor_meta):
)


def remove_invalid_configs(out_strat, mesh):
kept = []
for strategy in out_strat.strategies:
is_valid = True
output_specs = strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = [output_specs]
specs = list(strategy.input_specs) + list(output_specs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: list(output_specs) seems redundant to the line above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strategy.output_specs can also be a tuple of DTensorSpec, so I'm just trying to make sure we are not concatenating lists and tuples together

for spec in specs:
if spec is None:
continue
shape = list(spec.tensor_meta.shape)
for mesh_shape, plc in zip(mesh.shape, spec.placements):
if plc.is_shard():
dim = plc.dim
if shape[dim] % mesh_shape == 0:
shape[dim] //= mesh_shape
else:
is_valid = False
break
if is_valid:
kept.append(strategy)

return OpStrategy(kept)


def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None):
if tensor_meta is None:
tensor_meta = _gen_tensor_meta(shape)
Expand All @@ -94,7 +120,9 @@ def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None):
continue
spec = DTensorSpec.from_dim_map(mesh, op, [], tensor_meta)
strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]]))
return OpStrategy(strats)
out_strats = OpStrategy(strats)
out_strats = remove_invalid_configs(out_strats, mesh)
return out_strats


def _create_all_options(mesh, shape, tensor_meta=None, tensor=None):
Expand All @@ -112,7 +140,9 @@ def _create_all_options(mesh, shape, tensor_meta=None, tensor=None):
for placement in all_options:
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]]))
return OpStrategy(strats)
out_strats = OpStrategy(strats)
out_strats = remove_invalid_configs(out_strats, mesh)
return out_strats


@register_rule(operator.getitem)
Expand Down
25 changes: 5 additions & 20 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
from torch.utils._pytree import tree_flatten, tree_map_only

from .propagation_rules import _op_partial_rules, _op_rules
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs


def propagate_tensor_meta(op, user_args, out_strat):
Expand Down Expand Up @@ -90,7 +90,9 @@ def get_placement_options(mesh, op, specs, user_args):
# print(op)

if op in _op_rules:
return _op_rules[op](mesh, specs)
out_strat = _op_rules[op](mesh, specs)
out_strat = remove_invalid_configs(out_strat, mesh)
return out_strat

strat = []
for spec in specs:
Expand Down Expand Up @@ -119,24 +121,7 @@ def get_placement_options(mesh, op, specs, user_args):

propagate_tensor_meta(op, user_args, out_strat)
fill_missing_redistribute_cost(op, specs, out_strat)

kept = []
for strategy in out_strat.strategies:
is_valid = True
for input_spec in strategy.input_specs:
shape = list(input_spec.tensor_meta.shape)
for mesh_shape, plc in zip(mesh.shape, input_spec.placements):
if plc.is_shard():
dim = plc.dim
if shape[dim] % mesh_shape == 0:
shape[dim] /= mesh_shape
else:
is_valid = False
break
if is_valid:
kept.append(strategy)

out_strat = OpStrategy(kept)
out_strat = remove_invalid_configs(out_strat, mesh)

return out_strat

Expand Down