diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 56051360..bed732b4 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -3,6 +3,80 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +""" +Sharding optimization using Integer Linear Programming (ILP). + +This module solves the optimal sharding strategy problem by formulating it as an ILP +where each binary variable x_{i,a,o,j} ∈ {0,1} represents a choice of input placement j +and output placement o for operation i and argument a. The objective minimizes total cost: + + minimize: Σ_{i,a,o,j} c_{i,a,o,j} * x_{i,a,o,j} + +where: +- x_{i,a,o,j}: binary decision variable (1 if strategy selected, 0 otherwise) +- c_{i,a,o,j}: total cost (communication + computation) for this strategy choice + +subject to the following constraint categories: + +1. UNIQUENESS CONSTRAINTS: Each operation-argument pair must select exactly one + input-output placement combination. + + ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1 + + → Implemented in: add_unique_decision_constraint() + +2. CONSISTENCY CONSTRAINTS: For multi-argument operations, all arguments must agree + on the same output placement to ensure the operation can execute correctly. + + ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j} + where A_i is the number of arguments for operation i. + + → Implemented in: add_same_output_across_args_constraint() + +3. FLOW CONSTRAINTS: The output placement of producer operations must match the + input placement of consumer operations (dataflow consistency). + + ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o} + where operation i feeds into operation k at argument position a. + + → Implemented in: add_output_input_consistent_constraint() + +4. COST CONSTRAINTS: Variables with infinite cost (invalid configurations) are + forced to zero. + + ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0 + + → Implemented in: add_inf_cost_constraint() + +5. EFFICIENCY CONSTRAINTS: Penalize inefficient collective operations like + non-batch dimension shard-to-replicate conversions and forbid invalid + transitions like replicate-to-partial. + + - Shard(dim≠0) → Replicate: multiply cost by 4 + - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) + - Partial → Shard(dim≠0): multiply cost by 4 + + → Implemented in: penalize_inefficient_collectives() + +6. USER CONSTRAINTS (optional): Force specific placements for inputs, outputs, + parameters, or memory usage bounds. + + 6a. Input/Output constraints: x_{i,a,o*,j*} = 1 for specified (o*,j*) + → Implemented in: add_sharded_input_constraint(), add_sharded_output_constraint() + + 6b. Memory constraints: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit + → Implemented in: add_parameter_memory_constraint() + + 6c. Parameter-gradient consistency: x_{param} = x_{grad_param} + → Implemented in: add_grad_param_constraints() + + 6d. General node constraints: Force specific placement for any node + → Implemented in: add_node_constraint() + +The solver finds the globally optimal sharding strategy that minimizes total +runtime cost while satisfying all constraints. +""" + import math import pulp @@ -43,6 +117,9 @@ def __init__(self, gm, mesh): self.mesh = mesh self.node_map = {node: i for i, node in enumerate(self.graph.nodes)} self.strats = self.build_sharding_metadata() + # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data + # Each key represents a choice of input placement ii and output placement ss + # for operation s_i and argument argi (corresponds to x_{i,a,o,j} in math notation) self.ds, self.num_inp_out, self.num_args = self.build_ds() self.validate() self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize) @@ -80,6 +157,26 @@ def build_sharding_metadata(self): return strats def build_ds(self): + """ + Build decision variables (ds) for the ILP optimization. + + Creates binary variables x_{i,a,o,j} for each valid combination of: + - s_i: operation index + - argi: argument index + - ss: output placement strategy index (o in math notation) + - ii: input placement strategy index (j in math notation) + + Returns: + ds: Dictionary mapping (s_i, argi, ss, ii) -> { + "va": PuLP binary variable, + "cost": communication + computation cost, + "full_strat": complete strategy object, + "out_strat": output placement specification, + "inp_strat": input placement specification + } + num_inp_out: Metadata about strategy counts per operation-argument + num_args: Number of arguments per operation + """ strats = self.strats ds = {} num_inp_out = {} @@ -127,6 +224,12 @@ def walk_over_options(self, node, constrain_arg=None): yield argi, oi, ii def add_unique_decision_constraint(self): + """ + UNIQUENESS CONSTRAINTS (Category 1): Each operation-argument pair must select exactly one + input-output placement combination. + + Mathematical form: ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1 + """ # a single pair of input-output policy is chosen for s_i, node in enumerate(self.graph.nodes): if node.op not in {"placeholder", "call_function"}: @@ -139,6 +242,12 @@ def add_unique_decision_constraint(self): self.prob += (pulp.lpSum(eqs) == 1, _get_next_name("unique_decision")) def add_same_output_across_args_constraint(self): + """ + CONSISTENCY CONSTRAINTS (Category 2): For multi-argument operations, all arguments must agree + on the same output placement to ensure the operation can execute correctly. + + Mathematical form: ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j} + """ # enforce that the same output policy is chosen # across arguments for s_i, node in enumerate(self.graph.nodes): @@ -163,6 +272,12 @@ def add_same_output_across_args_constraint(self): ) def add_output_input_consistent_constraint(self): + """ + FLOW CONSTRAINTS (Category 3): The output placement of producer operations must match the + input placement of consumer operations (dataflow consistency). + + Mathematical form: ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o} + """ # enforce that the input of strat_{i+1} == output of strat_{i} for s_i, node in enumerate(self.graph.nodes): if node.op == "output": @@ -196,6 +311,12 @@ def add_output_input_consistent_constraint(self): ) def add_inf_cost_constraint(self): + """ + COST CONSTRAINTS (Category 4): Variables with infinite cost (invalid configurations) are + forced to zero. + + Mathematical form: ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0 + """ # force inf cost values to be 0, as the solver doesn't accept inf for x in self.ds.values(): if not math.isfinite(x["cost"]): @@ -213,6 +334,13 @@ def add_default_constraints(self): def penalize_inefficient_collectives(self): """ + EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like + non-batch dimension shard-to-replicate conversions and forbid invalid transitions. + + - Shard(dim≠0) → Replicate: multiply cost by 4 + - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) + - Partial → Shard(dim≠0): multiply cost by 4 + When performing shard_{n} -> replicate (for n != 0), there is additional computation cost associated. Let's penalize it here while we don't add the computation cost together in the comm cost @@ -440,6 +568,12 @@ def get_grad_param_nodes(self): return grad_param_nodes def add_grad_param_constraints(self): + """ + USER CONSTRAINTS (Category 6c): Parameter-gradient consistency constraints. + Ensures parameters and their gradients have matching sharding strategies. + + Mathematical form: x_{param} = x_{grad_param} + """ # TODO: need to make sure that the params and grads are aligned, which are not always the case # and we might have fewer gradients than parameters @@ -478,6 +612,12 @@ def add_grad_param_constraints(self): ) def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high): + """ + USER CONSTRAINTS (Category 6b): Memory constraints for parameters. + Ensures total parameter memory usage stays within specified bounds. + + Mathematical form: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit + """ # get all parameters param_nodes = self.get_param_nodes() elms = [] @@ -499,6 +639,12 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high) self.prob += (pulp.lpSum(elms) >= memory_factor_low, "memory_constraint_low") def add_node_constraint(self, node, placement=None, constraint_name=None): + """ + USER CONSTRAINTS (Category 6d): General node constraints. + Force specific placement for any node. + + Mathematical form: x_{i,a,o*,j*} = 1 for specified (o*,j*) + """ strat = self.strats[node] if placement is None: # default is Shard(0) to parallelize on the batch @@ -514,6 +660,12 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): self._add_node_constraint(node, oi=oi, constraint_name=constraint_name) def add_sharded_input_constraint(self, input_placements=None): + """ + USER CONSTRAINTS (Category 6a): Input placement constraints. + Force specific placements for input nodes and their corresponding gradient inputs. + + Mathematical form: x_{i,a,o*,j*} = 1 for specified input placements (o*,j*) + """ input_nodes = self.get_input_nodes() if input_placements is None: input_placements = [None] * len(input_nodes) @@ -538,6 +690,12 @@ def add_sharded_input_constraint(self, input_placements=None): ) def add_sharded_output_constraint(self, output_placements=None): + """ + USER CONSTRAINTS (Category 6a): Output placement constraints. + Force specific placements for output nodes and their corresponding gradient outputs. + + Mathematical form: x_{i,a,o*,j*} = 1 for specified output placements (o*,j*) + """ # add final constraint on the output strategy output_nodes = self.get_fn_output_nodes()