Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add offload codegen #1598

Merged
merged 36 commits into from Sep 14, 2022
Merged

Conversation

Cypher30
Copy link
Contributor

What's New

Previously we have offload option in activation checkpoint region to support the offload input process, however, in the upcoming activation solver, we might offload the input of a node that is not inside any checkpoint region. Therefore, I use saved_tensors_hooks for this kind of offload manner.

As we haven't implemented the torch11 version of ColoGraphModule, I skip the unit test for this part and attach the results on torch12 below
Screen Shot 2022-09-14 at 14 35 33

Cypher30 and others added 30 commits July 14, 2022 16:07
@FrankLeeeee
Copy link
Contributor

Paste an example of the generated code here.

if node.name == "linear4":
setattr(node, "activation_checkpoint", [0])

gm = ColoGraphModule(copy.deepcopy(model), graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think copy is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy is needed because we want to test backward gradients.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i see.

@Cypher30
Copy link
Contributor Author

Paste an example of the generated code here.

Here is an example for the generated code:

def pack_hook(self, x):
    if getattr(x, "offload", None):
        return (x.device, x.cpu())
    else:
        return x

def unpack_hook(self, packed):
    if isinstance(packed, tuple):
        device, tensor = packed
        return tensor.to(device)
    else:
        return packed

def checkpoint_0(self, linear2):
    linear3 = self.linear3(linear2);  linear2 = None
    linear4 = self.linear4(linear3);  linear3 = None
    return linear4

def forward(self, x):
    linear1 = self.linear1(x);  x = None
    setattr(linear1, 'offload', True)
    with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
        linear2 = self.linear2(linear1);  linear1 = None
    linear4 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)
    linear5 = self.linear5(linear4);  linear4 = None
    return linear5

Comment on lines 153 to 154
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need two skips?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I could remove the first one~

@Cypher30 Cypher30 merged commit a7cda6f into hpcaitech:main Sep 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants