Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions autoparallel/activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,17 @@ def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:


def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
"""
Returns True if the node is a wait_tensor node that is the result of an all_gather
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
single-input operators that leads to a graph input.
"""
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
# maybe just 2D FSDP
# ag_node = node.args[0]
# assert is_graph_input(ag_node.args[0]) or (
# ag_node.args[0].op == "call_function"
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
# and is_graph_input(ag_node.args[0].args[0])
# ), (
# "Assume all_gather_into_tensor input is either graph input "
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
# )
return True
n: torch.fx.Node = node.all_input_nodes[0]
while len(n.all_input_nodes) == 1:
if is_graph_input(n.all_input_nodes[0]):
return True
n = n.all_input_nodes[0]
return False


Expand Down