Skip to content

Fix incorrect patch in zero.init#5921

Closed
VeryLazyBoy wants to merge 9 commits intodeepspeedai:masterfrom
VeryLazyBoy:fix-incorrect-patch-zero-init
Closed

Fix incorrect patch in zero.init#5921
VeryLazyBoy wants to merge 9 commits intodeepspeedai:masterfrom
VeryLazyBoy:fix-incorrect-patch-zero-init

Conversation

@VeryLazyBoy
Copy link
Copy Markdown
Contributor

The code below has a problem where cls.__init__ in line 525 can be modified before assignment to _old_init. This could lead to an incorrect __init__ being backed up:

https://github.com/microsoft/DeepSpeed/blob/ffe0af23575c4f03a07408eacfc50b1a58781429/deepspeed/runtime/zero/partition_parameters.py#L524-L534

Test Case

import deepspeed
from torch import nn


class ModelA(nn.Module):
    def __init__(self):
        super().__init__()


class ModelB(ModelA):
    pass


original_init = ModelA.__init__


ds_config = {
    'fp16': {'enabled': False},
    'bf16': {'enabled': True},
    'zero_optimization': {
        'stage': 3,
        'offload_optimizer': {
            'device': 'cpu',
            'pin_memory': True
        },
        'offload_param': {
            'device': 'cpu',
            'pin_memory': True
        },
    },
    'gradient_accumulation_steps': 1,
    'gradient_clipping': 1,
    'train_batch_size': 1,
    'train_micro_batch_size_per_gpu': 1
}


with deepspeed.zero.Init(config_dict_or_path=ds_config, enabled=True, mem_efficient_linear=False, mpu=None):
    model_a = ModelA()
    assert ModelA.__init__ != original_init

assert ModelA.__init__ == original_init
assert ModelB.__init__ == original_init   #  Fails here. If not, please try several times since it depends on the order of modifications

@VeryLazyBoy VeryLazyBoy requested a review from tjruwase as a code owner August 12, 2024 18:15
@VeryLazyBoy
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree

@VeryLazyBoy
Copy link
Copy Markdown
Contributor Author

A better solution is proposed to handle _init_subclass as well

@tjruwase tjruwase requested a review from tohtana August 15, 2024 17:53
@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Aug 21, 2024

Thank you @VeryLazyBoy for the great catch!

I think the issue is that we patch superclass's cls.__init__ when cls doesn't have its __init__. So I try another approach in this branch. Do you think if this works?
This is less intrusive as we do not set __init__ when cls doesn't have it.

@VeryLazyBoy
Copy link
Copy Markdown
Contributor Author

@tohtana Yes! Your approach is less intrusive and much better. Let's go ahead with this new method. Should I close this merge request?

@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Aug 21, 2024

@VeryLazyBoy Thank you for your response!
Let me create a PR using my branch to make sure it works. Let's close this PR after all test pass with the PR.

github-merge-queue bot pushed a commit that referenced this pull request Sep 4, 2024
This PR fixes an issue addressed in #5921.
With this change, we only apply the patch for parameter partitioning to
classes that have `__init__` so that we can avoid applying the patch
multiple times.
The class that does not have `__init__` now uses its superclass's one.
So this PR also applies the patch to the root class,
`torch.nn.modules.module.Module`.

Thanks @VeryLazyBoy for the report and initial solution.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Mar 20, 2025
This PR fixes an issue addressed in deepspeedai#5921.
With this change, we only apply the patch for parameter partitioning to
classes that have `__init__` so that we can avoid applying the patch
multiple times.
The class that does not have `__init__` now uses its superclass's one.
So this PR also applies the patch to the root class,
`torch.nn.modules.module.Module`.

Thanks @VeryLazyBoy for the report and initial solution.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants