Skip to content

Add function to_sequential to PipelineModule#1014

Open
sdtblck wants to merge 8 commits intodeepspeedai:masterfrom
sdtblck:patch-2
Open

Add function to_sequential to PipelineModule#1014
sdtblck wants to merge 8 commits intodeepspeedai:masterfrom
sdtblck:patch-2

Conversation

@sdtblck
Copy link
Contributor

@sdtblck sdtblck commented Apr 28, 2021

In https://github.com/EleutherAI/gpt-neox we were previously maintaining two separate models - one if the user wanted to use pipeline parallel, and one if they didn't.

The more straightforward solution was to add a to_sequential function to export the PipelineModule as an nn.Sequential model, so we could train with deepspeed features that aren't compatible with pipe parallel (i.e ZeRO 2+).

Figure this might be a useful addition to the base module, too. I'm not 100% sure if the support for tied layers here is as flexible as it could / should be, since their capabilities are not very well documented, but it works at least for our purposes (with tied Embeddings as the output layer).

In https://github.com/EleutherAI/gpt-neox we were previously maintaining two separate models - one if the user wanted to use pipeline parallel, and one if they didn't.

The more straightforward solution was to add a `to_sequential` function to export the PipelineModule as an nn.Sequential model, so we could train with deepspeed features that aren't compatible with pipe parallel (i.e ZeRO 2+).

Figure this might be a useful addition to the base module, too. I'm not 100% sure if the support for tied layers here is as flexible as it could / should be, since their capabilities are not very well documented, but it works at least for our purposes (with tied Embeddings as the output layer).
Copy link
Contributor

@ShadenSmith ShadenSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea, thanks @sdtblck !

One caveat is that we lose the activation checkpointing that the PipelineModule's forward can be configured to use. But users can instead use torch's checkpoint_sequential() if they want checkpointing. Or we could wrap the layers in a similar way as Lambda if we really want to mirror functionality. What are your thoughts?

else:
# check that it's a lambda function
LAMBDA = lambda:0
if isinstance(spec, type(LAMBDA)) and spec.__name__ == LAMBDA.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PipelineModule should work with any callable object, and I think the Lambda module above will too. Maybe the filtering condition could be hasattr(spec, '__call__') to support things like named methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point yes! I'll make that change too

@ShadenSmith
Copy link
Contributor

In addition to to_sequential, there may be another way we could accomplish this while keeping the normal PipelineModule, if that would be useful.

If we short-circuit this condition and use the regular training engine, I think that PipelineModule should behave as a normal torch.nn.Module and you can use ZeRO-2, etc. I intended for that to be the case, but not tested these days.

https://github.com/microsoft/DeepSpeed/blob/dad26428e3f28898b8d0f5ace1b3df3e6db8f8e8/deepspeed/__init__.py#L119-L120

@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 30, 2021

In addition to to_sequential, there may be another way we could accomplish this while keeping the normal PipelineModule, if that would be useful.

If we short-circuit this condition and use the regular training engine, I think that PipelineModule should behave as a normal torch.nn.Module and you can use ZeRO-2, etc. I intended for that to be the case, but not tested these days.

https://github.com/microsoft/DeepSpeed/blob/dad26428e3f28898b8d0f5ace1b3df3e6db8f8e8/deepspeed/__init__.py#L119-L120

Hi @ShadenSmith , I actually tried this as well - and it seems this way of doing things drops any tied modules (since the pipe engine handles them specially.) So for example, if we used this with a model with tied embeddings, the to_logits function that uses the word embedding weights would just get silently dropped.

@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 30, 2021

This is a great idea, thanks @sdtblck !

One caveat is that we lose the activation checkpointing that the PipelineModule's forward can be configured to use. But users can instead use torch's checkpoint_sequential() if they want checkpointing. Or we could wrap the layers in a similar way as Lambda if we really want to mirror functionality. What are your thoughts?

Hm. Yeah this is a good point that I had overlooked. I'll spend some time looking into the best way to get this working today.

sdtblck added 2 commits April 30, 2021 16:11
Used to convert a deepspeed PipelineModule to an nn.Sequential like model whilst retaining activation checkpointing.
@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 30, 2021

Hi @ShadenSmith

I think the two latest commits should fix both the above requirements. There is maybe some repeated code between SequentialModel and PipelineModule that could be slimmed down - but I have tested with gpt-neox and it works well.

@rocm-mici
Copy link

Can one of the admins verify this patch?

@jeffra jeffra requested a review from duli2012 as a code owner June 23, 2023 21:31
@ShadenSmith ShadenSmith self-assigned this Aug 18, 2023
@loadams
Copy link
Collaborator

loadams commented Nov 14, 2023

@sdtblck - just fixed some formatting issues that were preventing this - if the tests pass, would this be good to merge now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants