Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions autoparallel/auto_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import torch
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing

Expand Down Expand Up @@ -112,3 +114,46 @@ def aten_autobucketing_reordering_pass(
max_in_flight_gb=configs.max_in_flight_gb,
max_coll_distance=configs.max_coll_distance,
)


def configure_inductor_for_autobucketing(mode: str = "aten"):
# allow configuring inductor comms optimizations from torchtitan commandline
if mode == "aten":
from autoparallel.auto_bucketing import (
aten_autobucketing_config,
aten_autobucketing_reordering_pass,
)

# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
aten_autobucketing_reordering_pass = partial(
aten_autobucketing_reordering_pass,
configs=aten_autobucketing_config, # type: ignore
)
torch._inductor.config.post_grad_custom_post_pass = (
aten_autobucketing_reordering_pass # type: ignore
)
elif mode == "inductor":
from autoparallel.auto_bucketing import (
simple_fsdp_autobucketing_reordering_pass,
simplefsdp_autobucketing_config,
)

torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
simplefsdp_autobucketing_config.calibrate_number = 5
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config, # type: ignore
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]
elif mode == "none":
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
else:
raise ValueError(f"Unknown comms bucket reorder strategy: {mode}")
38 changes: 2 additions & 36 deletions examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import time
from functools import partial

import torch
from torch.distributed.fsdp import MixedPrecisionPolicy
Expand All @@ -13,12 +12,7 @@

from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
from autoparallel.api import AutoParallel
from autoparallel.auto_bucketing import (
aten_autobucketing_config,
aten_autobucketing_reordering_pass,
simple_fsdp_autobucketing_reordering_pass,
simplefsdp_autobucketing_config,
)
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing

world_size = 64

Expand Down Expand Up @@ -89,35 +83,7 @@ def input_fn():
return x


autobucketing_level = "aten"

if autobucketing_level == "aten":
# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
aten_autobucketing_reordering_pass = partial(
aten_autobucketing_reordering_pass,
configs=aten_autobucketing_config,
)
torch._inductor.config.post_grad_custom_post_pass = (
aten_autobucketing_reordering_pass
)
elif autobucketing_level == "inductor":
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
simplefsdp_autobucketing_config.calibrate_number = 5
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config,
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]
else:
raise ValueError(f"Unknown autobucketing_level {autobucketing_level}")

configure_inductor_for_autobucketing("aten")

# parallelize the model
with torch.device("meta"):
Expand Down
51 changes: 28 additions & 23 deletions mast/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,13 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.comms_bucket_reorder_strategy=none",
],
"llama3_autop_1d_compile_bucket_reorder": llama3_1d_common_opts
"llama3_autop_1d_compile_aten_bucket_reorder": llama3_1d_common_opts
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.bucket_all_gathers_fx=fsdp",
"--experimental.bucket_reduce_scatters_fx=fsdp",
"--experimental.reorder_for_compute_comm_overlap",
"--experimental.comms_bucket_reorder_strategy=aten",
],
}

Expand All @@ -127,41 +126,31 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.comms_bucket_reorder_strategy=none",
],
"llama3_autop_2d_compile_bucket_reorder": llama3_2d_common_opts
"llama3_autop_2d_compile_aten_bucket_reorder": llama3_2d_common_opts
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.bucket_all_gathers_fx=fsdp",
"--experimental.bucket_reduce_scatters_fx=fsdp",
"--experimental.reorder_for_compute_comm_overlap",
"--experimental.comms_bucket_reorder_strategy=aten",
],
}

test_run = {
"FSDP_tp_compile": llama3_2d_common_opts
+ [
"--model.name=llama3",
"--compile.enable",
],
}


all_runs = (
llama3_1d
| llama3_2d
| {
"llama3_autop_1d_compile_ruisi_bucket_reorder": llama3_1d_common_opts
"llama3_autop_1d_compile_inductor_bucket_reorder": llama3_1d_common_opts
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.enable_simplefsdp_passes",
"--experimental.comms_bucket_reorder_strategy=inductor",
],
"llama3_autop_2d_compile_ruisi_bucket_reorder": llama3_2d_common_opts
"llama3_autop_2d_compile_inductor_bucket_reorder": llama3_2d_common_opts
+ [
"--model.name=auto_parallel.llama3",
"--compile.enable",
"--experimental.enable_simplefsdp_passes",
"--experimental.comms_bucket_reorder_strategy=inductor",
],
}
)
Expand All @@ -178,10 +167,26 @@ def build_sweep(names):
[
"llama3_FSDP_compile",
"llama3_autop_1d_compile",
"llama3_autop_1d_compile_ruisi_bucket_reorder",
"llama3_autop_1d_compile_inductor_bucket_reorder",
"llama3_FSDP_tp_compile",
"llama3_autop_2d_compile",
"llama3_autop_2d_compile_inductor_bucket_reorder",
]
),
"compare_1d_bucketing": build_sweep(
[
"llama3_FSDP_compile",
"llama3_autop_1d_compile",
"llama3_autop_1d_compile_aten_bucket_reorder",
"llama3_autop_1d_compile_inductor_bucket_reorder",
]
),
"compare_2d_bucketing": build_sweep(
[
"llama3_FSDP_tp_compile",
"llama3_autop_2d_compile",
"llama3_autop_2d_compile_ruisi_bucket_reorder",
"llama3_autop_2d_compile_aten_bucket_reorder",
"llama3_autop_2d_compile_inductor_bucket_reorder",
]
),
}
Expand Down