diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index b930ab2c..5cafd920 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -49,19 +49,17 @@ def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: + """ + Returns True if the node is a wait_tensor node that is the result of an all_gather + that can be arbitrarily prefetched, i.e., if all its recursive inputs are + single-input operators that leads to a graph input. + """ if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): - # TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait, - # maybe just 2D FSDP - # ag_node = node.args[0] - # assert is_graph_input(ag_node.args[0]) or ( - # ag_node.args[0].op == "call_function" - # and ag_node.args[0].target == torch.ops.prims.convert_element_type.default - # and is_graph_input(ag_node.args[0].args[0]) - # ), ( - # "Assume all_gather_into_tensor input is either graph input " - # + f"or dtype conversion of graph input, but got {ag_node.args[0]}" - # ) - return True + n: torch.fx.Node = node.all_input_nodes[0] + while len(n.all_input_nodes) == 1: + if is_graph_input(n.all_input_nodes[0]): + return True + n = n.all_input_nodes[0] return False