diff --git a/autoparallel/api.py b/autoparallel/api.py index dcb80689..d036d73d 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -280,8 +280,10 @@ def build_model_graph(self): # we basically want to remove noops in here prev = torch._inductor.config.pattern_matcher torch._inductor.config.pattern_matcher = False - gm = joint_graph_passes(gm) - torch._inductor.config.pattern_matcher = prev + try: + gm = joint_graph_passes(gm) + finally: + torch._inductor.config.pattern_matcher = prev remove_assert_ops(gm.graph) gm.graph.eliminate_dead_code() gm.recompile()