New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] zero.Init
and composable models break and memory leak
#2617
Comments
zero.Init
issueszero.Init
and composable models break
Hmm, preparing the "external" model ahead of time seems to solve it:
you can see I moved the |
@tjruwase, so the above works, but leaks memory on each training iteration like there is no tomorrow. It appears to partition these "external" weights if I pass the model, but something is very wrong. the leak is very sizeable (and depends on the model size). If I turn I did various large/small outside model compositions and it appears that the leak size depends on the size of the main model, rather than the included model. So I suspect the whole |
Also, I figured out how to solve the initial reproducible test I shared in the OP, where the external model is created directly in
which in this example The problem happens because when the submodule init overrides are done in DeepSpeed/deepspeed/runtime/zero/partition_parameters.py Lines 376 to 379 in 7e2103f
So one can't use But alas, the leak is the same. |
zero.Init
and composable models breakzero.Init
and composable models break and memory leak
super, thank you, @tohtana! If the test I supplied works, then all is wonderful! Tunji and I are working through something similar at the moment so we will be thoroughly testing this shortly. |
Describe the bug
at HF's m4 we have models that are composed of 2 pretrained models and some new layers added to be trained. Everything works under stage 2, but
zero.Init
under stage3 breaks underzero.Init
with those composed models.To Reproduce
add this to
tests/unit/runtime/zero/test_zero.py
:now running it:
It can be one of the many classes that fail here - it's sort of random name. But it's a name from the nested model,
hf-internal-testing/tiny-random-vit
in this case.It's easy to see that this code didn't run for these submodules (from enabling the debug print there)
DeepSpeed/deepspeed/runtime/zero/partition_parameters.py
Lines 376 to 379 in 7e2103f
Even if I hack
from_pretrained
to not run the nestedzero.Init
context the issue still happens because of the_enable_class
in the snippet above didn't run regardless of the context.oh, actually in this test I shared,
from_pretrained
isn't set up to runzero.Init
anyway, so really the issue is simply thatzero.Init
doesn't seem to be able to handle such composed models.so
no attribute '_old_init'
is a symptom and shouldn't be fixed by itself - we need to figure out why the "external" model didn't get thezero.Init
treatment, even though it's being created inside this context.@tjruwase, jeffra
The text was updated successfully, but these errors were encountered: