-
Notifications
You must be signed in to change notification settings - Fork 25.1k
/
integrations.py
410 lines (332 loc) 路 15.9 KB
/
integrations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# Integrations with other Python libraries
import math
import os
# Import 3rd-party integrations first:
try:
# Comet needs to be imported before any ML frameworks
import comet_ml # noqa: F401
# XXX: there should be comet_ml.ensure_configured(), like `wandb`, for now emulate it
comet_ml.Experiment(project_name="ensure_configured")
_has_comet = True
except (ImportError, ValueError):
_has_comet = False
try:
import wandb
wandb.ensure_configured()
if wandb.api.api_key is None:
_has_wandb = False
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
else:
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
except (ImportError, AttributeError):
_has_wandb = False
try:
import optuna # noqa: F401
_has_optuna = True
except (ImportError):
_has_optuna = False
try:
import ray # noqa: F401
_has_ray = True
except (ImportError):
_has_ray = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
# No transformer imports above this point
from .file_utils import is_torch_tpu_available
from .trainer_callback import TrainerCallback
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from .utils import logging
logger = logging.get_logger(__name__)
# Integration functions:
def is_wandb_available():
return _has_wandb
def is_comet_available():
return _has_comet
def is_tensorboard_available():
return _has_tensorboard
def is_optuna_available():
return _has_optuna
def is_ray_available():
return _has_ray
def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_available():
return "ray"
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective, **metrics, done=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
# `args.n_gpu` is considered the total number of GPUs that will be split
# among the `n_jobs`
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
<https://www.tensorflow.org/tensorboard>`__.
Args:
tb_writer (:obj:`SummaryWriter`, `optional`):
The writer to use. Will instantiate one if not set.
"""
def __init__(self, tb_writer=None):
assert (
_has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
class WandbCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
<https://www.wandb.com/>`__.
"""
def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False
def setup(self, args, state, model, reinit, **kwargs):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
Set this to a custom string to store results in a different project.
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely.
"""
self._initialized = True
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model, reinit=False)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step)
class CometCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
<https://www.comet.ml/site/>`__.
"""
def __init__(self):
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
self._initialized = False
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE (:obj:`str`, `optional`):
"OFFLINE", "ONLINE", or "DISABLED"
COMET_PROJECT_NAME (:obj:`str`, `optional`):
Comet.ml project name for experiments
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
"""
self._initialized = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")