diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 7153b857a37..b380d6476b8 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -231,7 +231,10 @@ def create_optimizer(self): self.manager, steps_per_epoch=self.manager_steps_per_epoch, loggers=self.manager_loggers, - initialize_kwargs={"grad_sampler": self.grad_sampler}, + initialize_kwargs={ + "grad_sampler": self.grad_sampler, + "distillation_teacher": self.teacher, + }, ) if not self.manager.initialized: self.manager.initialize(