-
Notifications
You must be signed in to change notification settings - Fork 882
/
deepspeed.py
252 lines (200 loc) · 8.82 KB
/
deepspeed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
from copy import deepcopy
from ..optimizer import AcceleratedOptimizer
from ..scheduler import AcceleratedScheduler
class HfDeepSpeedConfig:
"""
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
it's important that this object remains alive while the program is still running.
[`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
the DeepSpeed configuration is not modified in any way.
Args:
config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
"""
def __init__(self, config_file_or_dict):
if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overridden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict")
self.config = config
# zero stage - this is done as early as possible, before model is created, to allow
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
# during ``zero.Init()`` which needs to know the dtype, and some other hparams.
self._stage = self.get_value("zero_optimization.stage", -1)
# offload
self._offload = False
if self.is_zero2() or self.is_zero3():
offload_devices_valid = set(["cpu", "nvme"])
offload_devices = set(
[
self.get_value("zero_optimization.offload_optimizer.device"),
self.get_value("zero_optimization.offload_param.device"),
]
)
if len(offload_devices & offload_devices_valid) > 0:
self._offload = True
def find_config_node(self, ds_key_long):
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return None, ds_key
return config, ds_key
def get_value(self, ds_key_long, default=None):
"""
Returns the set value or `default` if no value is set
"""
config, ds_key = self.find_config_node(ds_key_long)
if config is None:
return default
return config.get(ds_key, default)
def del_config_sub_tree(self, ds_key_long, must_exist=False):
"""
Deletes a sub-section of the config file if it's found.
Unless `must_exist` is `True` the section doesn't have to exist.
"""
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
for node in nodes:
parent_config = config
config = config.get(node)
if config is None:
if must_exist:
raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
else:
return
# if found remove it
if parent_config is not None:
parent_config.pop(node)
def is_true(self, ds_key_long):
"""
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).
"""
value = self.get_value(ds_key_long)
return False if value is None else bool(value)
def is_false(self, ds_key_long):
"""
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).
"""
value = self.get_value(ds_key_long)
return False if value is None else not bool(value)
def is_zero2(self):
return self._stage == 2
def is_zero3(self):
return self._stage == 3
def is_offload(self):
return self._offload
class DeepSpeedEngineWrapper:
"""
Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
Args:
engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
"""
def __init__(self, engine):
self.engine = engine
def backward(self, loss):
# runs backpropagation and handles mixed precision
self.engine.backward(loss)
# deepspeed `engine.step` performs following operations:
# gradient accumulation check
# gradient clipping
# optimizer step
# zero grad
# checking overflow
# lr_scheduler step
self.engine.step()
class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
"""
Internal wrapper around a deepspeed optimizer.
Args:
optimizer (`torch.optim.optimizer.Optimizer`):
The optimizer to wrap.
"""
def __init__(self, optimizer):
super().__init__(optimizer, device_placement=False, scaler=None)
def zero_grad(self, set_to_none=None):
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
def step(self):
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
@property
def step_was_skipped(self):
"""Whether or not the optimizer step was done, or skipped because of gradient overflow."""
return self.optimizer.overflow
class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
"""
Internal wrapper around a deepspeed scheduler.
Args:
scheduler (`torch.optim.lr_scheduler.LambdaLR`):
The scheduler to wrap.
optimizers (one or a list of `torch.optim.Optimizer`):
"""
def __init__(self, scheduler, optimizers):
super().__init__(scheduler, optimizers)
def step(self):
pass # `accelerator.backward(loss)` is doing that automatically. Therefore, it's implementation is not needed
class DummyOptim:
"""
Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training
loop when optimizer config is specified in the deepspeed config file.
Args:
lr (float):
Learning rate.
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
weight_decay (float):
Weight decay.
**kwargs:
Other arguments.
"""
def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
self.params = params
self.lr = lr
self.weight_decay = weight_decay
self.kwargs = kwargs
class DummyScheduler:
"""
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
loop when scheduler config is specified in the deepspeed config file.
Args:
optimizer (`torch.optim.optimizer.Optimizer`):
The optimizer to wrap.
total_num_steps (int):
Total number of steps.
warmup_num_steps (int):
Number of steps for warmup.
**kwargs:
Other arguments.
"""
def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs):
self.optimizer = optimizer
self.total_num_steps = total_num_steps
self.warmup_num_steps = warmup_num_steps
self.kwargs = kwargs