Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] added activation checkpoint codegen #1355

Merged
merged 1 commit into from
Jul 25, 2022

Conversation

FrankLeeeee
Copy link
Contributor

@FrankLeeeee FrankLeeeee commented Jul 22, 2022

In the previous PR #1349 , we annotated the nodes if they are activation-checkpointed. In this PR, we utilize these annotations and generate model forward code with activation checkpoint. The new CodeGen inherits from the PyTorch CodeGen, the code change can be found by looking for the following multi-line comment.

#########################################
# Modified for activation checkpointing #
#########################################

A unit test is added in this PR as well. Below are the different codes generated with/without activation checkpoint.

# without activation checkpoint
def forward(self, x):
    mlp1_linear1 = self.mlp1.linear1(x)
    mlp1_linear1_1 = self.mlp1.linear1(x)
    mlp2_linear1 = self.mlp2.linear1(x)
    mlp2_linear1_1 = self.mlp2.linear1(x);  x = None
    add = mlp1_linear1 + mlp1_linear1_1;  mlp1_linear1 = mlp1_linear1_1 = None
    add_1 = add + mlp2_linear1;  add = mlp2_linear1 = None
    add_2 = add_1 + mlp2_linear1_1;  add_1 = mlp2_linear1_1 = None
    return add_2

# with activation checkpoint
def forward(self, x):
    def checkpoint_0(x):
        mlp1_linear1 = self.mlp1.linear1(x)
        mlp1_linear1_1 = self.mlp1.linear1(x)
        return mlp1_linear1, mlp1_linear1_1
    mlp1_linear1, mlp1_linear1_1 = torch.utils.checkpoint.checkpoint(checkpoint_0, x)
    def checkpoint_1(x):
        mlp2_linear1 = self.mlp2.linear1(x)
        mlp2_linear1_1 = self.mlp2.linear1(x);  x = None
        return mlp2_linear1, mlp2_linear1_1
    mlp2_linear1, mlp2_linear1_1 = torch.utils.checkpoint.checkpoint(checkpoint_1, x)
    add = mlp1_linear1 + mlp1_linear1_1;  mlp1_linear1 = mlp1_linear1_1 = None
    add_1 = add + mlp2_linear1;  add = mlp2_linear1 = None
    add_2 = add_1 + mlp2_linear1_1;  add_1 = mlp2_linear1_1 = None
    return add_2

As the codegen is only provided in torch 1.12 and the CI is based on torch 1.11, the test is skipped in pytest but the full local test log is shown below.

Screenshot 2022-07-22 at 4 46 25 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants