diff --git a/autoparallel/api.py b/autoparallel/api.py index ce087738..3ae9dc63 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -288,8 +288,6 @@ def build_model_graph(self): ep.module(), inputs, decompositions=decomp_table, - fw_compiler=self.compiler_fn, - bw_compiler=self.compiler_fn, ) gm = self.joint_with_descriptors.graph_module assert_has_no_collectives(gm) @@ -454,7 +452,9 @@ def apply_placement(self, sharding_placement=None): ) self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( - self.joint_with_descriptors + self.joint_with_descriptors, + fw_compiler=self.compiler_fn, + bw_compiler=self.compiler_fn, ) # TODO: this probably belongs in the AOTAutograd API