diff --git a/mmrazor/engine/runner/darts_loop.py b/mmrazor/engine/runner/darts_loop.py index 92494e585..65e4611df 100644 --- a/mmrazor/engine/runner/darts_loop.py +++ b/mmrazor/engine/runner/darts_loop.py @@ -47,15 +47,14 @@ def __init__(self, mutator_dataloader, seed=runner.seed) else: self.mutator_dataloader = mutator_dataloader - multi_loaders = [self.dataloader, self.mutator_dataloader] - self.multi_loaders = EpochMultiLoader(multi_loaders) + self.multi_loaders = [self.dataloader, self.mutator_dataloader] def run_epoch(self) -> None: """Iterate one epoch.""" self.runner.call_hook('before_train_epoch') self.runner.model.train() - for idx, data_batch in enumerate(self.multi_loaders): + for idx, data_batch in enumerate(EpochMultiLoader(self.multi_loaders)): self.run_iter(idx, data_batch) self.runner.call_hook('after_train_epoch')