diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py index 447d4bb..1ae913d 100644 --- a/autoparallel/auto_bucketing.py +++ b/autoparallel/auto_bucketing.py @@ -119,21 +119,14 @@ def aten_autobucketing_reordering_pass( def configure_inductor_for_autobucketing(mode: str = "aten"): # allow configuring inductor comms optimizations from torchtitan commandline if mode == "aten": - from autoparallel.auto_bucketing import ( - aten_autobucketing_config, - aten_autobucketing_reordering_pass, - ) - - # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 - torch._inductor.config.reorder_for_peak_memory = False - torch._inductor.config.reorder_for_compute_comm_overlap = False - aten_autobucketing_reordering_pass = partial( - aten_autobucketing_reordering_pass, - configs=aten_autobucketing_config, # type: ignore + torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling = ( + True ) - torch._inductor.config.post_grad_custom_post_pass = ( - aten_autobucketing_reordering_pass # type: ignore + torch._inductor.config.aten_distributed_optimizations.collective_bucketing = ( + True ) + torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True + torch._inductor.config.aten_distributed_optimizations.max_compute_pre_fetch = 10 elif mode == "inductor": from autoparallel.auto_bucketing import ( simple_fsdp_autobucketing_reordering_pass,