diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py index 641285b..447d4bb 100644 --- a/autoparallel/auto_bucketing.py +++ b/autoparallel/auto_bucketing.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from functools import partial + import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing @@ -112,3 +114,46 @@ def aten_autobucketing_reordering_pass( max_in_flight_gb=configs.max_in_flight_gb, max_coll_distance=configs.max_coll_distance, ) + + +def configure_inductor_for_autobucketing(mode: str = "aten"): + # allow configuring inductor comms optimizations from torchtitan commandline + if mode == "aten": + from autoparallel.auto_bucketing import ( + aten_autobucketing_config, + aten_autobucketing_reordering_pass, + ) + + # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + aten_autobucketing_reordering_pass = partial( + aten_autobucketing_reordering_pass, + configs=aten_autobucketing_config, # type: ignore + ) + torch._inductor.config.post_grad_custom_post_pass = ( + aten_autobucketing_reordering_pass # type: ignore + ) + elif mode == "inductor": + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, + ) + + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.calibrate_number = 5 + simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, # type: ignore + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] + elif mode == "none": + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + else: + raise ValueError(f"Unknown comms bucket reorder strategy: {mode}") diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96..9aa2ea8 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import time -from functools import partial import torch from torch.distributed.fsdp import MixedPrecisionPolicy @@ -13,12 +12,7 @@ from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel -from autoparallel.auto_bucketing import ( - aten_autobucketing_config, - aten_autobucketing_reordering_pass, - simple_fsdp_autobucketing_reordering_pass, - simplefsdp_autobucketing_config, -) +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing world_size = 64 @@ -89,35 +83,7 @@ def input_fn(): return x -autobucketing_level = "aten" - -if autobucketing_level == "aten": - # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 - torch._inductor.config.reorder_for_peak_memory = False - torch._inductor.config.reorder_for_compute_comm_overlap = False - aten_autobucketing_reordering_pass = partial( - aten_autobucketing_reordering_pass, - configs=aten_autobucketing_config, - ) - torch._inductor.config.post_grad_custom_post_pass = ( - aten_autobucketing_reordering_pass - ) -elif autobucketing_level == "inductor": - torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.reorder_for_peak_memory = False - torch._inductor.config.reorder_for_compute_comm_overlap = True - simplefsdp_autobucketing_config.calibrate_number = 5 - simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" - simple_fsdp_autobucketing_reordering_pass = partial( - simple_fsdp_autobucketing_reordering_pass, - configs=simplefsdp_autobucketing_config, - ) - torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ - simple_fsdp_autobucketing_reordering_pass - ] -else: - raise ValueError(f"Unknown autobucketing_level {autobucketing_level}") - +configure_inductor_for_autobucketing("aten") # parallelize the model with torch.device("meta"): diff --git a/mast/sweep.py b/mast/sweep.py index 3ab22ed..dfa2b45 100644 --- a/mast/sweep.py +++ b/mast/sweep.py @@ -106,14 +106,13 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]: + [ "--model.name=auto_parallel.llama3", "--compile.enable", + "--experimental.comms_bucket_reorder_strategy=none", ], - "llama3_autop_1d_compile_bucket_reorder": llama3_1d_common_opts + "llama3_autop_1d_compile_aten_bucket_reorder": llama3_1d_common_opts + [ "--model.name=auto_parallel.llama3", "--compile.enable", - "--experimental.bucket_all_gathers_fx=fsdp", - "--experimental.bucket_reduce_scatters_fx=fsdp", - "--experimental.reorder_for_compute_comm_overlap", + "--experimental.comms_bucket_reorder_strategy=aten", ], } @@ -127,41 +126,31 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]: + [ "--model.name=auto_parallel.llama3", "--compile.enable", + "--experimental.comms_bucket_reorder_strategy=none", ], - "llama3_autop_2d_compile_bucket_reorder": llama3_2d_common_opts + "llama3_autop_2d_compile_aten_bucket_reorder": llama3_2d_common_opts + [ "--model.name=auto_parallel.llama3", "--compile.enable", - "--experimental.bucket_all_gathers_fx=fsdp", - "--experimental.bucket_reduce_scatters_fx=fsdp", - "--experimental.reorder_for_compute_comm_overlap", + "--experimental.comms_bucket_reorder_strategy=aten", ], } -test_run = { - "FSDP_tp_compile": llama3_2d_common_opts - + [ - "--model.name=llama3", - "--compile.enable", - ], -} - - all_runs = ( llama3_1d | llama3_2d | { - "llama3_autop_1d_compile_ruisi_bucket_reorder": llama3_1d_common_opts + "llama3_autop_1d_compile_inductor_bucket_reorder": llama3_1d_common_opts + [ "--model.name=auto_parallel.llama3", "--compile.enable", - "--experimental.enable_simplefsdp_passes", + "--experimental.comms_bucket_reorder_strategy=inductor", ], - "llama3_autop_2d_compile_ruisi_bucket_reorder": llama3_2d_common_opts + "llama3_autop_2d_compile_inductor_bucket_reorder": llama3_2d_common_opts + [ "--model.name=auto_parallel.llama3", "--compile.enable", - "--experimental.enable_simplefsdp_passes", + "--experimental.comms_bucket_reorder_strategy=inductor", ], } ) @@ -178,10 +167,26 @@ def build_sweep(names): [ "llama3_FSDP_compile", "llama3_autop_1d_compile", - "llama3_autop_1d_compile_ruisi_bucket_reorder", + "llama3_autop_1d_compile_inductor_bucket_reorder", + "llama3_FSDP_tp_compile", + "llama3_autop_2d_compile", + "llama3_autop_2d_compile_inductor_bucket_reorder", + ] + ), + "compare_1d_bucketing": build_sweep( + [ + "llama3_FSDP_compile", + "llama3_autop_1d_compile", + "llama3_autop_1d_compile_aten_bucket_reorder", + "llama3_autop_1d_compile_inductor_bucket_reorder", + ] + ), + "compare_2d_bucketing": build_sweep( + [ "llama3_FSDP_tp_compile", "llama3_autop_2d_compile", - "llama3_autop_2d_compile_ruisi_bucket_reorder", + "llama3_autop_2d_compile_aten_bucket_reorder", + "llama3_autop_2d_compile_inductor_bucket_reorder", ] ), }