diff --git a/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb b/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb index b46729de0ce..63002433da0 100644 --- a/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb +++ b/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb @@ -332,13 +332,13 @@ "from sparseml.pytorch.utils import ModuleExporter\n", "\n", "save_dir = \"pytorch_sparse_quantized_transfer_learning\"\n", - "qat_onnx_graph_name = \"resnet50_imagenette_pruned_qat.onnx\"\n", - "quantized_onnx_path = os.path.join(save_dir, \"resnet50_imagenette_pruned_quant.onnx\")\n", + "quant_onnx_graph_name = \"resnet50_imagenette_pruned_quant.onnx\"\n", + "quantized_onnx_path = os.path.join(save_dir, quant_onnx_graph_name)\n", "\n", "exporter = ModuleExporter(model, output_dir=save_dir)\n", "exporter.export_pytorch(name=\"resnet50_imagenette_pruned_qat.pth\")\n", "exporter.export_onnx(\n", - " torch.randn(1, 3, 224, 224), name=qat_onnx_graph_name, convert_qat=True\n", + " torch.randn(1, 3, 224, 224), name=quant_onnx_graph_name, convert_qat=True\n", ")\n", "\n", "print(f\"Sparse-Quantized ONNX model saved to {quantized_onnx_path}\")" diff --git a/src/sparseml/pytorch/optim/manager.py b/src/sparseml/pytorch/optim/manager.py index 89e7561c610..c0e1b1821b1 100644 --- a/src/sparseml/pytorch/optim/manager.py +++ b/src/sparseml/pytorch/optim/manager.py @@ -158,6 +158,43 @@ def step(self, *args, **kwargs): :param kwargs: Any kwargs to pass to the wrapped objects step function. :return: The return, if any, from the wrapped objects step function """ + return self._perform_wrapped_step(*args, **kwargs) + + def emulated_step(self): + """ + Emulated step function to be called in place of step when the + number of steps_per_epoch vary across epochs. + The emulated function should be called to keep the steps_per_epoch thee same. + Does not call into the step function for the wrapped object, + but does call into the manager to increment the steps. + """ + self._perform_wrapped_step(skip_orig_step=True) + + def loss_update(self, loss: Tensor) -> Tensor: + """ + Optional call to update modifiers based on the calculated loss. + Not needed unless one or more of the modifier is using the loss + to make a modification or is modifying the loss itself. + + :param loss: the calculated loss after running a forward pass and loss_fn + :return: the modified loss tensor + """ + loss = self._wrapped_manager.loss_update( + loss, + self._wrapped_module, + self._wrapped_optimizer, + self._wrapped_epoch, + self._wrapped_steps_per_epoch, + ) + + return loss + + def _perform_wrapped_step(self, *args, **kwargs) -> Any: + skip_orig_step = ( + kwargs["skip_orig_step"] if "skip_orig_step" in kwargs else False + ) + ret = None + if self._wrapped_manager.enabled: self._wrapped_manager.update( self._wrapped_module, @@ -172,7 +209,8 @@ def step(self, *args, **kwargs): self._wrapped_steps_per_epoch, ) - ret = self._wrapped.step(*args, **kwargs) + if not skip_orig_step: + ret = self._wrapped.step(*args, **kwargs) if self._wrapped_manager.enabled: self._wrapped_manager.optimizer_post_step( @@ -192,25 +230,6 @@ def step(self, *args, **kwargs): return ret - def loss_update(self, loss: Tensor) -> Tensor: - """ - Optional call to update modifiers based on the calculated loss. - Not needed unless one or more of the modifier is using the loss - to make a modification or is modifying the loss itself. - - :param loss: the calculated loss after running a forward pass and loss_fn - :return: the modified loss tensor - """ - loss = self._wrapped_manager.loss_update( - loss, - self._wrapped_module, - self._wrapped_optimizer, - self._wrapped_epoch, - self._wrapped_steps_per_epoch, - ) - - return loss - class ScheduledModifierManager(BaseManager, Modifier): """