diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 49f6d020..a42efb8d 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -739,14 +739,16 @@ def get_solution(self, verbose=False): # TODO: assert all nodes have a placement? return opt - def _add_node_constraint(self, node, oi, constraint_name=None): + def _add_node_constraint( + self, node, output_constraint_indices, constraint_name=None + ): if constraint_name is None: constraint_name = "user_constraint" s_i = self.node_map[node] vars_per_arg = {} - for argi, oi_, ii in self.walk_over_options(node): - if oi_ == oi: - va = self.ds[(s_i, argi, oi, ii)]["va"] + for argi, output_constraint_index, input_index in self.walk_over_options(node): + if output_constraint_index in output_constraint_indices: + va = self.ds[(s_i, argi, output_constraint_index, input_index)]["va"] vars_per_arg.setdefault(argi, []).append(va) for eqs in vars_per_arg.values(): self.prob += (pulp.lpSum(eqs) == 1, _get_next_name(constraint_name)) @@ -828,15 +830,20 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): if placement is None: # default is Shard(0) to parallelize on the batch placement = (Shard(0),) + (Replicate(),) * (self.mesh.ndim - 1) - for oi, s in enumerate(strat.strategies): + output_constraint_indices = [] + for output_constraint_index, s in enumerate(strat.strategies): spec = s.output_specs if spec.placements == placement: - break - else: + output_constraint_indices.append(output_constraint_index) + if len(output_constraint_indices) == 0: raise RuntimeError( f"Couldn't find appropriate constraint {node} {constraint_name} {placement}" ) - self._add_node_constraint(node, oi=oi, constraint_name=constraint_name) + self._add_node_constraint( + node, + output_constraint_indices=output_constraint_indices, + constraint_name=constraint_name, + ) def add_sharded_input_constraint( self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None