From 807480a04bdc371525776e42c31f7c72d9132f0a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 14 Feb 2020 17:04:15 -0800 Subject: [PATCH] Fix issue with empty grads for non-fused optimizers (#83) bug fixes for adamw/lamb and corresponding tests --- deepspeed/pt/fp16_unfused_optimizer.py | 16 +- tests/unit/test_fp16.py | 208 +++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 3 deletions(-) mode change 100644 => 100755 deepspeed/pt/fp16_unfused_optimizer.py create mode 100644 tests/unit/test_fp16.py diff --git a/deepspeed/pt/fp16_unfused_optimizer.py b/deepspeed/pt/fp16_unfused_optimizer.py old mode 100644 new mode 100755 index 2a15ca20167c..3e7b998658e3 --- a/deepspeed/pt/fp16_unfused_optimizer.py +++ b/deepspeed/pt/fp16_unfused_optimizer.py @@ -116,8 +116,13 @@ def step_fused_lamb(self, closure=None): grads_groups = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): - grads_groups.append([p.grad for p in group]) - grads_groups_flat.append(_flatten_dense_tensors(grads_groups[i])) + grads = [ + torch.zeros(p.size(), + dtype=p.dtype, + device=p.device) if p.grad is None else p.grad for p in group + ] + grads_groups.append(grads) + grads_groups_flat.append(_flatten_dense_tensors(grads)) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) @@ -162,7 +167,12 @@ def step(self, closure=None): # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): - fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) + if fp16_param.grad is None: + fp32_param.grad = torch.zeros(fp16_param.size(), + dtype=fp32_param.dtype, + device=fp32_param.device) + else: + fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py new file mode 100644 index 000000000000..0c69f097de60 --- /dev/null +++ b/tests/unit/test_fp16.py @@ -0,0 +1,208 @@ +import torch +import deepspeed +import argparse +import pytest +import json +import os +from common import distributed_test + + +def create_config_from_dict(tmpdir, config_dict): + config_path = os.path.join(tmpdir, 'temp_config.json') + with open(config_path, 'w') as fd: + json.dump(config_dict, fd) + return config_path + + +class SimpleModel(torch.nn.Module): + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleModel, self).__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + if empty_grad: + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden_dim = x + hidden_dim = self.linear(hidden_dim) + return self.cross_entropy_loss(hidden_dim, y) + + +def test_temp_config_json(tmpdir): + config_dict = { + "train_batch_size": 1, + } + config_path = create_config_from_dict(tmpdir, config_dict) + config_json = json.load(open(config_path, 'r')) + assert 'train_batch_size' in config_json + + +def prepare_optimizer_parameters(model): + param_optimizer = list(model.named_parameters()) + optimizer_grouped_parameters = [{ + 'params': [p for n, + p in param_optimizer], + 'weight_decay': 0.0 + }] + return optimizer_grouped_parameters + + +def get_data_loader(model, total_samples, hidden_dim, device): + batch_size = model.train_micro_batch_size_per_gpu() + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) + train_label = torch.empty(total_samples, + dtype=torch.long, + device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + return train_loader + + +def get_args(tmpdir, config_dict): + config_path = create_config_from_dict(tmpdir, config_dict) + parser = argparse.ArgumentParser() + args = parser.parse_args(args='') + args.deepspeed = True + args.deepspeed_config = config_path + args.local_rank = 0 + return args + + +def test_lamb_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015, + "max_grad_norm": 1.0 + } + }, + "fp16": { + "enabled": True + } + } + args = get_args(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_lamb_fp16_basic(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=False) + data_loader = get_data_loader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_lamb_fp16_empty_grad(tmpdir): + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015, + "max_grad_norm": 1.0 + } + }, + "fp16": { + "enabled": True + } + } + args = get_args(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=True) + + @distributed_test(world_size=[1]) + def _test_lamb_fp16_empty_grad(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=False) + data_loader = get_data_loader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) + + +def test_adamw_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + } + } + args = get_args(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_adamw_fp16_basic(args, model, hidden_dim): + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + dist_init_required=False) + data_loader = get_data_loader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_adamw_fp16_empty_grad(tmpdir): + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + } + } + args = get_args(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=True) + + @distributed_test(world_size=[1]) + def _test_adamw_fp16_empty_grad(args, model, hidden_dim): + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + dist_init_required=False) + data_loader = get_data_loader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)