-
Notifications
You must be signed in to change notification settings - Fork 9
[claude code] Documentation for optimize_sharding.py #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,80 @@ | |
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Sharding optimization using Integer Linear Programming (ILP). | ||
|
|
||
| This module solves the optimal sharding strategy problem by formulating it as an ILP | ||
| where each binary variable x_{i,a,o,j} ∈ {0,1} represents a choice of input placement j | ||
| and output placement o for operation i and argument a. The objective minimizes total cost: | ||
|
|
||
| minimize: Σ_{i,a,o,j} c_{i,a,o,j} * x_{i,a,o,j} | ||
|
|
||
| where: | ||
| - x_{i,a,o,j}: binary decision variable (1 if strategy selected, 0 otherwise) | ||
| - c_{i,a,o,j}: total cost (communication + computation) for this strategy choice | ||
|
|
||
| subject to the following constraint categories: | ||
|
|
||
| 1. UNIQUENESS CONSTRAINTS: Each operation-argument pair must select exactly one | ||
| input-output placement combination. | ||
|
|
||
| ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1 | ||
|
|
||
| → Implemented in: add_unique_decision_constraint() | ||
|
|
||
| 2. CONSISTENCY CONSTRAINTS: For multi-argument operations, all arguments must agree | ||
| on the same output placement to ensure the operation can execute correctly. | ||
|
|
||
| ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j} | ||
| where A_i is the number of arguments for operation i. | ||
|
|
||
| → Implemented in: add_same_output_across_args_constraint() | ||
|
|
||
| 3. FLOW CONSTRAINTS: The output placement of producer operations must match the | ||
| input placement of consumer operations (dataflow consistency). | ||
|
|
||
| ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o} | ||
| where operation i feeds into operation k at argument position a. | ||
|
|
||
| → Implemented in: add_output_input_consistent_constraint() | ||
|
|
||
| 4. COST CONSTRAINTS: Variables with infinite cost (invalid configurations) are | ||
| forced to zero. | ||
|
|
||
| ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0 | ||
|
|
||
| → Implemented in: add_inf_cost_constraint() | ||
|
|
||
| 5. EFFICIENCY CONSTRAINTS: Penalize inefficient collective operations like | ||
| non-batch dimension shard-to-replicate conversions and forbid invalid | ||
| transitions like replicate-to-partial. | ||
|
|
||
| - Shard(dim≠0) → Replicate: multiply cost by 4 | ||
| - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) | ||
| - Partial → Shard(dim≠0): multiply cost by 4 | ||
|
|
||
| → Implemented in: penalize_inefficient_collectives() | ||
|
|
||
| 6. USER CONSTRAINTS (optional): Force specific placements for inputs, outputs, | ||
| parameters, or memory usage bounds. | ||
|
|
||
| 6a. Input/Output constraints: x_{i,a,o*,j*} = 1 for specified (o*,j*) | ||
| → Implemented in: add_sharded_input_constraint(), add_sharded_output_constraint() | ||
|
|
||
| 6b. Memory constraints: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit | ||
| → Implemented in: add_parameter_memory_constraint() | ||
|
|
||
| 6c. Parameter-gradient consistency: x_{param} = x_{grad_param} | ||
| → Implemented in: add_grad_param_constraints() | ||
|
|
||
| 6d. General node constraints: Force specific placement for any node | ||
| → Implemented in: add_node_constraint() | ||
|
|
||
| The solver finds the globally optimal sharding strategy that minimizes total | ||
| runtime cost while satisfying all constraints. | ||
| """ | ||
|
|
||
| import math | ||
|
|
||
| import pulp | ||
|
|
@@ -43,6 +117,9 @@ def __init__(self, gm, mesh): | |
| self.mesh = mesh | ||
| self.node_map = {node: i for i, node in enumerate(self.graph.nodes)} | ||
| self.strats = self.build_sharding_metadata() | ||
| # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data | ||
| # Each key represents a choice of input placement ii and output placement ss | ||
| # for operation s_i and argument argi (corresponds to x_{i,a,o,j} in math notation) | ||
| self.ds, self.num_inp_out, self.num_args = self.build_ds() | ||
| self.validate() | ||
| self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize) | ||
|
|
@@ -80,6 +157,26 @@ def build_sharding_metadata(self): | |
| return strats | ||
|
|
||
| def build_ds(self): | ||
| """ | ||
| Build decision variables (ds) for the ILP optimization. | ||
|
|
||
| Creates binary variables x_{i,a,o,j} for each valid combination of: | ||
| - s_i: operation index | ||
| - argi: argument index | ||
| - ss: output placement strategy index (o in math notation) | ||
| - ii: input placement strategy index (j in math notation) | ||
|
|
||
| Returns: | ||
| ds: Dictionary mapping (s_i, argi, ss, ii) -> { | ||
| "va": PuLP binary variable, | ||
| "cost": communication + computation cost, | ||
| "full_strat": complete strategy object, | ||
| "out_strat": output placement specification, | ||
| "inp_strat": input placement specification | ||
| } | ||
| num_inp_out: Metadata about strategy counts per operation-argument | ||
| num_args: Number of arguments per operation | ||
| """ | ||
| strats = self.strats | ||
| ds = {} | ||
| num_inp_out = {} | ||
|
|
@@ -127,6 +224,12 @@ def walk_over_options(self, node, constrain_arg=None): | |
| yield argi, oi, ii | ||
|
|
||
| def add_unique_decision_constraint(self): | ||
| """ | ||
| UNIQUENESS CONSTRAINTS (Category 1): Each operation-argument pair must select exactly one | ||
| input-output placement combination. | ||
|
|
||
| Mathematical form: ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1 | ||
| """ | ||
| # a single pair of input-output policy is chosen | ||
| for s_i, node in enumerate(self.graph.nodes): | ||
| if node.op not in {"placeholder", "call_function"}: | ||
|
|
@@ -139,6 +242,12 @@ def add_unique_decision_constraint(self): | |
| self.prob += (pulp.lpSum(eqs) == 1, _get_next_name("unique_decision")) | ||
|
|
||
| def add_same_output_across_args_constraint(self): | ||
| """ | ||
| CONSISTENCY CONSTRAINTS (Category 2): For multi-argument operations, all arguments must agree | ||
| on the same output placement to ensure the operation can execute correctly. | ||
|
|
||
| Mathematical form: ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j} | ||
| """ | ||
| # enforce that the same output policy is chosen | ||
| # across arguments | ||
| for s_i, node in enumerate(self.graph.nodes): | ||
|
|
@@ -163,6 +272,12 @@ def add_same_output_across_args_constraint(self): | |
| ) | ||
|
|
||
| def add_output_input_consistent_constraint(self): | ||
| """ | ||
| FLOW CONSTRAINTS (Category 3): The output placement of producer operations must match the | ||
| input placement of consumer operations (dataflow consistency). | ||
|
|
||
| Mathematical form: ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o} | ||
| """ | ||
| # enforce that the input of strat_{i+1} == output of strat_{i} | ||
| for s_i, node in enumerate(self.graph.nodes): | ||
| if node.op == "output": | ||
|
|
@@ -196,6 +311,12 @@ def add_output_input_consistent_constraint(self): | |
| ) | ||
|
|
||
| def add_inf_cost_constraint(self): | ||
| """ | ||
| COST CONSTRAINTS (Category 4): Variables with infinite cost (invalid configurations) are | ||
| forced to zero. | ||
|
|
||
| Mathematical form: ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0 | ||
| """ | ||
| # force inf cost values to be 0, as the solver doesn't accept inf | ||
| for x in self.ds.values(): | ||
| if not math.isfinite(x["cost"]): | ||
|
|
@@ -213,6 +334,13 @@ def add_default_constraints(self): | |
|
|
||
| def penalize_inefficient_collectives(self): | ||
| """ | ||
| EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wconstab we discussed this the other day, ideally we should just make the comms cost computation more accurate at some point |
||
| non-batch dimension shard-to-replicate conversions and forbid invalid transitions. | ||
|
|
||
| - Shard(dim≠0) → Replicate: multiply cost by 4 | ||
| - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden) | ||
| - Partial → Shard(dim≠0): multiply cost by 4 | ||
|
|
||
| When performing shard_{n} -> replicate (for n != 0), there is additional | ||
| computation cost associated. Let's penalize it here while we don't add | ||
| the computation cost together in the comm cost | ||
|
|
@@ -440,6 +568,12 @@ def get_grad_param_nodes(self): | |
| return grad_param_nodes | ||
|
|
||
| def add_grad_param_constraints(self): | ||
| """ | ||
| USER CONSTRAINTS (Category 6c): Parameter-gradient consistency constraints. | ||
| Ensures parameters and their gradients have matching sharding strategies. | ||
|
|
||
| Mathematical form: x_{param} = x_{grad_param} | ||
| """ | ||
| # TODO: need to make sure that the params and grads are aligned, which are not always the case | ||
| # and we might have fewer gradients than parameters | ||
|
|
||
|
|
@@ -478,6 +612,12 @@ def add_grad_param_constraints(self): | |
| ) | ||
|
|
||
| def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high): | ||
| """ | ||
| USER CONSTRAINTS (Category 6b): Memory constraints for parameters. | ||
| Ensures total parameter memory usage stays within specified bounds. | ||
|
|
||
| Mathematical form: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit | ||
| """ | ||
| # get all parameters | ||
| param_nodes = self.get_param_nodes() | ||
| elms = [] | ||
|
|
@@ -499,6 +639,12 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high) | |
| self.prob += (pulp.lpSum(elms) >= memory_factor_low, "memory_constraint_low") | ||
|
|
||
| def add_node_constraint(self, node, placement=None, constraint_name=None): | ||
| """ | ||
| USER CONSTRAINTS (Category 6d): General node constraints. | ||
| Force specific placement for any node. | ||
|
|
||
| Mathematical form: x_{i,a,o*,j*} = 1 for specified (o*,j*) | ||
| """ | ||
| strat = self.strats[node] | ||
| if placement is None: | ||
| # default is Shard(0) to parallelize on the batch | ||
|
|
@@ -514,6 +660,12 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): | |
| self._add_node_constraint(node, oi=oi, constraint_name=constraint_name) | ||
|
|
||
| def add_sharded_input_constraint(self, input_placements=None): | ||
| """ | ||
| USER CONSTRAINTS (Category 6a): Input placement constraints. | ||
| Force specific placements for input nodes and their corresponding gradient inputs. | ||
|
|
||
| Mathematical form: x_{i,a,o*,j*} = 1 for specified input placements (o*,j*) | ||
| """ | ||
| input_nodes = self.get_input_nodes() | ||
| if input_placements is None: | ||
| input_placements = [None] * len(input_nodes) | ||
|
|
@@ -538,6 +690,12 @@ def add_sharded_input_constraint(self, input_placements=None): | |
| ) | ||
|
|
||
| def add_sharded_output_constraint(self, output_placements=None): | ||
| """ | ||
| USER CONSTRAINTS (Category 6a): Output placement constraints. | ||
| Force specific placements for output nodes and their corresponding gradient outputs. | ||
|
|
||
| Mathematical form: x_{i,a,o*,j*} = 1 for specified output placements (o*,j*) | ||
| """ | ||
| # add final constraint on the output strategy | ||
| output_nodes = self.get_fn_output_nodes() | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can i just say that
ssbeing an index andssibeing the value associated with that index has been driving me madThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should rename those variables to more meaningful names. Those namings are there from the first version of the code I wrote, and I was just naming things almost arbitrarily to get things going. I did some cleanup some time ago but those names remained, we should improve it soon