diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 600c9d2be11..f88db3bfc25 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -251,13 +251,14 @@ def create_optimizer(self): self.manager, steps_per_epoch=self.manager_steps_per_epoch, loggers=self.logger_manager, - grad_sampler=self.grad_sampler, + initialize_kwargs={"grad_sampler": self.grad_sampler}, ) if not self.manager.initialized: self.manager.initialize( self.model, loggers=self.logger_manager, distillation_teacher=self.teacher, + grad_sampler=self.grad_sampler, ) self.manager_initialized = True _LOGGER.info(