Add function to_sequential to PipelineModule#1014
Add function to_sequential to PipelineModule#1014sdtblck wants to merge 8 commits intodeepspeedai:masterfrom
to_sequential to PipelineModule#1014Conversation
In https://github.com/EleutherAI/gpt-neox we were previously maintaining two separate models - one if the user wanted to use pipeline parallel, and one if they didn't. The more straightforward solution was to add a `to_sequential` function to export the PipelineModule as an nn.Sequential model, so we could train with deepspeed features that aren't compatible with pipe parallel (i.e ZeRO 2+). Figure this might be a useful addition to the base module, too. I'm not 100% sure if the support for tied layers here is as flexible as it could / should be, since their capabilities are not very well documented, but it works at least for our purposes (with tied Embeddings as the output layer).
ShadenSmith
left a comment
There was a problem hiding this comment.
This is a great idea, thanks @sdtblck !
One caveat is that we lose the activation checkpointing that the PipelineModule's forward can be configured to use. But users can instead use torch's checkpoint_sequential() if they want checkpointing. Or we could wrap the layers in a similar way as Lambda if we really want to mirror functionality. What are your thoughts?
deepspeed/runtime/pipe/module.py
Outdated
| else: | ||
| # check that it's a lambda function | ||
| LAMBDA = lambda:0 | ||
| if isinstance(spec, type(LAMBDA)) and spec.__name__ == LAMBDA.__name__: |
There was a problem hiding this comment.
PipelineModule should work with any callable object, and I think the Lambda module above will too. Maybe the filtering condition could be hasattr(spec, '__call__') to support things like named methods?
There was a problem hiding this comment.
good point yes! I'll make that change too
|
In addition to If we short-circuit this condition and use the regular training engine, I think that |
Hi @ShadenSmith , I actually tried this as well - and it seems this way of doing things drops any tied modules (since the pipe engine handles them specially.) So for example, if we used this with a model with tied embeddings, the to_logits function that uses the word embedding weights would just get silently dropped. |
Hm. Yeah this is a good point that I had overlooked. I'll spend some time looking into the best way to get this working today. |
Used to convert a deepspeed PipelineModule to an nn.Sequential like model whilst retaining activation checkpointing.
|
Hi @ShadenSmith I think the two latest commits should fix both the above requirements. There is maybe some repeated code between |
|
Can one of the admins verify this patch? |
|
@sdtblck - just fixed some formatting issues that were preventing this - if the tests pass, would this be good to merge now? |
In https://github.com/EleutherAI/gpt-neox we were previously maintaining two separate models - one if the user wanted to use pipeline parallel, and one if they didn't.
The more straightforward solution was to add a
to_sequentialfunction to export the PipelineModule as an nn.Sequential model, so we could train with deepspeed features that aren't compatible with pipe parallel (i.e ZeRO 2+).Figure this might be a useful addition to the base module, too. I'm not 100% sure if the support for tied layers here is as flexible as it could / should be, since their capabilities are not very well documented, but it works at least for our purposes (with tied Embeddings as the output layer).