New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx] Add offload codegen #1598
[fx] Add offload codegen #1598
Conversation
Merge ColossalAI
Daily merge
Paste an example of the generated code here. |
if node.name == "linear4": | ||
setattr(node, "activation_checkpoint", [0]) | ||
|
||
gm = ColoGraphModule(copy.deepcopy(model), graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think copy is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy is needed because we want to test backward gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i see.
Here is an example for the generated code: def pack_hook(self, x):
if getattr(x, "offload", None):
return (x.device, x.cpu())
else:
return x
def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
def checkpoint_0(self, linear2):
linear3 = self.linear3(linear2); linear2 = None
linear4 = self.linear4(linear3); linear3 = None
return linear4
def forward(self, x):
linear1 = self.linear1(x); x = None
setattr(linear1, 'offload', True)
with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
linear2 = self.linear2(linear1); linear1 = None
linear4 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)
linear5 = self.linear5(linear4); linear4 = None
return linear5 |
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') | ||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need two skips?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I could remove the first one~
…ossalAI into feature/add_offload_codegen
What's New
Previously we have offload option in activation checkpoint region to support the offload input process, however, in the upcoming activation solver, we might offload the input of a node that is not inside any checkpoint region. Therefore, I use saved_tensors_hooks for this kind of offload manner.
As we haven't implemented the torch11 version of
ColoGraphModule
, I skip the unit test for this part and attach the results on torch12 below