Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,80 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Sharding optimization using Integer Linear Programming (ILP).

This module solves the optimal sharding strategy problem by formulating it as an ILP
where each binary variable x_{i,a,o,j} ∈ {0,1} represents a choice of input placement j
and output placement o for operation i and argument a. The objective minimizes total cost:

minimize: Σ_{i,a,o,j} c_{i,a,o,j} * x_{i,a,o,j}

where:
- x_{i,a,o,j}: binary decision variable (1 if strategy selected, 0 otherwise)
- c_{i,a,o,j}: total cost (communication + computation) for this strategy choice

subject to the following constraint categories:

1. UNIQUENESS CONSTRAINTS: Each operation-argument pair must select exactly one
input-output placement combination.

∀i,a: Σ_{o,j} x_{i,a,o,j} = 1

→ Implemented in: add_unique_decision_constraint()

2. CONSISTENCY CONSTRAINTS: For multi-argument operations, all arguments must agree
on the same output placement to ensure the operation can execute correctly.

∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
where A_i is the number of arguments for operation i.

→ Implemented in: add_same_output_across_args_constraint()

3. FLOW CONSTRAINTS: The output placement of producer operations must match the
input placement of consumer operations (dataflow consistency).

∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
where operation i feeds into operation k at argument position a.

→ Implemented in: add_output_input_consistent_constraint()

4. COST CONSTRAINTS: Variables with infinite cost (invalid configurations) are
forced to zero.

∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0

→ Implemented in: add_inf_cost_constraint()

5. EFFICIENCY CONSTRAINTS: Penalize inefficient collective operations like
non-batch dimension shard-to-replicate conversions and forbid invalid
transitions like replicate-to-partial.

- Shard(dim≠0) → Replicate: multiply cost by 4
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
- Partial → Shard(dim≠0): multiply cost by 4

→ Implemented in: penalize_inefficient_collectives()

6. USER CONSTRAINTS (optional): Force specific placements for inputs, outputs,
parameters, or memory usage bounds.

6a. Input/Output constraints: x_{i,a,o*,j*} = 1 for specified (o*,j*)
→ Implemented in: add_sharded_input_constraint(), add_sharded_output_constraint()

6b. Memory constraints: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
→ Implemented in: add_parameter_memory_constraint()

6c. Parameter-gradient consistency: x_{param} = x_{grad_param}
→ Implemented in: add_grad_param_constraints()

6d. General node constraints: Force specific placement for any node
→ Implemented in: add_node_constraint()

The solver finds the globally optimal sharding strategy that minimizes total
runtime cost while satisfying all constraints.
"""

import math

import pulp
Expand Down Expand Up @@ -43,6 +117,9 @@ def __init__(self, gm, mesh):
self.mesh = mesh
self.node_map = {node: i for i, node in enumerate(self.graph.nodes)}
self.strats = self.build_sharding_metadata()
# ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
# Each key represents a choice of input placement ii and output placement ss
# for operation s_i and argument argi (corresponds to x_{i,a,o,j} in math notation)
self.ds, self.num_inp_out, self.num_args = self.build_ds()
self.validate()
self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize)
Expand Down Expand Up @@ -80,6 +157,26 @@ def build_sharding_metadata(self):
return strats

def build_ds(self):
"""
Build decision variables (ds) for the ILP optimization.

Creates binary variables x_{i,a,o,j} for each valid combination of:
- s_i: operation index
- argi: argument index
- ss: output placement strategy index (o in math notation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can i just say that ss being an index and ssi being the value associated with that index has been driving me mad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should rename those variables to more meaningful names. Those namings are there from the first version of the code I wrote, and I was just naming things almost arbitrarily to get things going. I did some cleanup some time ago but those names remained, we should improve it soon

- ii: input placement strategy index (j in math notation)

Returns:
ds: Dictionary mapping (s_i, argi, ss, ii) -> {
"va": PuLP binary variable,
"cost": communication + computation cost,
"full_strat": complete strategy object,
"out_strat": output placement specification,
"inp_strat": input placement specification
}
num_inp_out: Metadata about strategy counts per operation-argument
num_args: Number of arguments per operation
"""
strats = self.strats
ds = {}
num_inp_out = {}
Expand Down Expand Up @@ -127,6 +224,12 @@ def walk_over_options(self, node, constrain_arg=None):
yield argi, oi, ii

def add_unique_decision_constraint(self):
"""
UNIQUENESS CONSTRAINTS (Category 1): Each operation-argument pair must select exactly one
input-output placement combination.

Mathematical form: ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
"""
# a single pair of input-output policy is chosen
for s_i, node in enumerate(self.graph.nodes):
if node.op not in {"placeholder", "call_function"}:
Expand All @@ -139,6 +242,12 @@ def add_unique_decision_constraint(self):
self.prob += (pulp.lpSum(eqs) == 1, _get_next_name("unique_decision"))

def add_same_output_across_args_constraint(self):
"""
CONSISTENCY CONSTRAINTS (Category 2): For multi-argument operations, all arguments must agree
on the same output placement to ensure the operation can execute correctly.

Mathematical form: ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
"""
# enforce that the same output policy is chosen
# across arguments
for s_i, node in enumerate(self.graph.nodes):
Expand All @@ -163,6 +272,12 @@ def add_same_output_across_args_constraint(self):
)

def add_output_input_consistent_constraint(self):
"""
FLOW CONSTRAINTS (Category 3): The output placement of producer operations must match the
input placement of consumer operations (dataflow consistency).

Mathematical form: ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
"""
# enforce that the input of strat_{i+1} == output of strat_{i}
for s_i, node in enumerate(self.graph.nodes):
if node.op == "output":
Expand Down Expand Up @@ -196,6 +311,12 @@ def add_output_input_consistent_constraint(self):
)

def add_inf_cost_constraint(self):
"""
COST CONSTRAINTS (Category 4): Variables with infinite cost (invalid configurations) are
forced to zero.

Mathematical form: ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0
"""
# force inf cost values to be 0, as the solver doesn't accept inf
for x in self.ds.values():
if not math.isfinite(x["cost"]):
Expand All @@ -213,6 +334,13 @@ def add_default_constraints(self):

def penalize_inefficient_collectives(self):
"""
EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab we discussed this the other day, ideally we should just make the comms cost computation more accurate at some point

non-batch dimension shard-to-replicate conversions and forbid invalid transitions.

- Shard(dim≠0) → Replicate: multiply cost by 4
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
- Partial → Shard(dim≠0): multiply cost by 4

When performing shard_{n} -> replicate (for n != 0), there is additional
computation cost associated. Let's penalize it here while we don't add
the computation cost together in the comm cost
Expand Down Expand Up @@ -440,6 +568,12 @@ def get_grad_param_nodes(self):
return grad_param_nodes

def add_grad_param_constraints(self):
"""
USER CONSTRAINTS (Category 6c): Parameter-gradient consistency constraints.
Ensures parameters and their gradients have matching sharding strategies.

Mathematical form: x_{param} = x_{grad_param}
"""
# TODO: need to make sure that the params and grads are aligned, which are not always the case
# and we might have fewer gradients than parameters

Expand Down Expand Up @@ -478,6 +612,12 @@ def add_grad_param_constraints(self):
)

def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high):
"""
USER CONSTRAINTS (Category 6b): Memory constraints for parameters.
Ensures total parameter memory usage stays within specified bounds.

Mathematical form: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
"""
# get all parameters
param_nodes = self.get_param_nodes()
elms = []
Expand All @@ -499,6 +639,12 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
self.prob += (pulp.lpSum(elms) >= memory_factor_low, "memory_constraint_low")

def add_node_constraint(self, node, placement=None, constraint_name=None):
"""
USER CONSTRAINTS (Category 6d): General node constraints.
Force specific placement for any node.

Mathematical form: x_{i,a,o*,j*} = 1 for specified (o*,j*)
"""
strat = self.strats[node]
if placement is None:
# default is Shard(0) to parallelize on the batch
Expand All @@ -514,6 +660,12 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
self._add_node_constraint(node, oi=oi, constraint_name=constraint_name)

def add_sharded_input_constraint(self, input_placements=None):
"""
USER CONSTRAINTS (Category 6a): Input placement constraints.
Force specific placements for input nodes and their corresponding gradient inputs.

Mathematical form: x_{i,a,o*,j*} = 1 for specified input placements (o*,j*)
"""
input_nodes = self.get_input_nodes()
if input_placements is None:
input_placements = [None] * len(input_nodes)
Expand All @@ -538,6 +690,12 @@ def add_sharded_input_constraint(self, input_placements=None):
)

def add_sharded_output_constraint(self, output_placements=None):
"""
USER CONSTRAINTS (Category 6a): Output placement constraints.
Force specific placements for output nodes and their corresponding gradient outputs.

Mathematical form: x_{i,a,o*,j*} = 1 for specified output placements (o*,j*)
"""
# add final constraint on the output strategy
output_nodes = self.get_fn_output_nodes()

Expand Down