Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,14 +739,16 @@ def get_solution(self, verbose=False):
# TODO: assert all nodes have a placement?
return opt

def _add_node_constraint(self, node, oi, constraint_name=None):
def _add_node_constraint(
self, node, output_constraint_indices, constraint_name=None
):
if constraint_name is None:
constraint_name = "user_constraint"
s_i = self.node_map[node]
vars_per_arg = {}
for argi, oi_, ii in self.walk_over_options(node):
if oi_ == oi:
va = self.ds[(s_i, argi, oi, ii)]["va"]
for argi, output_constraint_index, input_index in self.walk_over_options(node):
if output_constraint_index in output_constraint_indices:
va = self.ds[(s_i, argi, output_constraint_index, input_index)]["va"]
vars_per_arg.setdefault(argi, []).append(va)
for eqs in vars_per_arg.values():
self.prob += (pulp.lpSum(eqs) == 1, _get_next_name(constraint_name))
Expand Down Expand Up @@ -828,15 +830,20 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
if placement is None:
# default is Shard(0) to parallelize on the batch
placement = (Shard(0),) + (Replicate(),) * (self.mesh.ndim - 1)
for oi, s in enumerate(strat.strategies):
output_constraint_indices = []
for output_constraint_index, s in enumerate(strat.strategies):
spec = s.output_specs
if spec.placements == placement:
break
else:
output_constraint_indices.append(output_constraint_index)
if len(output_constraint_indices) == 0:
raise RuntimeError(
f"Couldn't find appropriate constraint {node} {constraint_name} {placement}"
)
self._add_node_constraint(node, oi=oi, constraint_name=constraint_name)
self._add_node_constraint(
node,
output_constraint_indices=output_constraint_indices,
constraint_name=constraint_name,
)

def add_sharded_input_constraint(
self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None
Expand Down