-
Notifications
You must be signed in to change notification settings - Fork 392
/
training_args.py
338 lines (299 loc) 路 16.2 KB
/
training_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from transformers import TrainingArguments
from transformers.debug_utils import DebugOption
from transformers.trainer_utils import (
EvaluationStrategy,
FSDPOption,
HubStrategy,
IntervalStrategy,
SchedulerType,
ShardedDDPOption,
)
from transformers.training_args import OptimizerNames, default_logdir, logger
from transformers.utils import (
ExplicitEnum,
get_full_repo_name,
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tf32_available,
logging,
)
if is_torch_available():
import torch
class ORTOptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers in `onnxruntime.training.optim`.
"""
ADAMW_ORT_FUSED = "adamw_ort_fused"
@dataclass
class ORTTrainingArguments(TrainingArguments):
"""
Parameters:
optim (`str` or [`training_args.ORTOptimizerNames`] or [`transformers.training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use, including optimizers in Transformers: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. And optimizers implemented by ONNX Runtime: adamw_ort_fused.
"""
optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 馃 Transformers.
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
if self.logging_dir is None and self.output_dir is not None:
self.logging_dir = os.path.join(self.output_dir, default_logdir())
if self.logging_dir is not None:
self.logging_dir = os.path.expanduser(self.logging_dir)
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5"
" of 馃 Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
# Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
self.evaluation_strategy = self.evaluation_strategy.value
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.hub_strategy = HubStrategy(self.hub_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
self.do_eval = True
# eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
if self.logging_steps > 0:
logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
self.eval_steps = self.logging_steps
else:
raise ValueError(
f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or"
" --logging_steps"
)
# logging_steps must be non-zero for logging_strategy that is other than 'no'
if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
# Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
if self.load_best_model_at_end:
if self.evaluation_strategy != self.save_strategy:
raise ValueError(
"--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}"
)
if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
raise ValueError(
"--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
)
if self.load_best_model_at_end and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
if self.run_name is None:
self.run_name = self.output_dir
if self.framework == "pt" and is_torch_available():
if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
"`fp16_backend` is deprecated and will be removed in version 5 of 馃 Transformers. Use"
" `half_precision_backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
if self.bf16 or self.bf16_full_eval:
if self.no_cuda and not is_torch_bf16_cpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.fp16_full_eval and self.bf16_full_eval:
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
try:
self.optim = ORTOptimizerNames(self.optim)
except ValueError:
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(
"`--adafactor` is deprecated and will be removed in version 5 of 馃 Transformers. Use `--optim"
" adafactor` instead",
FutureWarning,
)
self.optim = OptimizerNames.ADAFACTOR
if (
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA devices."
)
if (
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
torch.backends.cuda.matmul.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
)
if is_torch_available() and self.tf32 is not None:
if self.tf32:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False
# no need to assert on else
if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
"integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
"now. You should start updating your code and make this info disappear :-)."
)
self.report_to = "all"
if self.report_to == "all" or self.report_to == ["all"]:
# Import at runtime to avoid a circular import.
from transformers.integrations import get_available_reporting_integrations
self.report_to = get_available_reporting_integrations()
elif self.report_to == "none" or self.report_to == ["none"]:
self.report_to = []
elif not isinstance(self.report_to, list):
self.report_to = [self.report_to]
if self.warmup_ratio < 0 or self.warmup_ratio > 1:
raise ValueError("warmup_ratio must lie in range [0,1]")
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
logger.info(
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
" during training"
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
raise ValueError(
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
)
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 馃 Transformers. Use"
" `--debug tpu_metrics_debug` instead",
FutureWarning,
)
self.debug += " tpu_metrics_debug"
self.tpu_metrics_debug = False
if isinstance(self.debug, str):
self.debug = [DebugOption(s) for s in self.debug.split()]
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
from transformers.deepspeed import HfTrainerDeepSpeedConfig
# will be used later by the Trainer
# note: leave self.deepspeed unmodified in case a user relies on it not to be modified)
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)
if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 馃 Transformers. Use "
"`--hub_token` instead.",
FutureWarning,
)
self.hub_token = self.push_to_hub_token
if self.push_to_hub_model_id is not None:
self.hub_model_id = get_full_repo_name(
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
)
if self.push_to_hub_organization is not None:
warnings.warn(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 馃 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f"argument (in this case {self.hub_model_id}).",
FutureWarning,
)
else:
warnings.warn(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 馃 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
elif self.push_to_hub_organization is not None:
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
warnings.warn(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 馃 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)