Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add offload codegen #1598

Merged
merged 36 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 Jul 14, 2022
75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 Jul 15, 2022
3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 Jul 20, 2022
cf24049
Merge remote-tracking branch 'upstream/main' into main
Jul 20, 2022
3d223b6
Merge remote-tracking branch 'upstream/main' into main
Jul 21, 2022
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 22, 2022
d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 25, 2022
bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 6, 2022
0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 8, 2022
74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
e550490
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 11, 2022
b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 12, 2022
b4b0974
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 15, 2022
65c20de
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 16, 2022
1660bfc
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 17, 2022
6eb0ad0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 20, 2022
56df059
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 26, 2022
480e932
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
0fa66ee
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
1d013b0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 31, 2022
5774db2
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 5, 2022
e8ff699
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 6, 2022
855c728
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 7, 2022
2c113ea
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 8, 2022
838ba70
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 13, 2022
cacec2b
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 13, 2022
5ed6ef0
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 14, 2022
6dbac6a
[fx] add input activation offload to codegen
Sep 14, 2022
9a7dcc6
[fx] modify unit test
Sep 14, 2022
b0afb21
Merge branch 'hpcaitech:main' into feature/add_offload_codegen
Cypher30 Sep 14, 2022
18b0e4f
[fx] remove two skips in torch11
Sep 14, 2022
dd8cf0e
Merge branch 'feature/add_offload_codegen' of github.com:Cypher30/Col…
Sep 14, 2022
3d9d63a
[fx] use all_input_nodes instead of _input_nodes
Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 95 additions & 15 deletions colossalai/fx/codegen/activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@
__all__ = ['python_code_with_activation_checkpoint']


def _gen_saved_tensors_hooks():
"""
Generate saved tensors hooks
"""

pack_hook = """def pack_hook(self, x):
if getattr(x, "offload", None):
return (x.device, x.cpu())
else:
return x
"""

unpack_hook = """def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""

return pack_hook, unpack_hook


def _gen_save_tensors_hooks_context():
"""
Generate save tensors hooks context
"""

context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
return context


def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
Expand Down Expand Up @@ -211,7 +243,7 @@ def emit_ckpt_func(body,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
Expand Down Expand Up @@ -251,7 +283,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
node_idx += 1

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
Expand All @@ -266,7 +298,7 @@ def emit_ckpt_func(body,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
Expand All @@ -292,6 +324,9 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod

node_list = list(nodes)

# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
is_hook_inserted = False
node_idx = 0
while 1:
# break if we finish the processing all the nodes
Expand All @@ -307,8 +342,27 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)

# if a node is outside of checkpoint region and want to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")

for par in node._input_nodes:
Cypher30 marked this conversation as resolved.
Show resolved Hide resolved
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")

body.append(_gen_save_tensors_hooks_context())
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)

else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1


Expand All @@ -323,6 +377,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,

node_list = list(nodes)

# use this variable to avoid inserting hook functions
# to ckpt_func repeatedly
is_hook_inserted = False

# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
Expand All @@ -348,8 +406,26 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
# if a node is outside of checkpoint region wants to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")

for par in node._input_nodes:
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")

body.append(_gen_save_tensors_hooks_context())
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)

else:
emit_node_func(node, body)
delete_unused_value_func(node, body)

if idx in end_idx:
# if this is the last node of the ckpt region
Expand Down Expand Up @@ -587,10 +663,13 @@ def emit_node(node: Node, body):

# Modified for activation checkpointing
ckpt_func = []
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:

# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)

if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
Expand All @@ -612,7 +691,6 @@ def emit_node(node: Node, body):

# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
Expand Down Expand Up @@ -788,10 +866,13 @@ def emit_node(node: Node, body):

# Modified for activation checkpointing
ckpt_func = []
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:

# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)

if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
Expand Down Expand Up @@ -827,7 +908,6 @@ def emit_node(node: Node, body):

# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
fn_code = f"""
{wrap_stmts}

Expand Down
14 changes: 10 additions & 4 deletions colossalai/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, c
super().__init__(root, graph, class_name)

def bind(self, ckpt_def, globals):
"""Bind checkpoint functions to ColoGraphModule
We need to bind our checkpoint functions to the GraphModule so
that we could correctly use self.checkpoint for GraphModule forward
"""Bind function needed for correctly execute gm forward

We need to bind checkpoint functions and saved_tensor_hooks functions
to gm so that we could correctly execute gm forward

Args:
ckpt_def (_type_): definition before the forward function
globals (_type_): global variables
"""

ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from operator import mod
import torch
import torch.nn.functional as F
import pytest
Expand Down
159 changes: 159 additions & 0 deletions tests/test_fx/test_codegen/test_offload_codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import copy
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule

try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False


class MyNet(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
return x


def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p.grad, gm_p.grad):
return False
return True


def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):

# test forward
non_fx_out = model(data)
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"

# test barckward
loss0 = non_fx_out.sum()
loss0.backward()
loss1 = fx_out.sum()
loss1.backward()
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"


def _run_offload_codegen(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()

# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)

# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0])

gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
print(gm)

# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code

_test_fwd_and_bwd(model, gm, data)
gpc.destroy()


@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
mp.spawn(_run_offload_codegen, nprocs=1)


def _run_offload_codegen_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()

# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)

# replace a bound method of an object
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)

# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0])

gm = ColoGraphModule(copy.deepcopy(model), graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think copy is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy is needed because we want to test backward gradients.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i see.

gm.recompile()
print(gm)

# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code

_test_fwd_and_bwd(model, gm, data)
gpc.destroy()


@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_offload_codegen_torch11, nprocs=1)


if __name__ == "__main__":
_run_offload_codegen(0)