diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 941f777827b7..2c5006ae73e2 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -102,6 +102,24 @@ def create_optimizer(self): self.optimizer, self.model, self.manager, steps_per_epoch=steps_per_epoch, loggers=self.loggers ) + def create_scheduler(self, num_training_steps: int): + """ + Override LR scheduler if the SparseML manager has LR modifiers, otherwise + set default scheduler + """ + if self.lr_scheduler is not None: + # scheduler already set + return + + if self.manager.learning_rate_modifiers: + # allow SparseML to manage LR and set a dummy scheduler + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lambda _: 1.0, -1 + ) + else: + # default scheduler + super().create_scheduler(num_training_steps) + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. The sparsification recipe will also be saved.