-
Notifications
You must be signed in to change notification settings - Fork 104
/
trainer.py
2168 lines (1905 loc) · 87.8 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
import gc
import importlib
import logging
import os
import platform
import shutil
import sys
import time
import traceback
from contextlib import nullcontext
from dataclasses import dataclass, field
from inspect import signature
from typing import Callable, Dict, List, Tuple, Union
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from trainer.analytics import ping_training_run
from trainer.callbacks import TrainerCallback
from trainer.generic_utils import (
KeepAverage,
count_parameters,
get_experiment_folder_path,
get_git_branch,
isimplemented,
remove_experiment_folder,
set_partial_state_dict,
to_cuda,
)
from trainer.io import (
copy_model_files,
get_last_checkpoint,
load_fsspec,
save_best_model,
save_checkpoint,
)
from trainer.logging import ConsoleLogger, DummyLogger, logger_factory
from trainer.trainer_utils import (
get_optimizer,
get_scheduler,
is_apex_available,
print_training_env,
setup_torch_training_env,
)
from trainer.utils.cuda_memory import cuda_meminfo, should_reduce_batch_size
from trainer.utils.distributed import (
get_rank,
init_distributed,
rank_zero_logger_info,
rank_zero_only,
)
logger = logging.getLogger("trainer")
if is_apex_available():
from apex import amp # pylint: disable=import-error
@dataclass
class TrainerConfig(Coqpit):
"""Config fields tweaking the Trainer for a model.
A ````ModelConfig```, by inheriting ```TrainerConfig``` must be defined for using 👟.
Inherit this by a new model config and override the fields as needed.
All the fields can be overridden from comman-line as ```--coqpit.arg_name=value```.
Example::
Run the training code by overriding the ```lr``` and ```plot_step``` fields.
>>> python train.py --coqpit.plot_step=22 --coqpit.lr=0.001
Defining a model using ```TrainerConfig```.
>>> from trainer import TrainerConfig
>>> class MyModelConfig(TrainerConfig):
... optimizer: str = "Adam"
... lr: float = 0.001
... epochs: int = 1
... ...
>>> class MyModel(nn.module):
... def __init__(self, config):
... ...
>>> model = MyModel(MyModelConfig())
"""
# Fields for the run
output_path: str = field(default="output")
logger_uri: str = field(
default=None,
metadata={
"help": "URI to save training artifacts by the logger. If not set, logs will be saved in the output_path. Defaults to None"
},
)
run_name: str = field(default="run", metadata={"help": "Name of the run. Defaults to 'run'"})
project_name: str = field(default=None, metadata={"help": "Name of the project. Defaults to None"})
run_description: str = field(
default="🐸Coqui trainer run.",
metadata={"help": "Notes and description about the run. Defaults to '🐸Coqui trainer run.'"},
)
# Fields for logging
print_step: int = field(
default=25, metadata={"help": "Print training stats on the terminal every print_step steps. Defaults to 25"}
)
plot_step: int = field(
default=100, metadata={"help": "Plot training stats on the logger every plot_step steps. Defaults to 100"}
)
model_param_stats: bool = field(
default=False, metadata={"help": "Log model parameters stats on the logger dashboard. Defaults to False"}
)
wandb_entity: str = field(default=None, metadata={"help": "Wandb entity to log the run. Defaults to None"})
dashboard_logger: str = field(
default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"}
)
# Fields for checkpointing
save_on_interrupt: bool = field(
default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"}
)
log_model_step: int = field(
default=None,
metadata={
"help": "Save checkpoint to the logger every log_model_step steps. If not defined `save_step == log_model_step`."
},
)
save_step: int = field(
default=10000, metadata={"help": "Save local checkpoint every save_step steps. Defaults to 10000"}
)
save_n_checkpoints: int = field(default=5, metadata={"help": "Keep n local checkpoints. Defaults to 5"})
save_checkpoints: bool = field(default=True, metadata={"help": "Save checkpoints locally. Defaults to True"})
save_all_best: bool = field(
default=False, metadata={"help": "Save all best checkpoints and keep the older ones. Defaults to False"}
)
save_best_after: int = field(
default=10000, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 10000"}
)
target_loss: str = field(
default=None, metadata={"help": "Target loss name to select the best model. Defaults to None"}
)
# Fields for eval and test run
print_eval: bool = field(default=False, metadata={"help": "Print eval steps on the terminal. Defaults to False"})
test_delay_epochs: int = field(default=0, metadata={"help": "Wait N epochs before running the test. Defaults to 0"})
run_eval: bool = field(
default=True, metadata={"help": "Run evalulation epoch after training epoch. Defaults to True"}
)
run_eval_steps: int = field(
default=None,
metadata={
"help": "Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None"
},
)
# Fields for distributed training
distributed_backend: str = field(
default="nccl", metadata={"help": "Distributed backend to use. Defaults to 'nccl'"}
)
distributed_url: str = field(
default="tcp://localhost:54321",
metadata={"help": "Distributed url to use. Defaults to 'tcp://localhost:54321'"},
)
# Fields for training specs
mixed_precision: bool = field(default=False, metadata={"help": "Use mixed precision training. Defaults to False"})
precision: str = field(
default="fp16",
metadata={
"help": "Precision to use in mixed precision training. `fp16` for float16 and `bf16` for bfloat16. Defaults to 'f16'"
},
)
epochs: int = field(default=1000, metadata={"help": "Number of epochs to train. Defaults to 1000"})
batch_size: int = field(default=32, metadata={"help": "Batch size to use. Defaults to 32"})
eval_batch_size: int = field(default=16, metadata={"help": "Batch size to use for eval. Defaults to 16"})
grad_clip: float = field(
default=0.0, metadata={"help": "Gradient clipping value. Disabled if <= 0. Defaults to 0.0"}
)
scheduler_after_epoch: bool = field(
default=True,
metadata={"help": "Step the scheduler after each epoch else step after each iteration. Defaults to True"},
)
# Fields for optimzation
lr: Union[float, List[float]] = field(
default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"}
)
optimizer: Union[str, List[str]] = field(default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"})
optimizer_params: Union[Dict, List[Dict]] = field(
default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"}
)
lr_scheduler: Union[str, List[str]] = field(
default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"}
)
lr_scheduler_params: Dict = field(
default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"}
)
use_grad_scaler: bool = field(
default=False,
metadata={
"help": "Enable/disable gradient scaler explicitly. It is enabled by default with AMP training. Defaults to False"
},
)
allow_tf32: bool = field(
default=False,
metadata={
"help": "A bool that controls whether TensorFloat-32 tensor cores may be used in matrix multiplications on Ampere or newer GPUs. Default to False."
},
)
cudnn_enable: bool = field(default=True, metadata={"help": "Enable/disable cudnn explicitly. Defaults to True"})
cudnn_deterministic: bool = field(
default=False,
metadata={
"help": "Enable/disable deterministic cudnn operations. Set this True for reproducibility but it slows down training significantly. Defaults to False."
},
)
cudnn_benchmark: bool = field(
default=False,
metadata={
"help": "Enable/disable cudnn benchmark explicitly. Set this False if your input size change constantly. Defaults to False"
},
)
training_seed: int = field(
default=54321,
metadata={"help": "Global seed for torch, random and numpy random number generator. Defaults to 54321"},
)
@dataclass
class TrainerArgs(Coqpit):
"""Trainer arguments that can be accessed from the command line.
Examples::
>>> python train.py --restore_path /path/to/checkpoint.pth
"""
continue_path: str = field(
default="",
metadata={
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
},
)
restore_path: str = field(
default="",
metadata={
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
},
)
best_path: str = field(
default="",
metadata={
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
},
)
use_ddp: bool = field(
default=False,
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
)
use_accelerate: bool = field(default=False, metadata={"help": "Use HF Accelerate as the back end for training."})
grad_accum_steps: int = field(
default=1,
metadata={
"help": "Number of gradient accumulation steps. It is used to accumulate gradients over multiple batches."
},
)
overfit_batch: bool = field(default=False, metadata={"help": "Overfit a single batch for debugging."})
skip_train_epoch: bool = field(
default=False,
metadata={"help": "Skip training and only run evaluation and test."},
)
start_with_eval: bool = field(
default=False,
metadata={"help": "Start with evaluation and test."},
)
small_run: int = field(
default=None,
metadata={
"help": "Only use a subset of the samples for debugging. Set the number of samples to use. Defaults to None. "
},
)
gpu: int = field(
default=None, metadata={"help": "GPU ID to use if ```CUDA_VISIBLE_DEVICES``` is not set. Defaults to None."}
)
# only for DDP
rank: int = field(default=0, metadata={"help": "Process rank in a distributed training. Don't set manually."})
group_id: str = field(
default="", metadata={"help": "Process group id in a distributed training. Don't set manually."}
)
class Trainer:
def __init__( # pylint: disable=dangerous-default-value
self,
args: TrainerArgs,
config: Coqpit,
output_path: str,
c_logger: ConsoleLogger = None,
dashboard_logger: "Logger" = None,
model: nn.Module = None,
get_model: Callable = None,
get_data_samples: Callable = None,
train_samples: List = None,
eval_samples: List = None,
test_samples: List = None,
train_loader: DataLoader = None,
eval_loader: DataLoader = None,
training_assets: Dict = {},
parse_command_line_args: bool = True,
callbacks: Dict[str, Callable] = {},
gpu: int = None,
) -> None:
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
or easily be customized.
Notes:
Supports Automatic Mixed Precision training. If `Apex` is availabe, it automatically picks that, else
it uses PyTorch's native `amp` module. `Apex` may provide more stable training in some cases.
Args:
args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainerArgs`
config object.
config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating
and testing the model.
output_path (str): Path to the output training folder. All the files are saved under thi path.
c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default
console logger is used. Defaults to None.
dashboard_logger Union[TensorboardLogger, WandbLogger]: Dashboard logger. If not provided, the tensorboard logger is used.
Defaults to None.
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
initializes a model from the provided config. Defaults to None.
get_model (Callable):
A function that returns a model. It is used to initialize the model when `model` is not provided.
It either takes the config as the only argument or does not take any argument.
Defaults to None
get_data_samples (Callable):
A function that returns a list of training and evaluation samples. Used if `train_samples` and
`eval_samples` are None. Defaults to None.
train_samples (List):
A list of training samples used by the model's `get_train_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
eval_samples (List):
A list of evaluation samples used by the model's `get_eval_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
train_loader (DataLoader):
A pytorch data loader object for training epochs. Leave as None if you want it to be made during training. Defaults to None.
eval_loader (DataLoader):
A pytorch data loader object for evaluation epochs. Leave as None to be generated during training. Defaults to None.
test_samples (List):
A list of test samples used by the model's `get_test_data_loader` to init the `dataset` and the
`data_loader`. If None, the ```model.test_run()``` is expected to load the data. Defaults to None.
training_assets (Dict):
A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()```
during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}.
parse_command_line_args (bool):
If true, parse command-line arguments and update `TrainerArgs` and model `config` values. Set it
to false if you parse the arguments yourself. Defaults to True.
callbacks (Dict[str, Callable]):
A dictionary of callbacks to be used during training. The keys are the callback names and the values
gpu (int):
GPU ID to use for training If "CUDA_VISIBLE_DEVICES" is not set. Defaults to None.
Example::
Running trainer with a model.
>>> args = TrainerArgs(...)
>>> config = ModelConfig(...)
>>> model = Model(config)
>>> trainer = Trainer(args, config, output_path, model=model)
>>> trainer.fit()
TODO:
- Wrap model for not calling .module in DDP.
- Deepspeed integration
- Profiler integration.
- Overfitting to a batch.
- TPU training
"""
if parse_command_line_args:
# parse command-line arguments to override TrainerArgs()
args, coqpit_overrides = self.parse_argv(args)
# get ready for training and parse command-line arguments to override the model config
config, new_fields = self.init_training(args, coqpit_overrides, config)
elif args.continue_path or args.restore_path:
config, new_fields = self.init_training(args, {}, config)
else:
new_fields = {}
# set the output path
if args.continue_path:
# use the same path as the continuing run
output_path = args.continue_path
else:
# override the output path if it is provided
output_path = config.output_path if output_path is None else output_path
# create a new output folder name
output_path = get_experiment_folder_path(config.output_path, config.run_name)
os.makedirs(output_path, exist_ok=True)
# copy training assets to the output folder
copy_model_files(config, output_path, new_fields)
# init class members
self.args = args
self.config = config
self.output_path = output_path
self.training_assets = training_assets
self.grad_accum_steps = args.grad_accum_steps
self.overfit_batch = args.overfit_batch
self.skip_train_epoch = args.skip_train_epoch
self.start_with_eval = args.start_with_eval
assert self.grad_accum_steps > 0, " [!] grad_accum_steps must be greater than 0."
# setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# setup training environment
self.use_cuda, self.num_gpus = self.setup_training_environment(args=args, config=config, gpu=gpu)
# init loggers
self.dashboard_logger, self.c_logger = self.init_loggers(self.config, output_path, dashboard_logger, c_logger)
# self.c_logger.logger = logger
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
# make sure that start_with_eval is disabled if eval is disabled
if not self.config.run_eval and self.start_with_eval:
self.start_with_eval = False
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
self.restore_epoch = 0
self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None}
self.train_loader = None
self.test_loader = None
self.eval_loader = None
self.keep_avg_train = None
self.keep_avg_eval = None
self.use_amp_scaler = (
self.use_cuda
if self.config.mixed_precision and self.config.precision == "fp16"
else self.config.use_grad_scaler
)
if train_samples is not None:
# use the provided samples
self.train_samples = train_samples
self.eval_samples = eval_samples
self.test_samples = test_samples
elif get_data_samples is not None:
# run `get_data_samples` to init the data samples
( # pylint: disable=unbalanced-tuple-unpacking
self.train_samples,
self.eval_samples,
self.test_samples,
) = self.run_get_data_samples(config, get_data_samples)
else:
# expecting to load the samples in `model.get_data_loader()`
self.train_samples = None
self.eval_samples = None
self.test_samples = None
# define custom train and eval loader
self.train_loader = train_loader
self.eval_loader = eval_loader
# only use a subset of the samples if small_run is set
self.setup_small_run(args.small_run)
# init the model
if model is None and get_model is None:
raise ValueError("[!] `model` and `get_model` cannot both be None.")
if model is not None:
self.model = model
else:
self.run_get_model(self.config, get_model)
# init model's training assets
if isimplemented(self.model, "init_for_training"):
self.model.init_for_training()
# setup criterion
self.criterion = self.get_criterion(self.model)
# DISTRUBUTED
if self.use_pt_ddp:
rank_zero_logger_info(" > Using PyTorch DDP", logger)
init_distributed(
args.rank,
self.num_gpus,
args.group_id,
self.config.distributed_backend,
self.config.distributed_url,
)
if self.use_cuda:
self.model.cuda()
if isinstance(self.criterion, list):
for criterion in self.criterion:
if isinstance(criterion, torch.nn.Module):
criterion.cuda()
else:
if isinstance(self.criterion, torch.nn.Module):
self.criterion.cuda()
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)
# If multiple-optimizer setup with grad accumulation and without custom optimize method raise an error
if (
self.grad_accum_steps != 1
and isinstance(self.optimizer, list)
and not isimplemented(self.model, "optimize")
):
raise ValueError(
" [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called ´optimize` that need to deal with dangling gradients in multiple-optimizer setup!"
)
# CALLBACK
self.callbacks = TrainerCallback()
self.callbacks.parse_callbacks_dict(callbacks)
self.callbacks.on_init_start(self)
# init AMP
if self.use_amp_scaler:
if self.use_apex:
self.scaler = None
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
# restore model
if self.args.restore_path:
(self.model, self.optimizer, self.scaler, self.restore_step, self.restore_epoch) = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer, self.scaler
)
self.scaler = torch.cuda.amp.GradScaler()
# setup scheduler
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
self.scheduler = self.restore_scheduler(
self.scheduler, self.args, self.config, self.restore_epoch, self.restore_step
)
# DISTRIBUTED
if self.use_pt_ddp:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
# setup accelerator
self.setup_accelerate()
# count model size
num_params = count_parameters(self.model)
rank_zero_logger_info(f"\n > Model has {num_params} parameters", logger)
self.callbacks.on_init_end(self)
self.dashboard_logger.add_config(config)
self.save_training_script()
ping_training_run()
@property
def use_apex(self):
"""Return True if using APEX."""
return not self.args.use_accelerate and self._is_apex_available()
@property
def use_pt_ddp(self):
"""Return True if using PyTorch DDP."""
return self.num_gpus > 1 and not self.use_accelerate
@property
def use_accelerate(self):
"""Return True if using HF Accelerate."""
return self.args.use_accelerate
def setup_accelerate(self):
if self.use_accelerate:
self.model, self.optimizer, self.train_loader, self.scheduler, self.accelerator = self.init_accelerate(
model=self.model,
optimizer=self.optimizer,
training_dataloader=self.train_loader,
scheduler=self.scheduler,
grad_accum_steps=self.grad_accum_steps,
mixed_precision=self.config.mixed_precision,
precision=self.config.precision,
)
def prepare_accelerate_loader(self, data_loader):
"""Prepare the accelerator for the training."""
if self.use_accelerate:
return self.accelerator.prepare_data_loader(data_loader)
return data_loader
@staticmethod
def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum_steps, mixed_precision, precision):
"""Setup HF Accelerate for the training."""
# check if accelerate is installed
try:
from accelerate import Accelerator # pylint:disable=import-outside-toplevel
except ImportError as e:
raise ImportError("Please install accelerate to use this feature.") from e
_precision = precision if precision is not None else "f16" if mixed_precision else None
if _precision == "float16":
_precision = "f16"
elif _precision == "float8":
_precision = "f8"
elif _precision == "bfloat16":
_precision = "bf16"
accelerator = Accelerator(gradient_accumulation_steps=grad_accum_steps, mixed_precision=_precision)
if isinstance(model, torch.nn.Module):
model = accelerator.prepare_model(model)
if isinstance(optimizer, dict):
for key, optim in optimizer.items():
optimizer[key] = accelerator.prepare_optimizer(optim)
elif isinstance(optimizer, list):
for i, optim in enumerate(optimizer):
optimizer[i] = accelerator.prepare_optimizer(optim)
elif optimizer is not None:
optimizer = accelerator.prepare_optimizer(optimizer)
if isinstance(training_dataloader, torch.utils.data.DataLoader):
training_dataloader = accelerator.prepare_data_loader(training_dataloader)
if isinstance(scheduler, dict):
for key, sched in scheduler.items():
scheduler[key] = accelerator.prepare_scheduler(sched)
elif isinstance(scheduler, list):
for i, sched in enumerate(scheduler):
scheduler[i] = accelerator.prepare_scheduler(sched)
elif scheduler is not None:
scheduler = accelerator.prepare_scheduler(scheduler)
return model, optimizer, training_dataloader, scheduler, accelerator
def save_training_script(self):
"""Save the training script to tracking dashboard and output path."""
file_path = sys.argv[0]
if os.path.isfile(file_path):
file_name = os.path.basename(file_path)
self.dashboard_logger.add_artifact(file_or_dir=file_path, name=file_name, artifact_type="file")
with open(file_path, "r", encoding="utf8") as f:
self.dashboard_logger.add_text("training-script", f"{f.read()}", 0)
shutil.copyfile(file_path, os.path.join(self.output_path, file_name))
@staticmethod
def parse_argv(args: Union[Coqpit, List]):
"""Parse command line arguments to init or override `TrainerArgs()`."""
if isinstance(args, Coqpit):
parser = args.init_argparse(arg_prefix="")
else:
train_config = TrainerArgs()
parser = train_config.init_argparse(arg_prefix="")
training_args, coqpit_overrides = parser.parse_known_args()
args.parse_args(training_args)
return args, coqpit_overrides
@staticmethod
def init_loggers(config: "Coqpit", output_path: str, dashboard_logger=None, c_logger=None):
"""Init console and dashboard loggers.
Use the given logger if passed externally else use config values to pick the right logger.
Return a dashboard logger only for the rank 0 process in DDP
Define a console logger for each process in DDP
Args:
config (Coqpit): Model config.
output_path (str): Output path to save the training artifacts.
dashboard_logger (DashboardLogger): Object passed to the trainer from outside.
c_logger (ConsoleLogger): Object passed to the trained from outside.
Returns:
Initialized dashboard_logger and console_logger objects.
"""
c_logger = ConsoleLogger() if c_logger is None else c_logger
# only allow dashboard logging for the main process in DDP mode
if get_rank() > 0:
return DummyLogger(), c_logger
if dashboard_logger is None:
dashboard_logger = logger_factory(config, output_path)
return dashboard_logger, c_logger
def setup_small_run(self, small_run: int = None):
"""Use a subset of samples for training, evaluation and testing."""
if small_run is not None:
logger.info("[!] Small Run, only using %i samples.", small_run)
self.train_samples = None if self.train_samples is None else self.train_samples[:small_run]
self.eval_samples = None if self.eval_samples is None else self.eval_samples[:small_run]
self.test_samples = None if self.test_samples is None else self.test_samples[:small_run]
@staticmethod
def init_training(args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = None):
"""Initialize training and update model configs from command line arguments.
Args:
args (argparse.Namespace or dict like): Parsed trainer arguments.
config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
config (Coqpit): Config paramaters.
"""
# set arguments for continuing training
if args.continue_path:
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# use the same config
if config:
config.load_json(args.config_path)
else:
coqpit = Coqpit()
coqpit.load_json(args.config_path)
# override config values from command-line args
# TODO: Maybe it is better to do it outside
if len(coqpit_overrides) > 0:
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
# update the config.json fields and copy it to the output folder
new_fields = {}
if args.rank == 0:
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
return config, new_fields
@staticmethod
def setup_training_environment(args, config, gpu):
if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973
import resource # pylint: disable=import-outside-toplevel
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
# set and initialize Pytorch runtime
use_cuda, num_gpus = setup_torch_training_env(
args=args,
cudnn_enable=config.cudnn_enable,
cudnn_deterministic=config.cudnn_deterministic,
cudnn_benchmark=config.cudnn_benchmark,
use_ddp=args.use_ddp,
training_seed=config.training_seed,
allow_tf32=config.allow_tf32,
gpu=gpu if args.gpu is None else args.gpu,
)
print_training_env(args, config)
return use_cuda, num_gpus
@staticmethod
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
"""Run the `get_model` function and return the model.
Args:
config (Coqpit): Model config.
Returns:
nn.Module: initialized model.
"""
if len(signature(get_model).sig.parameters) == 1:
model = get_model(config)
else:
model = get_model()
return model
@staticmethod
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
if callable(get_data_samples):
if len(signature(get_data_samples).sig.parameters) == 1:
train_samples, eval_samples = get_data_samples(config)
else:
train_samples, eval_samples = get_data_samples()
return train_samples, eval_samples
return None, None
def restore_model(
self,
config: Coqpit,
restore_path: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.cuda.amp.GradScaler = None,
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
"""Restore training from an old run. It restores model, optimizer, AMP scaler and training stats.
Args:
config (Coqpit): Model config.
restore_path (str): Path to the restored training run.
model (nn.Module): Model to restored.
optimizer (torch.optim.Optimizer): Optimizer to restore.
scaler (torch.cuda.amp.GradScaler, optional): AMP scaler to restore. Defaults to None.
Returns:
Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: [description]
"""
def _restore_list_objs(states, obj):
if isinstance(obj, list):
for idx, state in enumerate(states):
obj[idx].load_state_dict(state)
elif isinstance(obj, dict):
for key, state in states.items():
obj[key].load_state_dict(state)
else:
obj.load_state_dict(states)
return obj
logger.info(" > Restoring from %s ...", os.path.basename(restore_path))
checkpoint = load_fsspec(restore_path, map_location="cpu")
try:
logger.info(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
logger.info(" > Restoring Optimizer...")
try:
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
except (KeyError, TypeError, RuntimeError):
logger.info(" > Optimizer is not compatible with the restored model.")
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
logger.info(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError, ValueError):
logger.info(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_partial_state_dict(model_dict, checkpoint["model"], config)
model.load_state_dict(model_dict)
del model_dict
optimizer = self.restore_lr(config, self.args, model, optimizer)
logger.info(" > Model restored from step %i", checkpoint["step"])
restore_step = checkpoint["step"] + 1 # +1 not to immediately checkpoint if the model is restored
restore_epoch = checkpoint["epoch"]
torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step, restore_epoch
def restore_lr(self, config, args, model, optimizer):
# use the same lr if continue training
if not args.continue_path:
if isinstance(optimizer, list):
for idx, optim in enumerate(optimizer):
for group in optim.param_groups:
group["lr"] = self.get_lr(model, config)[idx]
elif isinstance(optimizer, dict):
for optim_name, optim in optimizer.items():
for group in optim.param_groups:
group["lr"] = self.get_lr(model, config)[optim_name]
else:
for group in optimizer.param_groups:
group["lr"] = self.get_lr(model, config)
return optimizer
#########################
# DATA LOADING FUNCTIONS
#########################
def _get_loader(
self,
model: nn.Module,
config: Coqpit,
assets: Dict,
is_eval: str,
samples: List,
verbose: bool,
num_gpus: int,
) -> DataLoader:
if num_gpus > 1:
if isimplemented(model.module, "get_data_loader"):
loader = model.module.get_data_loader(
config,
assets,
is_eval,
samples,
verbose,
num_gpus,
self.args.rank,
)
else:
if isimplemented(model, "get_data_loader"):
loader = model.get_data_loader(
config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus
)
assert (
len(loader) > 0
), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. "
return loader
def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader:
"""Initialize and return a training data loader.
Call ```model.get_train_data_loader``` if it is implemented, else call ```model.get_data_loader```
and set ```is_eval=False```.
Args:
ap (AudioProcessor): Audio processor.
samples (List): Data samples used for training.
verbose (bool): enable/disable printing loader stats at initialization.
Returns:
DataLoader: Initialized training data loader.
"""
if self.num_gpus > 1:
if isimplemented(self.model.module, "get_train_data_loader"):
loader = self.model.module.get_train_data_loader(
self.config,
self.training_assets,
samples,
verbose,
self.num_gpus,
self.args.rank,
)
return loader
else:
if isimplemented(self.model, "get_train_data_loader"):
loader = self.model.get_train_data_loader(
self.config, self.training_assets, samples, verbose, self.num_gpus
)
return loader
return self._get_loader(
self.model,
self.config,
training_assets,
False,
samples,
verbose,
self.num_gpus,
)
def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader:
"""Initialize and return a evaluation data loader.
Call ```model.get_eval_data_loader``` if it is implemented, else call ```model.get_data_loader```
and set ```is_eval=True```.
Args:
ap (AudioProcessor): Audio processor.
samples (List): Data samples used for training.
verbose (bool): enable/disable printing loader stats at initialization.
Returns:
DataLoader: Initialized training data loader.
"""
if self.num_gpus > 1:
if isimplemented(self.model.module, "get_eval_data_loader"):
loader = self.model.module.get_eval_data_loader(
self.config,
self.training_assets,
samples,
verbose,
self.num_gpus,
self.args.rank,
)
return loader
else:
if isimplemented(self.model, "get_eval_data_loader"):
loader = self.model.get_eval_data_loader(
self.config, self.training_assets, samples, verbose, self.num_gpus
)
return loader
return self._get_loader(
self.model,
self.config,
training_assets,
True,
samples,
verbose,
self.num_gpus,
)
def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: