Skip to content

Commit

Permalink
[fx] fix offload codegen test (#1648)
Browse files Browse the repository at this point in the history
* [fx] fix offload codegen test

* [fx] modify typing
  • Loading branch information
Cypher30 committed Sep 27, 2022
1 parent 45b39a6 commit 5d0fdb9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions colossalai/fx/codegen/activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import colossalai
import torch
from typing import List, Callable, Any, Tuple, Dict
from typing import List, Callable, Any, Tuple, Dict, Iterable

try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
Expand Down Expand Up @@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None

for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
act_offload_label = node.activation_offload

if current_region == None:
Expand Down Expand Up @@ -796,7 +796,7 @@ def emit_node(node: Node, body):

# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
Expand Down Expand Up @@ -999,7 +999,7 @@ def emit_node(node: Node, body):

# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_fx/test_codegen/test_offload_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", (1, True, True))
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
Expand Down Expand Up @@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", (1, True, True))
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
Expand Down

0 comments on commit 5d0fdb9

Please sign in to comment.