Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepSpeed] restore memory for evaluation #10114

Merged
merged 4 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 28 additions & 31 deletions examples/tests/deepspeed/ds_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,37 @@
"min_loss_scale": 1
},

"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},

"zero_allow_untested_optimizer": true,
"zero_allow_untested_optimizer": true,

"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},

"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},

"steps_per_print": 2000,
"wall_clock_breakdown": false
Expand Down
83 changes: 46 additions & 37 deletions examples/tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
import unittest

from transformers.integrations import is_deepspeed_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
from transformers.trainer_callback import TrainerState
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from transformers.trainer_utils import set_seed


Expand All @@ -42,45 +48,53 @@ def require_deepspeed(test_case):
return test_case


@slow
@require_deepspeed
@require_torch_gpu
class TestDeepSpeed(TestCasePlus):

# XXX: need to do better validation beyond just that the run was successful
def run_quick(self, distributed=None, extra_args_str=None, remove_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats

def run_quick_no_train(self, distributed=None, extra_args_str=None):
remove_args_str = "--do_train"
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, remove_args_str)
val_metrics = load_json(os.path.join(output_dir, "val_results.json"))
assert "val_bleu" in val_metrics
test_metrics = load_json(os.path.join(output_dir, "test_results.json"))
assert "test_bleu" in test_metrics

@require_torch_multi_gpu
def test_basic(self):
self.run_quick()
def test_basic_distributed(self):
self.run_quick(distributed=True)

@require_torch_multi_gpu
def test_grad_acum(self):
self.run_quick(extra_args_str="--gradient_accumulation_steps 2")
self.run_quick(distributed=True, extra_args_str="--gradient_accumulation_steps 2")

@require_torch_multi_gpu
def test_no_train(self):
def test_do_eval_no_train(self):
# we should not fail if train is skipped
self.run_quick_no_train()
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
model_name=MBART_TINY,
num_train_epochs=1,
distributed=False,
extra_args_str="--do_eval",
remove_args_str="--do_train",
)
val_metrics = load_json(os.path.join(output_dir, "val_results.json"))
assert "val_bleu" in val_metrics

# XXX: need to do better validation beyond just that the run was successful
def run_quick(self, distributed=True, extra_args_str=None, remove_args_str=None):
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
model_name=MBART_TINY,
num_train_epochs=1,
distributed=distributed,
extra_args_str=extra_args_str,
remove_args_str=remove_args_str,
)
train_metrics = load_json(os.path.join(output_dir, "train_results.json"))
assert "train_runtime" in train_metrics

def run_trainer(
self,
eval_steps: int,
max_len: str,
model_name: str,
num_train_epochs: int,
distributed: bool = False,
distributed: bool = True,
extra_args_str: str = None,
remove_args_str: str = None,
):
Expand All @@ -97,26 +111,20 @@ def run_trainer(
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-3
--warmup_steps 8
--evaluation_strategy steps
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2

if extra_args_str is not None:
args.extend(extra_args_str.split())
Expand All @@ -126,12 +134,13 @@ def run_trainer(
args = [x for x in args if x not in remove_args]

ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
distributed_args = f"""
{self.test_file_dir}/../../seq2seq/finetune_trainer.py
""".split()
cmd = ["deepspeed"] + distributed_args + args + ds_args
script = [f"{self.examples_dir_str}/seq2seq/finetune_trainer.py"]
num_gpus = get_gpu_count() if distributed else 1
launcher = f"deepspeed --num_gpus {num_gpus}".split()

cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join(cmd)); die
# print(" ".join([f"PYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env())

return output_dir
24 changes: 15 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import collections
import gc
import inspect
import math
import os
Expand Down Expand Up @@ -266,8 +267,9 @@ def __init__(

# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway
if not (self.is_model_parallel or args.deepspeed):
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
if not self.is_model_parallel and not (args.deepspeed and args.do_train):
model = model.to(args.device)

# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
Expand Down Expand Up @@ -1036,6 +1038,14 @@ def train(
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()

if self.deepspeed:
# free up any memory that might be useful for eval
self.deepspeed = None
self.optimizer = None
self.lr_scheduler = None
self.model_wrapped = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super fond of having this done automatically. It should be in a free_memory method of the Trainer (or a name you like better) that is explicitly called by the user between training and evaluation IMO.
This is also useful beyond deepspeed.

Copy link
Contributor Author

@stas00 stas00 Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 2 different things.

  1. in the case of DeepSpeed this is a clean cut. We explicitly init all these at the beginning of train:
    if self.args.deepspeed:
    model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
    self.model = model.module
    self.model_wrapped = model # will get further wrapped in DDP
    self.deepspeed = model # DeepSpeedEngine object
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler

so this PR explicitly cleans these up at the end of train - this is completely opaque. A user had no way to init those explicitly and thus has no need to do anything special.

  1. wrt generic case, the story is different because the user may supply her own optimizer/lr_scheduler and in such case, yes, they need to have control over whether to clean up or not.

As you pointed out this would be useful to the user, but it's a different situation, so let's solve it separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though I do think I need to fix this:

            self.model_wrapped = None

to restore this to self.model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, it's true that in this case they are instantiated inside the train method so this makes sense.

gc.collect() # force memory release

return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
Expand Down Expand Up @@ -1593,13 +1603,9 @@ def prediction_loop(
)

if self.args.deepspeed and not self.args.do_train:
# In the future we probably can run deepspeed for inference too, but this will require
# some thinking about how to best run it - since while it works DeepSpeed wasn't
# designed for inference

# since we have to postpone model.to() till training for DeepSpeed, if there was no
# training, we must put the model on the right device
self.model = self.model.to(self.args.device)
# no harm, but flagging to the user that deepspeed config is ignored for eval
# flagging only for when --do_train wasn't passed as only then it's redundant
logger.info("Detected the deepspeed argument but it will not be used for evaluation")

model = self.model

Expand Down