/
integrations.py
968 lines (801 loc) 路 39.3 KB
/
integrations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
import functools
import importlib.util
import numbers
import os
import sys
import tempfile
from pathlib import Path
from .utils import is_datasets_available, logging
logger = logging.get_logger(__name__)
# comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
try:
import comet_ml # noqa: F401
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
_has_comet = True
else:
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
_has_comet = False
except (ImportError, ValueError):
_has_comet = False
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
# Integration functions:
def is_wandb_available():
# any value of WANDB_DISABLED disables wandb
if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
logger.warning(
"Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
)
return False
return importlib.util.find_spec("wandb") is not None
def is_comet_available():
return _has_comet
def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_optuna_available():
return importlib.util.find_spec("optuna") is not None
def is_ray_available():
return importlib.util.find_spec("ray") is not None
def is_ray_tune_available():
if not is_ray_available():
return False
return importlib.util.find_spec("ray.tune") is not None
def is_sigopt_available():
return importlib.util.find_spec("sigopt") is not None
def is_azureml_available():
if importlib.util.find_spec("azureml") is None:
return False
if importlib.util.find_spec("azureml.core") is None:
return False
return importlib.util.find_spec("azureml.core.run") is not None
def is_mlflow_available():
if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
return False
return importlib.util.find_spec("mlflow") is not None
def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None
def is_neptune_available():
return importlib.util.find_spec("neptune") is not None
def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
return trial
if is_sigopt_available():
if isinstance(trial, dict):
return trial
if is_wandb_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_tune_available():
return "ray"
elif is_sigopt_available():
return "sigopt"
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
def _objective(trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
def _objective(trial, local_trainer, checkpoint_dir=None):
try:
from transformers.utils.notebook import NotebookProgressCallback
if local_trainer.pop_callback(NotebookProgressCallback):
local_trainer.add_callback(ProgressCallback)
except ModuleNotFoundError:
pass
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
local_trainer.objective = None
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
if not trainer._memory_tracker.skip_memory_metrics:
from .trainer_utils import TrainerMemoryTracker
logger.warning(
"Memory tracking for your Trainer is currently "
"enabled. Automatically disabling the memory tracker "
"since the memory tracker is not serializable."
)
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs:
# Default to 1 CPU and 1 GPU (if applicable) per trial.
kwargs["resources_per_trial"] = {"cpu": 1}
if trainer.args.n_gpu > 0:
kwargs["resources_per_trial"]["gpu"] = 1
resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
logger.info(
"No `resources_per_trial` arg was passed into "
"`hyperparameter_search`. Setting it to a default value "
f"of {resource_msg} for each trial."
)
# Make sure each trainer only uses GPUs that were allocated per trial.
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
trainer.args._n_gpu = gpus_per_trial
# Setup default `progress_reporter`.
if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
@functools.wraps(trainable)
def dynamic_modules_import_trainable(*args, **kwargs):
"""
Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
Assumes that `_objective`, defined above, is a function.
"""
if is_datasets_available():
import datasets.load
dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
# load dynamic_modules from path
spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
datasets_modules = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = datasets_modules
spec.loader.exec_module(datasets_modules)
return trainable(*args, **kwargs)
# special attr set by tune.with_parameters
if hasattr(trainable, "__mixins__"):
dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
analysis = ray.tune.run(
dynamic_modules_import_trainable,
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
from sigopt import Connection
conn = Connection()
proxies = kwargs.pop("proxies", None)
if proxies is not None:
conn.set_proxies(proxies)
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create()
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=suggestion)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
values = [dict(name="objective", value=trainer.objective)]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
best_run = BestRun(best.id, best.value, best.assignments)
return best_run
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
from .integrations import is_wandb_available
if not is_wandb_available():
raise ImportError("This function needs wandb installed: `pip install wandb`")
import wandb
# add WandbCallback if not already added in trainer callbacks
reporting_to_wandb = False
for callback in trainer.callback_handler.callbacks:
if isinstance(callback, WandbCallback):
reporting_to_wandb = True
break
if not reporting_to_wandb:
trainer.add_callback(WandbCallback())
trainer.args.report_to = "wandb"
best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
sweep_id = kwargs.pop("sweep_id", None)
project = kwargs.pop("project", None)
name = kwargs.pop("name", None)
entity = kwargs.pop("entity", None)
metric = kwargs.pop("metric", "eval/loss")
sweep_config = trainer.hp_space(None)
sweep_config["metric"]["goal"] = direction
sweep_config["metric"]["name"] = metric
if name:
sweep_config["name"] = name
def _objective():
run = wandb.run if wandb.run else wandb.init()
trainer.state.trial_name = run.name
run.config.update({"assignments": {}, "metric": metric})
config = wandb.config
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
format_metrics = rewrite_logs(metrics)
if metric not in format_metrics:
logger.warning(
f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available metrics are {format_metrics.keys()}"
)
best_score = False
if best_trial["run_id"] is not None:
if direction == "minimize":
best_score = trainer.objective < best_trial["objective"]
elif direction == "maximize":
best_score = trainer.objective > best_trial["objective"]
if best_score or best_trial["run_id"] is None:
best_trial["run_id"] = run.id
best_trial["objective"] = trainer.objective
best_trial["hyperparameters"] = dict(config)
return trainer.objective
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
logger.info(f"wandb sweep id - {sweep_id}")
wandb.agent(sweep_id, function=_objective, count=n_trials)
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
def get_available_reporting_integrations():
integrations = []
if is_azureml_available():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_mlflow_available():
integrations.append("mlflow")
if is_tensorboard_available():
integrations.append("tensorboard")
if is_wandb_available():
integrations.append("wandb")
if is_codecarbon_available():
integrations.append("codecarbon")
return integrations
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
test_prefix = "test_"
test_prefix_len = len(test_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
elif k.startswith(test_prefix):
new_d["test/" + k[test_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Args:
tb_writer (`SummaryWriter`, *optional*):
The writer to use. Will instantiate one if not set.
"""
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
if not has_tensorboard:
raise RuntimeError(
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
)
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None
else:
self._SummaryWriter = None
self.tb_writer = tb_writer
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
if self.tb_writer is None:
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
# Version of TensorBoard coming from tensorboardX does not have this method.
if hasattr(self.tb_writer, "add_hparams"):
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero:
return
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer is not None:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute."
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
self.tb_writer = None
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
"""
def __init__(self):
has_wandb = is_wandb_available()
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
if has_wandb:
import wandb
self._wandb = wandb
self._initialized = False
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.
One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
variables:
Environment:
WANDB_LOG_MODEL (`bool`, *optional*, defaults to `False`):
Whether or not to log model as artifact at the end of training. Use along with
*TrainingArguments.load_best_model_at_end* to upload best model.
WANDB_WATCH (`str`, *optional* defaults to `"gradients"`):
Can be `"gradients"`, `"all"` or `"false"`. Set to `"false"` to disable gradient logging or `"all"` to
log gradients and parameters.
WANDB_PROJECT (`str`, *optional*, defaults to `"huggingface"`):
Set this to a custom string to store results in a different project.
WANDB_DISABLED (`bool`, *optional*, defaults to `False`):
Whether or not to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
"""
if self._wandb is None:
return
self._initialized = True
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if hp_search:
self._wandb.finish()
self._initialized = False
args.run_name = None
if not self._initialized:
self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
metadata = (
{
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step})
class CometCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
"""
def __init__(self):
if not _has_comet:
raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
self._initialized = False
self._log_assets = False
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE (`str`, *optional*):
Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
or "DISABLED". Defaults to "ONLINE".
COMET_PROJECT_NAME (`str`, *optional*):
Comet project name for experiments
COMET_OFFLINE_DIRECTORY (`str`, *optional*):
Folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
COMET_LOG_ASSETS (`str`, *optional*):
Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
"FALSE". Defaults to "TRUE".
For a number of configurable items in the environment, see
[here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
"""
self._initialized = True
log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
if log_assets in {"TRUE", "1"}:
self._log_assets = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
experiment = None
experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if (experiment is not None) and (self._log_assets is True):
logger.info("Logging checkpoints. This may take time.")
experiment.log_asset_folder(
args.output_dir, recursive=True, log_file_name=True, step=state.global_step
)
experiment.end()
class AzureMLCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
"""
def __init__(self, azureml_run=None):
if not is_azureml_available():
raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs):
from azureml.core.run import Run
if self.azureml_run is None and state.is_world_process_zero:
self.azureml_run = Run.get_context()
def on_log(self, args, state, control, logs=None, **kwargs):
if self.azureml_run and state.is_world_process_zero:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.azureml_run.log(k, v, description=k)
class MLflowCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
"""
def __init__(self):
if not is_mlflow_available():
raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
import mlflow
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
self._initialized = False
self._log_artifacts = False
self._ml_flow = mlflow
def setup(self, args, state, model):
"""
Setup the optional MLflow integration.
Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts.
This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
without a remote storage will just copy the files to your artifact location.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
if state.is_world_process_zero:
if self._ml_flow.active_run is None:
self._ml_flow.start_run(run_name=args.run_name)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{value}" for key "{name}" as a parameter. '
f"MLflow's log_param() only accepts values no longer than "
f"250 characters so we dropped this attribute."
)
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, logs, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
else:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
f"MLflow's log_metric() only accepts float and "
f"int types so we dropped this attribute."
)
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
self._ml_flow.end_run()
class NeptuneCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Neptune](https://neptune.ai).
"""
def __init__(self):
if not is_neptune_available():
raise ValueError(
"NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
)
import neptune.new as neptune
self._neptune = neptune
self._initialized = False
self._log_artifacts = False
def setup(self, args, state, model):
"""
Setup the Neptune integration.
Environment:
NEPTUNE_PROJECT (`str`, *required*):
The project ID for neptune.ai account. Should be in format *workspace_name/project_name*
NEPTUNE_API_TOKEN (`str`, *required*):
API-token for neptune.ai account
NEPTUNE_CONNECTION_MODE (`str`, *optional*):
Neptune connection mode. *async* by default
NEPTUNE_RUN_NAME (`str`, *optional*):
The name of run process on Neptune dashboard
"""
if state.is_world_process_zero:
self._neptune_run = self._neptune.init(
project=os.getenv("NEPTUNE_PROJECT"),
api_token=os.getenv("NEPTUNE_API_TOKEN"),
mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
name=os.getenv("NEPTUNE_RUN_NAME", None),
run=os.getenv("NEPTUNE_RUN_ID", None),
)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
self._neptune_run["parameters"] = combined_dict
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, logs, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
self._neptune_run[k].log(v, step=state.global_step)
def __del__(self):
"""
Environment:
NEPTUNE_STOP_TIMEOUT (`int`, *optional*):
Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
run. If not set it will wait for all tracking calls to finish.
"""
try:
stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
stop_timeout = int(stop_timeout) if stop_timeout else None
self._neptune_run.stop(seconds=stop_timeout)
except AttributeError:
pass
class CodeCarbonCallback(TrainerCallback):
"""
A [`TrainerCallback`] that tracks the CO2 emission of training.
"""
def __init__(self):
if not is_codecarbon_available():
raise RuntimeError(
"CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
)
import codecarbon
self._codecarbon = codecarbon
self.tracker = None
def on_init_end(self, args, state, control, **kwargs):
if self.tracker is None and state.is_local_process_zero:
# CodeCarbon will automatically handle environment variables for configuration
self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self.tracker and state.is_local_process_zero:
self.tracker.start()
def on_train_end(self, args, state, control, **kwargs):
if self.tracker and state.is_local_process_zero:
self.tracker.stop()
INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
"mlflow": MLflowCallback,
"neptune": NeptuneCallback,
"tensorboard": TensorBoardCallback,
"wandb": WandbCallback,
"codecarbon": CodeCarbonCallback,
}
def get_reporting_integration_callbacks(report_to):
for integration in report_to:
if integration not in INTEGRATION_TO_CALLBACK:
raise ValueError(
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
)
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]