Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] add rules to linearize computation graphs for searching. #1461

Merged
merged 22 commits into from
Aug 17, 2022
Merged

[fx] add rules to linearize computation graphs for searching. #1461

merged 22 commits into from
Aug 17, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Aug 16, 2022

What's new?

Since in most existing frameworks of activation checkpoints, the forward graph is assumed to be linearized, I developed a small tool function to detect all potential checkpoint nodes. In this way, chen_greedy() will operate in a linearized manner, without checkpointing something within the skip connection blocks. I also removed chen_sqrtn() because it is a bit outdated.
And hopefully, with the new checkpoint function in Colossal-AI (#1460), we can now do linearized searches on arbitrary computation graphs, such as resnet18()

def forward(self, x : torch.Tensor) -> torch.Tensor:
    import colossalai
    conv1 = self.conv1(x);  x = None
    def checkpoint_0(conv1):
        bn1 = self.bn1(conv1);  conv1 = None
        relu = self.relu(bn1);  bn1 = None
        maxpool = self.maxpool(relu);  relu = None
        layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
        layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1);  layer1_0_conv1 = None
        layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1);  layer1_0_bn1 = None
        layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu);  layer1_0_relu = None
        layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2);  layer1_0_conv2 = None
        add = layer1_0_bn2 + maxpool;  layer1_0_bn2 = maxpool = None
        return add
    add = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, conv1, use_reentrant=False)
    def checkpoint_1(add):
        layer1_0_relu_1 = getattr(self.layer1, "0").relu(add);  add = None
        layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
        layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1);  layer1_1_conv1 = None
        layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1);  layer1_1_bn1 = None
        layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu);  layer1_1_relu = None
        layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2);  layer1_1_conv2 = None
        add_1 = layer1_1_bn2 + layer1_0_relu_1;  layer1_1_bn2 = layer1_0_relu_1 = None
        layer1_1_relu_1 = getattr(self.layer1, "1").relu(add_1);  add_1 = None
        layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
        layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1);  layer2_0_conv1 = None
        layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1);  layer2_0_bn1 = None
        layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu);  layer2_0_relu = None
        layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2);  layer2_0_conv2 = None
        layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1);  layer1_1_relu_1 = None
        layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0);  layer2_0_downsample_0 = None
        add_2 = layer2_0_bn2 + layer2_0_downsample_1;  layer2_0_bn2 = layer2_0_downsample_1 = None
        layer2_0_relu_1 = getattr(self.layer2, "0").relu(add_2);  add_2 = None
        layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
        layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1);  layer2_1_conv1 = None
        layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1);  layer2_1_bn1 = None
        layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu);  layer2_1_relu = None
        layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2);  layer2_1_conv2 = None
        add_3 = layer2_1_bn2 + layer2_0_relu_1;  layer2_1_bn2 = layer2_0_relu_1 = None
        return add_3
    add_3 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, add, use_reentrant=False)
    layer2_1_relu_1 = getattr(self.layer2, "1").relu(add_3);  add_3 = None
    layer3_0_conv1 = getattr(self.layer3, "0").conv1(layer2_1_relu_1)
    layer3_0_bn1 = getattr(self.layer3, "0").bn1(layer3_0_conv1);  layer3_0_conv1 = None
    layer3_0_relu = getattr(self.layer3, "0").relu(layer3_0_bn1);  layer3_0_bn1 = None
    layer3_0_conv2 = getattr(self.layer3, "0").conv2(layer3_0_relu);  layer3_0_relu = None
    layer3_0_bn2 = getattr(self.layer3, "0").bn2(layer3_0_conv2);  layer3_0_conv2 = None
    layer3_0_downsample_0 = getattr(getattr(self.layer3, "0").downsample, "0")(layer2_1_relu_1);  layer2_1_relu_1 = None
    layer3_0_downsample_1 = getattr(getattr(self.layer3, "0").downsample, "1")(layer3_0_downsample_0);  layer3_0_downsample_0 = None
    add_4 = layer3_0_bn2 + layer3_0_downsample_1;  layer3_0_bn2 = layer3_0_downsample_1 = None
    layer3_0_relu_1 = getattr(self.layer3, "0").relu(add_4);  add_4 = None
    layer3_1_conv1 = getattr(self.layer3, "1").conv1(layer3_0_relu_1)
    layer3_1_bn1 = getattr(self.layer3, "1").bn1(layer3_1_conv1);  layer3_1_conv1 = None
    layer3_1_relu = getattr(self.layer3, "1").relu(layer3_1_bn1);  layer3_1_bn1 = None
    layer3_1_conv2 = getattr(self.layer3, "1").conv2(layer3_1_relu);  layer3_1_relu = None
    layer3_1_bn2 = getattr(self.layer3, "1").bn2(layer3_1_conv2);  layer3_1_conv2 = None
    add_5 = layer3_1_bn2 + layer3_0_relu_1;  layer3_1_bn2 = layer3_0_relu_1 = None
    layer3_1_relu_1 = getattr(self.layer3, "1").relu(add_5);  add_5 = None
    layer4_0_conv1 = getattr(self.layer4, "0").conv1(layer3_1_relu_1)
    layer4_0_bn1 = getattr(self.layer4, "0").bn1(layer4_0_conv1);  layer4_0_conv1 = None
    layer4_0_relu = getattr(self.layer4, "0").relu(layer4_0_bn1);  layer4_0_bn1 = None
    layer4_0_conv2 = getattr(self.layer4, "0").conv2(layer4_0_relu);  layer4_0_relu = None
    layer4_0_bn2 = getattr(self.layer4, "0").bn2(layer4_0_conv2);  layer4_0_conv2 = None
    layer4_0_downsample_0 = getattr(getattr(self.layer4, "0").downsample, "0")(layer3_1_relu_1);  layer3_1_relu_1 = None
    layer4_0_downsample_1 = getattr(getattr(self.layer4, "0").downsample, "1")(layer4_0_downsample_0);  layer4_0_downsample_0 = None
    add_6 = layer4_0_bn2 + layer4_0_downsample_1;  layer4_0_bn2 = layer4_0_downsample_1 = None
    layer4_0_relu_1 = getattr(self.layer4, "0").relu(add_6);  add_6 = None
    layer4_1_conv1 = getattr(self.layer4, "1").conv1(layer4_0_relu_1)
    layer4_1_bn1 = getattr(self.layer4, "1").bn1(layer4_1_conv1);  layer4_1_conv1 = None
    layer4_1_relu = getattr(self.layer4, "1").relu(layer4_1_bn1);  layer4_1_bn1 = None
    layer4_1_conv2 = getattr(self.layer4, "1").conv2(layer4_1_relu);  layer4_1_relu = None
    layer4_1_bn2 = getattr(self.layer4, "1").bn2(layer4_1_conv2);  layer4_1_conv2 = None
    add_7 = layer4_1_bn2 + layer4_0_relu_1;  layer4_1_bn2 = layer4_0_relu_1 = None
    layer4_1_relu_1 = getattr(self.layer4, "1").relu(add_7);  add_7 = None
    avgpool = self.avgpool(layer4_1_relu_1);  layer4_1_relu_1 = None
    flatten = torch.flatten(avgpool, 1);  avgpool = None
    fc = self.fc(flatten);  flatten = None
    return fc

super-dainiu and others added 16 commits August 9, 2022 23:23
* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.
Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see all the changes, currently I don't spot any mistakes, I will modify the codegen to check if we need to use use_reentrant=False when call checkpoint functions.

Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just hold the PR~

@super-dainiu
Copy link
Contributor Author

Hey, anyone to review this??

Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's check the if CI could pass

@super-dainiu super-dainiu merged commit e7383f5 into hpcaitech:main Aug 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants