Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add nested checkpoint in activation checkpoint codegen #1585

Merged

Conversation

Cypher30
Copy link
Contributor

What's New?

As we need to make use of activation checkpoint solver with the setting in rotor and pofo, we might encounter the situation that we need to employ nested checkpoint, i.e. we have something in the following forms

def checkpoint_0(self, x):
    linear3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)
    linear4 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)
    return linear4
def checkpoint_0_0(self, x):
    linear1 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)
    linear2 = self.linear2(linear1);  linear1 = None
    linear3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)
    return linear3
def checkpoint_0_0_0(self, x):
    linear1 = self.linear1(x);  x = None
    return linear1
def checkpoint_0_0_1(self, linear2):
    linear3 = self.linear3(linear2);  linear2 = None
    return linear3
def checkpoint_0_1(self, linear3):
    linear4 = self.linear4(linear3);  linear3 = None
    return linear4
def checkpoint_1(self, linear4):
    linear5 = self.linear5(linear4);  linear4 = None
def forward(self, x):
    linear4 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)
    linear5 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)
    linear6 = self.linear6(linear5);  linear5 = None
    return linear6

in the upcoming solver update, the annotation process will be able to detect those structures, and each node.activation_checkpoint(if annotated) will be a list indicates the checkpoint label in each level, for example, the node with [0, 1, 1] means that it belongs to checkpoint_0_1_1, this function will be called by checkpoint_0_1 and checkpoint_0_1 will be called by checkpoint_0, finally the checkpoint_0 will be called by forward.

Old version of activation checkpoint codegen is also preserved as we have the following mechanism to choose which activation checkpoint to use

if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
       emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
       emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)

As currently we haven't implemented the ColoGraphModule for torch11, I just simply skip the test for it, the following is the test result on torch12
Screen Shot 2022-09-12 at 13 49 51

Cypher30 and others added 30 commits July 14, 2022 16:07
@Cypher30 Cypher30 merged commit f3687e4 into hpcaitech:main Sep 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants