diff --git a/src/sparseml/pytorch/image_classification/utils/helpers.py b/src/sparseml/pytorch/image_classification/utils/helpers.py index b7f38114c68..6734093b1fc 100644 --- a/src/sparseml/pytorch/image_classification/utils/helpers.py +++ b/src/sparseml/pytorch/image_classification/utils/helpers.py @@ -370,13 +370,14 @@ def save_model_training( :param arch_key: if provided, the `arch_key` will be saved in the checkpoint """ - has_top1 = "top1acc" in val_res.results - metric_name = "top-1 accuracy" if has_top1 else "val_loss" - metric = val_res.result_mean("top1acc" if has_top1 else DEFAULT_LOSS_KEY).item() - print( - f"Saving model for epoch {epoch} and {metric_name} " - f"{metric} to {save_dir} for {save_name}" - ) + if val_res is not None: + has_top1 = "top1acc" in val_res.results + metric_name = "top-1 accuracy" if has_top1 else "val_loss" + metric = val_res.result_mean("top1acc" if has_top1 else DEFAULT_LOSS_KEY).item() + print( + f"Saving model for epoch {epoch} and {metric_name} " + f"{metric} to {save_dir} for {save_name}" + ) exporter = ModuleExporter(model, save_dir) exporter.export_pytorch( optim, diff --git a/src/sparseml/pytorch/image_classification/utils/trainer.py b/src/sparseml/pytorch/image_classification/utils/trainer.py index e781e91e9ac..0fde1522a6f 100644 --- a/src/sparseml/pytorch/image_classification/utils/trainer.py +++ b/src/sparseml/pytorch/image_classification/utils/trainer.py @@ -124,7 +124,9 @@ def __init__( self.optim_name = optim_name self.epoch = 0 - + self._device_context = ModuleDeviceContext( + use_mixed_precision=self.use_mixed_precision, + ) if self.train_loader is not None: ( self.epoch, @@ -170,6 +172,10 @@ def run_one_epoch( train_mode = mode == "train" validation_mode = not train_mode + if torch.__version__ < "1.9" and self.manager.qat_active(epoch=self.epoch): + # switch off fp16 + self._device_context.use_mixed_precision = False + if validation_mode: return self._run_validation_epoch( max_steps=max_steps, @@ -225,15 +231,14 @@ def _initialize_scheduled_optimizer(self): return epoch, optim, manager def _initialize_module_trainer(self): + trainer = ModuleTrainer( module=self.model, device=self.device, loss=self.train_loss, optimizer=self.optim, loggers=self.loggers, - device_context=ModuleDeviceContext( - use_mixed_precision=self.use_mixed_precision, - ), + device_context=self._device_context, ) _LOGGER.info(f"created Module Trainer: {trainer}")