Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix step in adam #1823

Merged
merged 12 commits into from
May 9, 2022
Merged

fix step in adam #1823

merged 12 commits into from
May 9, 2022

Conversation

szhengac
Copy link
Contributor

deepspeed zero performs optimization step on each sub-group independently, but adam implementation uses a single step for the whole param group so different sub-groups will use different steps to compute exponential weight for first-order and second-order momentum. This pr fixes this issue.

tjruwase
tjruwase previously approved these changes Mar 17, 2022
@tjruwase
Copy link
Contributor

@szhengac, thanks!

@tjruwase tjruwase self-requested a review March 18, 2022 11:13
@tjruwase tjruwase dismissed their stale review March 18, 2022 11:14

Just became aware of issues I had overlooked.

@tjruwase
Copy link
Contributor

@szhengac, apologies I am reverting my approval because my colleagues just made me aware of some issues that require a bit more time to understand. Primarily, this fix changes the behavior of non zero stage 3 logic and will break existing checkpoints since it changes checkpoint state. Ideally, since only zero stage 3 exhibits a problem then the fix should probably be in stage 3 rather than in shared fused adam code. Apologies for the confusion.

@tjruwase
Copy link
Contributor

@szhengac, could you please provide a unit test that demonstrates the original zero stage 3 problem? Thanks!

@szhengac
Copy link
Contributor Author

@tjruwase I can provide an unit test. The main problem of this is we will get the wrong optimizer behaviors since we will be using different step_id's for different parameters at the same iteration. I checked the implementation of other two optimizers in DeepSpeed, which also use state to store the step count. We also need to make sure the implementations of optimizers are consistent.

lamb: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/lamb/fused_lamb.py#L163
adagrad: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/adagrad/cpu_adagrad.py#L89

@tjruwase
Copy link
Contributor

@szhengac, how about making these two changes for backwards compatibility for existing stage < 3 checkpoints.

  1. Restore the group level step
  2. Initialize param level step from group level if one exists, rather than initializing to zero

I will annotate the PR with these suggestions.

1 similar comment
@tjruwase
Copy link
Contributor

@szhengac, how about making these two changes for backwards compatibility for existing stage < 3 checkpoints.

  1. Restore the group level step
  2. Initialize param level step from group level if one exists, rather than initializing to zero

I will annotate the PR with these suggestions.

@@ -131,6 +124,8 @@ def step(self,
state = self.state[p]
# State initialization
if len(state) == 0:
# DeepSpeed processes each subgroup a time, so we need to keep tracking step for each tensor separately
state['step'] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state['step'] = group.get('step', 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good idea. I will also put a more detailed comment to explain it otherwise other readers may get confused.

@@ -131,6 +124,8 @@ def step(self,
state = self.state[p]
# State initialization
if len(state) == 0:
# DeepSpeed processes each subgroup a time, so we need to keep tracking step for each tensor separately
state['step'] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state['step'] = group.get('step', 0)

@szhengac
Copy link
Contributor Author

@tjruwase any update on this PR?

@tjruwase
Copy link
Contributor

@tjruwase any update on this PR?

This looks good to me. Is it possible for you to do some e2e comparing loss curves before and after this change for zero < 3?

@szhengac
Copy link
Contributor Author

szhengac commented Apr 19, 2022

@tjruwase Sure, will do it when I have capacity for the experiments. I have done some e2e finetuning with zero 3 after the fix, it improves accuracy by 6-7% when there are two param groups.

@tjruwase
Copy link
Contributor

@tjruwase Sure, will do it when I have capacity for the experiments. I have done some e2e finetuning with zero 3 after the fix, it improves accuracy by 6-7% when there are two param groups.

That is exciting news about the zero3 improvement :).

@tjruwase
Copy link
Contributor

@szhengac, just checking if you had time to push this? Thanks!

@ghost
Copy link

ghost commented Apr 28, 2022

CLA assistant check
Thank you for your submission, we really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.

❌ szhengac sign now
You have signed the CLA already but the status is still pending? Let us recheck it.

@szhengac
Copy link
Contributor Author

@tjruwase the conflict has been resolved.

@tjruwase
Copy link
Contributor

@szhengac, thanks for resolving the conflict. Did you get a chance to measure convergence impact for zero < 3? Thanks!

@szhengac
Copy link
Contributor Author

szhengac commented May 3, 2022

@szhengac, thanks for resolving the conflict. Did you get a chance to measure convergence impact for zero < 3? Thanks!

I haven't got extra machine capacity to test it.

@szhengac
Copy link
Contributor Author

szhengac commented May 9, 2022

@tjruwase I have done some loss testing and I can obtain exactly the same loss curve.

@szhengac
Copy link
Contributor Author

szhengac commented May 9, 2022

But at the same time, I also realized zero 1 and 2 give different loss trajectories regardless of this PR. Similar issue was reported before: #966

@tjruwase
Copy link
Contributor

tjruwase commented May 9, 2022

But at the same time, I also realized zero 1 and 2 give different loss trajectories regardless of this PR. Similar issue was reported before: #966

I thought for that issue we had resolved the loss divergence across the zero stages, right?

@szhengac
Copy link
Contributor Author

szhengac commented May 9, 2022

I thought so. But the following script gives me different results:

import os
import json
import argparse
import torch
import deepspeed
from torch import nn
from apex.normalization import FusedLayerNorm as BertLayerNorm
from torch.utils.data.distributed import DistributedSampler


class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_dim, empty_grad=False, zero=0):
        super(SimpleModel, self).__init__()
        linear = torch.nn.Linear(hidden_dim, hidden_dim)
        FinalLayerNorm = BertLayerNorm(hidden_dim, eps=1e-12)
        mlp = [linear, FinalLayerNorm]
        mlp.append(torch.nn.Linear(hidden_dim, hidden_dim//2))
        for _ in range(6):
            l = torch.nn.Linear(hidden_dim//2, hidden_dim//2)
            mlp.append(l)
        mlp.append(torch.nn.Linear(hidden_dim//2, hidden_dim))
        l = torch.nn.Linear(hidden_dim, hidden_dim)
        l.weight = linear.weight
        l.bias = linear.bias
        mlp.append(l)
        if zero == 3:
            deepspeed.zero.register_external_parameter(self, linear.weight)
            deepspeed.zero.register_external_parameter(self, linear.bias)
        self.mlp = nn.Sequential(*mlp)
        if empty_grad:
            self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        hidden_dim = x
        hidden_dim = self.mlp(hidden_dim)
        return self.cross_entropy_loss(hidden_dim, y)


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path

def get_data_loader(model, total_samples, hidden_dim, device, dtype, local_rank):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               sampler=sampler)
    return train_loader


def get_args(tmpdir, config_dict):
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument('--zero', type=int, default=0)
    args = parser.parse_args()  #args=''

    config_dict["zero_optimization"]["stage"] = args.zero
    print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
    config_path = create_config_from_dict(tmpdir, config_dict)

    args.deepspeed_config = config_path
    return args


def print0(msg):
    if torch.distributed.get_rank() == 0:
        print(msg, flush=True)


rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)

config_dict = {
    "train_batch_size": 16*2*8,
    "train_micro_batch_size_per_gpu": 16,
    "steps_per_print": 1,
    "zero_allow_untested_optimizer": True,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.01,
            "weight_decay": 0.01,
            "bias_correction": True,
            "eps": 1e-6
        }
    },
    "gradient_clipping": 0.1,
    "fp16": {
        "enabled": False,
        "initial_scale_power": 10
    },
    "bfloat16": {
        "enabled": True
    },
    "zero_optimization": {
        "stage": 1,
        "overlap_comm": False,
        "contiguous_gradients": False,
        "reduce_bucket_size": 20,
        "stage3_param_persistence_threshold" : 1e6,
        "stage3_param_persistence_threshold": 1e6,
        "stage3_max_reuse_distance" : 1e6,
        "stage3_max_live_parameters": 1e6,
        "stage3_prefetch_bucket_size": 5e8
    }
}

args = get_args('/tmp/', config_dict)
hidden_dim = 1027


#with deepspeed.zero.Init(config=config_dict, enabled=True):
model = SimpleModel(hidden_dim, empty_grad=False, zero=args.zero)

model, _, _,_ = deepspeed.initialize(args=args,
                                     model=model,
                                     model_parameters=model.parameters(),
                                     dist_init_required=True)

data_loader = get_data_loader(model=model,
                              total_samples=50000,
                              hidden_dim=hidden_dim,
                              device=model.device,
                              dtype=torch.bfloat16,
                              local_rank=args.local_rank)


def print_params(tag, model):
    if torch.distributed.get_rank() == 0:
        for n, p in model.named_parameters():
            print0("{} {}:{}".format(tag, n, p))

#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
    loss = model(batch[0], batch[1])
    #if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
    model.backward(loss)
    model.step()
    if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
        print("{}, LOSS: {}".format(n, loss.item()))
    #print_params('step={}'.format(n), model)
    if n == 32: break

@tjruwase
Copy link
Contributor

tjruwase commented May 9, 2022

I thought so. But the following script gives me different results:

Got it. This means we have a regression in DeepSpeed and is independent of this PR. Could you please create a new issue with your test case? Thanks so much for sharing this test case, in addition to this PR. The PR is good to merge.

@tjruwase tjruwase merged commit de88718 into microsoft:master May 9, 2022
@szhengac szhengac deleted the optimizer branch May 10, 2022 00:30
@szhengac
Copy link
Contributor Author

@tjruwase I have opened an issue: #1945

igor0 pushed a commit to igor0/DeeperSpeed that referenced this pull request Nov 14, 2022
* fix step in adam

* fix backward compatibility and add unittest

* add unittest

* fix unbounded error when there are more than 1 param groups

* fix typo

* remove trailing whitespace

* fix end of file

Co-authored-by: Shuai Zheng <shzheng@amazon.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
igor0 added a commit to igor0/DeeperSpeed that referenced this pull request Nov 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants