diff --git a/train_ms.py b/train_ms.py index 1a5660777..11fffced0 100644 --- a/train_ms.py +++ b/train_ms.py @@ -72,8 +72,6 @@ def run(): num_replicas=n_gpus, rank=rank, shuffle=True, - persistent_workers=True, - prefetch_factor=4, ) collate_fn = TextAudioSpeakerCollate() train_loader = DataLoader( @@ -83,6 +81,8 @@ def run(): pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, ) # DataLoader config could be adjusted. if rank == 0: eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)