Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
enable_ac: bool = True,
# None means 'auto'
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
enable_asynctp: bool = False,
**kwargs,
):
self.stack = ExitStack()
Expand Down Expand Up @@ -210,6 +211,8 @@ def __init__(
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB

self.enable_asynctp = enable_asynctp

# NB: rest of the construction happens in __enter__
self.active = False

Expand All @@ -236,6 +239,7 @@ def __enter__(self):
self.mesh,
rescale_grad_comm_cost_for_mp,
repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False),
enable_asynctp=self.enable_asynctp,
)

# makes sharding of params and gradients the same
Expand Down
137 changes: 134 additions & 3 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from .propagation_rules import _create_all_options
from .utils import get_local_map_placement_option, get_placement_options

aten = torch.ops.aten


def _debug_node(node):
def my_print(x):
Expand All @@ -128,7 +130,12 @@ def _get_next_name(name):

class ShardingOptimizer:
def __init__(
self, gm, mesh, rescale_grad_comm_cost_for_mp=1.0, repeated_subgraphs=False
self,
gm,
mesh,
rescale_grad_comm_cost_for_mp=1.0,
repeated_subgraphs=False,
enable_asynctp=False,
):
self.gm = gm
self.graph = gm.graph
Expand All @@ -139,11 +146,13 @@ def __init__(
self.strats = self.build_sharding_metadata()

self.cluster_links = {}
repeated_subgraphs = True
if repeated_subgraphs:
t = time.time()
clusters = get_identical_regions(self.gm.graph, self.strats)
print(f"Found {len(clusters)} clusters in {time.time() - t:.2f}s")
self.create_cluster_links(clusters)
self.enable_asynctp = enable_asynctp

# ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
# Each key represents a choice of input placement ii and output placement ss
Expand Down Expand Up @@ -556,11 +565,11 @@ def print_old(self):
print(self.get_violated_constraints_log())

def get_log(self, colored=False):

from torch.fx.graph import _color_fns, _identity

opt = {}
nodes = list(self.graph.nodes)
log_shapes = False
for x in self.res:
opt.setdefault(nodes[x[0]], []).append(self.ds[x])

Expand Down Expand Up @@ -600,10 +609,28 @@ def get_log(self, colored=False):
code.insert(l_id, line)
l_id += 1
continue
log_extra = ""
if log_shapes:
if (
isinstance(node, torch.fx.Node)
and "val" in node.meta
and isinstance(node.meta["val"], torch.Tensor)
):
log_extra += "("
for arg in node.args:
if (
isinstance(arg, torch.fx.Node)
and "val" in arg.meta
and isinstance(arg.meta["val"], torch.Tensor)
):
log_extra += str(list(arg.meta["val"].shape))
log_extra += ") -> "
log_extra += str(list(node.meta["val"].shape))
log_extra += "\n"
# LOL
while not code[l_id].lstrip().startswith(repr(node)):
l_id += 1
code[l_id] += line
code[l_id] = log_extra + code[l_id] + line
l_id += 1
code = "\n".join(code)
total_cost = sum(self.ds[x]["cost"] for x in self.res)
Expand Down Expand Up @@ -641,6 +668,11 @@ def get_solution(self, verbose=False):
# add their costs
for x in self.ds.values():
opt_target[x["va"]] += x["cost"]

if self.enable_asynctp:
for va, cost in self.add_asynctp_scores().items():
opt_target[va] += cost

self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()])

# solver = pulp.HiGHS(msg=verbose)
Expand Down Expand Up @@ -885,6 +917,105 @@ def add_sharded_output_constraint(self, output_placements=None):
"them from the graph to avoid aliasing."
)

def add_asynctp_scores(self):
# Encourage placements that enable asyncTP fusions:
# -X % of comm_cost
# 1. ag + mm: S(d) -> R, d < mm.ndim - 1
# 2. mm + rs: P -> S(d)
# TODO1: Filter out FSDP ag/rs that will not be asyncTPed
# TODO2: With AsyncTP we should have perf wins,
# overlapping ((group_size - 1) / group_size) of communication
# minus cost of decomposition.
# For this we need to get group_size from the redistribution.
def _get_transformations(src_spec, tgt_spec):
# TODO: Use real transform preparation
# For now just checking left to right
src_pls = src_spec.placements
tgt_pls = tgt_spec.placements
transformations = []
for src_pl, tgt_pl in zip(src_pls, tgt_pls):
if src_pl == tgt_pl:
continue
transformations.append((src_pl, tgt_pl))
return transformations

def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim):
# Check that the last transition will be S(dim) -> Replicate

transformations = _get_transformations(src_spec, tgt_spec)
if len(transformations) == 0:
return False
last_t = transformations[-1]
return (
last_t[1].is_replicate()
and last_t[0].is_shard()
and last_t[0].dim < mm_dim - 1
)

def _produces_asynctp_rs(src_spec, tgt_spec, mm_dim):
# Check that the last transition will be P -> S(dim)
transformations = _get_transformations(src_spec, tgt_spec)
if len(transformations) == 0:
return False
last_t = transformations[-1]
return last_t[0].is_partial() and last_t[1].is_shard()

va_cost_delta = defaultdict(int)
strats = self.strats
for s_i, (node, s) in enumerate(strats.items()):
if not (node.op == "call_function" and node.target == aten.mm.default):
continue
mm_n = node
# Incentivize ag+mm
# ard0 of MM should be S(dim) -> R to have all_gather before mm
a_n = node.args[0]
mm_sts = s.strategies
for mm_st_i, mm_st in enumerate(mm_sts):
a_sts = strats[a_n].strategies
mm_tgt_spec = mm_st.input_specs[0]
for a_st_i, a_st in enumerate(a_sts):
a_src_spec = a_st.output_spec
# TODO: Is adding constraint to arg is enough or we need to follow the arg
# ancestors and find the first sharding change?
if _produces_asynctp_ag(
a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim
):
# TODO: We want to to calculate the cost of specific AG, as it will be pipelined,
# for now using just redistribution cost
cost = mm_st.redistribute_cost[0][a_st_i]
if cost == float("inf"):
continue
va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"]
va_cost_delta[va] += -0.3 * cost
# mm+rs
src_spec = mm_st.output_spec
if len(mm_n.users) == 0:
continue
mm_user = next(iter(mm_n.users))
mm_user_s_i = self.node_map[mm_user]
mm_u_arg_mm_i = -1
for i, arg in enumerate(mm_user.args):
if arg == mm_n:
mm_u_arg_mm_i = i
assert mm_u_arg_mm_i != -1
mm_user_sts = strats[mm_user].strategies
for mm_u_st_i, mm_u_st in enumerate(mm_user_sts):
if _produces_asynctp_rs(
src_spec,
mm_u_st.input_specs[mm_u_arg_mm_i],
mm_n.meta["val"].ndim,
):
# TODO: We want to to calculate the cost of specific RS, as it will be pipelined,
# for now using just redistribution cost
cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i]
if cost == float("inf"):
continue
key = (mm_user_s_i, mm_u_arg_mm_i, mm_u_st_i, mm_st_i)
va = self.ds[key]["va"]
va_cost_delta[va] += -0.3 * cost

return va_cost_delta

def validate(self):
for node in self.graph.nodes:
if node.op != "call_function":
Expand Down
55 changes: 47 additions & 8 deletions examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,13 @@ def add_tp_constraints(autop):

# parallelize the model
with AutoParallel(
model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True
model,
input_fn,
mesh,
mp_policy,
compile=True,
repeated_subgraphs=True,
enable_asynctp=enable_asynctp,
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

Expand All @@ -229,22 +235,55 @@ def add_tp_constraints(autop):
if enable_manual_constraint and not use_1d_mesh:
add_tp_constraints(autop)

if enable_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
enable_overlap_scheduling = True
enable_overlap_scheduling_bucketing = True
if enable_overlap_scheduling_bucketing:
assert (
enable_overlap_scheduling
), "bucketing can not be used without overlap scheduling"
enable_asynctp = True
from autoparallel.asynctp import _micro_pipeline_tp_ag_transpose_mm_enabled, _micro_pipeline_tp_ag_mm_last_dim_enabled
_micro_pipeline_tp_ag_transpose_mm_enabled = True
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
if (
enable_overlap_scheduling
or enable_overlap_scheduling_bucketing
or enable_asynctp
):
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
enable_overlap_scheduling_bucketing
)

if enable_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
from autoparallel.asynctp import micro_pipeline_tp_pass
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
# TODO: Switch to Inductor AsyncTP passes, when all additions landed.
from autoparallel.asynctp import micro_pipeline_tp_pass

existing_post_grad_custom_post_pass = (
torch._inductor.config.post_grad_custom_post_pass
)
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler

def _pass(graph):
if existing_post_grad_custom_post_pass is not None:
existing_post_grad_custom_post_pass(graph)
micro_pipeline_tp_pass(graph)

collective_info = None
if enable_overlap_scheduling:
overlap_scheduler = OverlapScheduler(graph.owning_module)
overlap_scheduler.run()
collective_info = overlap_scheduler.collective_info

if enable_asynctp:
micro_pipeline_tp_pass(graph, collective_info)

torch._inductor.config.post_grad_custom_post_pass = _pass

Expand Down
6 changes: 6 additions & 0 deletions mast/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
"--model.name=llama3",
"--compile.enable",
],
"llama3_FSDP_tp_async_tp_compile": llama3_2d_common_opts
+ [
"--model.name=llama3",
"--compile.enable",
"--parallelism.enable_async_tensor_parallel",
],
"llama3_autop_2d_compile": llama3_2d_common_opts
+ [
"--model.name=llama3_auto_parallel",
Expand Down
Loading