From 11dc0d36747fd54c1b0f10bc2e6ebfd81fbb769b Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 10 Sep 2021 17:46:35 -0400 Subject: [PATCH] Force conversion of LR variables to float due to type coercion from manager serialization --- src/sparseml/optim/learning_rate.py | 3 +++ src/sparseml/pytorch/optim/modifier_lr.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/sparseml/optim/learning_rate.py b/src/sparseml/optim/learning_rate.py index c25fa553ba1..36d48ab77b4 100644 --- a/src/sparseml/optim/learning_rate.py +++ b/src/sparseml/optim/learning_rate.py @@ -156,6 +156,9 @@ def validate_lr_info(self): else: raise ValueError("unknown lr_class given of {}".format(self._lr_class)) + if isinstance(self._init_lr, str): + self._init_lr = float(self._init_lr) + if self._init_lr <= 0.0: raise ValueError("init_lr must be greater than 0") diff --git a/src/sparseml/pytorch/optim/modifier_lr.py b/src/sparseml/pytorch/optim/modifier_lr.py index c2c6fa6b822..340390c65fa 100644 --- a/src/sparseml/pytorch/optim/modifier_lr.py +++ b/src/sparseml/pytorch/optim/modifier_lr.py @@ -414,6 +414,9 @@ def validate(self): if self.lr_func not in lr_funcs: raise ValueError(f"lr_func must be one of {lr_funcs}") + if isinstance(self.init_lr, str): + self.init_lr = float(self.init_lr) + if ( (not self.init_lr and self.init_lr != 0) or self.init_lr < 0.0 @@ -423,6 +426,9 @@ def validate(self): f"init_lr must be within range [0.0, 1.0], given {self.init_lr}" ) + if isinstance(self.final_lr, str): + self.final_lr = float(self.final_lr) + if ( (not self.final_lr and self.final_lr != 0) or self.final_lr < 0.0