Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add nested checkpoint in activation checkpoint codegen #1585

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 Jul 14, 2022
75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 Jul 15, 2022
3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 Jul 20, 2022
cf24049
Merge remote-tracking branch 'upstream/main' into main
Jul 20, 2022
3d223b6
Merge remote-tracking branch 'upstream/main' into main
Jul 21, 2022
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 22, 2022
d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 25, 2022
bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 6, 2022
0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 8, 2022
74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
e550490
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 11, 2022
b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 12, 2022
b4b0974
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 15, 2022
65c20de
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 16, 2022
1660bfc
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 17, 2022
6eb0ad0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 20, 2022
56df059
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 26, 2022
480e932
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
0fa66ee
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
1d013b0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 31, 2022
5774db2
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 5, 2022
e8ff699
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 6, 2022
855c728
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 7, 2022
2c113ea
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 8, 2022
a987de0
[fx] add nested activation_checkpoint codegen
Sep 12, 2022
06da433
Merge branch 'hpcaitech:main' into feature/add_nested_checkpoint_codegen
Cypher30 Sep 12, 2022
bfcb7cd
undo algorithms commits
Sep 12, 2022
eef18c8
solver
Sep 12, 2022
59de066
undo some commits
Sep 12, 2022
456d844
[fx] torch11 add nested activation checkpoint codegen
Sep 12, 2022
f1f356b
remove some imports
Sep 12, 2022
0b10758
[fx] add some comments in activation codegen
Sep 12, 2022
931e4e8
[fx] codegen instance error fix
Sep 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
213 changes: 211 additions & 2 deletions colossalai/fx/codegen/activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,209 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'


def _end_of_ckpt(node: Node, check_idx: int) -> bool:
"""Check if the node could end the ckpt region

Args:
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
nested checkpoint

Returns:
bool
"""
if hasattr(node, "activation_checkpoint"):
if isinstance(node.activation_checkpoint, list):
return node.activation_checkpoint[check_idx] == None
else:
return False
else:
return True


def _find_nested_ckpt_regions(nodes, check_idx=0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
start = -1
end = -1
current_region = None

for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
if isinstance(getattr(node, 'activation_checkpoint'), int):
act_ckpt_label = node.activation_checkpoint
else:
act_ckpt_label = node.activation_checkpoint[check_idx]

# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx

# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and _end_of_ckpt(node, check_idx):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass

if current_region is not None:
end = len(nodes) - 1
ckpt_regions.append((start, end))
return ckpt_regions


def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way

Args:
body: forward code, in recursive calls, this part will be checkpoint
functions code
ckpt_func: checkpoint functions code, in recursive calls, this part
will be a buffer
node_list (List[Node]): list of torch.fx.Node
emit_node_func: function to emit a node
delete_unused_value_func: function to delete unused value
level (int, optional): checkpoint level. Defaults to 0.
in_ckpt (bool, optional): indicates wether the func is in recursive
call. Defaults to False.
"""
inputs, outputs = _find_input_and_output_nodes(node_list)

# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].activation_checkpoint, int):
label = node_list[0].activation_checkpoint
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)

# use nested ckpt function codegen
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')

# if there is more level to fetch
if level + 1 < len(node_list[0].activation_checkpoint):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]

# use ckpt_func_buffer to store nested checkpoint functions
ckpt_func_buffer = []
node_idx = 0
while 1:
if node_idx >= len(node_list):
break

if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
delete_unused_value_func, level + 1, True)
node_idx += len(ckpt_node_list)

else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)

# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)

ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)


def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.

Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]

node_list = list(nodes)

node_idx = 0
while 1:
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break

# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)

# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1


def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes)
Expand Down Expand Up @@ -384,7 +587,10 @@ def emit_node(node: Node, body):

# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)

if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
Expand Down Expand Up @@ -582,7 +788,10 @@ def emit_node(node: Node, body):

# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)

if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule

try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False


class MyModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x))))))


def _run_act_ckpt_codegen(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

# build model and run forward
model = MyModule()
data1 = torch.rand(4, 4)

# copy model to cuda
model = model.to(device="cuda")
data1 = data1.to(device="cuda")

non_fx_out = model(data1)

# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)

# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
gm = ColoGraphModule(model, graph)
gm.recompile()

# assert checkpoint function will be generated and
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code

# recompile and verify the outputs are consistent
fx_out = gm(data1)
assert torch.equal(non_fx_out, fx_out)

gpc.destroy()


@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
mp.spawn(_run_act_ckpt_codegen, nprocs=1)


def _run_act_ckpt_python_code_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

# build model and run forward
model = MyModule()
data1 = torch.rand(4, 4)

# copy model to cuda
model = model.to(device="cuda")
data1 = data1.to(device="cuda")

non_fx_out = model(data1)

# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)

# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
gm = ColoGraphModule(model, graph)
gm.recompile()

# assert checkpoint function will be generated and
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code

# recompile and verify the outputs are consistent
fx_out = gm(data1)
assert torch.equal(non_fx_out, fx_out)

gpc.destroy()


@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)


if __name__ == '__main__':
_run_act_ckpt_codegen(rank=0)