Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] zero.Init and composable models break and memory leak #2617

Open
stas00 opened this issue Dec 16, 2022 · 5 comments
Open

[BUG] zero.Init and composable models break and memory leak #2617

stas00 opened this issue Dec 16, 2022 · 5 comments
Assignees
Labels
bug Something isn't working training

Comments

@stas00
Copy link
Contributor

stas00 commented Dec 16, 2022

Describe the bug

at HF's m4 we have models that are composed of 2 pretrained models and some new layers added to be trained. Everything works under stage 2, but zero.Init under stage3 breaks under zero.Init with those composed models.

To Reproduce

add this to tests/unit/runtime/zero/test_zero.py:

class TestZeroNestedHFModel(DistributedTest):
    world_size = 1

    def test(self):
        config_dict = {
            "train_batch_size": 4,
            "zero_optimization": {
                "stage": 3
            }
        }
        hidden_dim = 10

        from transformers import AutoConfig, AutoModel
        mname = "hf-internal-testing/tiny-random-vit"

        class MyModel2(torch.nn.Module):
            def __init__(self, hidden_dim):
                super(MyModel2, self).__init__()
                self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
                self.cel = torch.nn.CrossEntropyLoss()

                config = AutoConfig.from_pretrained(mname)
                self.model1 = AutoModel.from_config(config)

            def forward(self, x, y):
                x = self.model1(x)
                x = self.l1(x)
                loss = self.cel(x, y)
                val = (x, loss)
                return val

        with deepspeed.zero.Init(config_dict_or_path=config_dict):
            model2 = MyModel2(hidden_dim)

now running it:

$ pytest tests/unit/runtime/zero/test_zero.py -k HF
[...]
  File "/mnt/nvme0/code/github/00optimize/deepspeed/tests/unit/runtime/zero/test_zero.py", line 1427, in test
    model2 = MyModel2(hidden_dim)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/zero/partition_parameters.py", line 419, in __exit__
    shutdown_init_context()
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/zero/partition_parameters.py", line 460, in shutdown_init_context
    _disable_class(subclass)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/zero/partition_parameters.py", line 456, in _disable_class
    cls.__init__ = cls._old_init
AttributeError: type object 'ClippedGELUActivation' has no attribute '_old_init'

It can be one of the many classes that fail here - it's sort of random name. But it's a name from the nested model, hf-internal-testing/tiny-random-vit in this case.

It's easy to see that this code didn't run for these submodules (from enabling the debug print there)

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
# print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
_enable_class(subclass)

Even if I hack from_pretrained to not run the nested zero.Init context the issue still happens because of the _enable_class in the snippet above didn't run regardless of the context.

oh, actually in this test I shared, from_pretrained isn't set up to run zero.Init anyway, so really the issue is simply that zero.Init doesn't seem to be able to handle such composed models.

so no attribute '_old_init' is a symptom and shouldn't be fixed by itself - we need to figure out why the "external" model didn't get the zero.Init treatment, even though it's being created inside this context.

@tjruwase, jeffra

@stas00 stas00 added bug Something isn't working training labels Dec 16, 2022
@stas00 stas00 changed the title [BUG] nested zero.Init issues [BUG] zero.Init and composable models break Dec 16, 2022
@stas00
Copy link
Contributor Author

stas00 commented Dec 16, 2022

Hmm, preparing the "external" model ahead of time seems to solve it:

class TestZeroSmart(DistributedTest):
    world_size = 1

    def test(self):
        config_dict = {
            "train_batch_size": 4,
            "zero_optimization": {
                "stage": 3
            }
        }
        hidden_dim = 10

        from transformers import AutoConfig, AutoModel
        mname = "hf-internal-testing/tiny-random-vit"

        config = AutoConfig.from_pretrained(mname)
        model1 = AutoModel.from_config(config)

        class MyModel2(torch.nn.Module):
            def __init__(self, model1, hidden_dim):
                super(MyModel2, self).__init__()
                self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
                self.cel = torch.nn.CrossEntropyLoss()
                self.model1 = model1

            def forward(self, x, y):
                x = self.model1(x)
                x = self.l1(x)
                loss = self.cel(x, y)
                val = (x, loss)
                return val

        with deepspeed.zero.Init(config_dict_or_path=config_dict):
            model2 = MyModel2(model1, hidden_dim)

you can see I moved the model1 creation to the outside of model2's __init__

@tjruwase tjruwase self-assigned this Dec 16, 2022
@stas00
Copy link
Contributor Author

stas00 commented Dec 19, 2022

@tjruwase, so the above works, but leaks memory on each training iteration like there is no tomorrow.

It appears to partition these "external" weights if I pass the model, but something is very wrong. the leak is very sizeable (and depends on the model size).

If I turn zero.Init off, stage 3 works fine - no leaking.

I did various large/small outside model compositions and it appears that the leak size depends on the size of the main model, rather than the included model. So I suspect the whole zero.Init wiring gets confused somewhere and instead of freeing pre-fetched weights they remain allocated somehow.

@stas00
Copy link
Contributor Author

stas00 commented Dec 19, 2022

Also, I figured out how to solve the initial reproducible test I shared in the OP, where the external model is created directly in __init__ and not passed as an argument. It works if you add on top:

from transformers import ViTModel

which in this example AutoModel resolves to.

The problem happens because when the submodule init overrides are done in zero.Init, ViTModel and its subclasses aren't loaded and so they don't get registered here:

# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
# print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
_enable_class(subclass)

So one can't use AutoModel w/o explicitly importing the class method it gets resolved to.

But alas, the leak is the same.

@stas00 stas00 changed the title [BUG] zero.Init and composable models break [BUG] zero.Init and composable models break and memory leak Dec 19, 2022
@tjruwase tjruwase assigned tohtana and unassigned tjruwase May 15, 2023
@tohtana
Copy link
Contributor

tohtana commented Jun 30, 2023

@stas00 I confirmed that the test case you showed (TestZeroNestedHFModel) on this issue ran without error after #3592 was merged.
Can you try it with the latest revision on master branch?

@stas00
Copy link
Contributor Author

stas00 commented Jun 30, 2023

super, thank you, @tohtana! If the test I supplied works, then all is wonderful!

Tunji and I are working through something similar at the moment so we will be thoroughly testing this shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants