From 74c36ccd23a998e1ff764ad6bafaada17e57c193 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Thu, 13 May 2021 10:03:17 -0400 Subject: [PATCH 1/2] Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch (#236) * Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch * Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch --- src/sparseml/pytorch/optim/manager.py | 59 ++++++++++++++++++--------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/sparseml/pytorch/optim/manager.py b/src/sparseml/pytorch/optim/manager.py index 89e7561c610..c0e1b1821b1 100644 --- a/src/sparseml/pytorch/optim/manager.py +++ b/src/sparseml/pytorch/optim/manager.py @@ -158,6 +158,43 @@ def step(self, *args, **kwargs): :param kwargs: Any kwargs to pass to the wrapped objects step function. :return: The return, if any, from the wrapped objects step function """ + return self._perform_wrapped_step(*args, **kwargs) + + def emulated_step(self): + """ + Emulated step function to be called in place of step when the + number of steps_per_epoch vary across epochs. + The emulated function should be called to keep the steps_per_epoch thee same. + Does not call into the step function for the wrapped object, + but does call into the manager to increment the steps. + """ + self._perform_wrapped_step(skip_orig_step=True) + + def loss_update(self, loss: Tensor) -> Tensor: + """ + Optional call to update modifiers based on the calculated loss. + Not needed unless one or more of the modifier is using the loss + to make a modification or is modifying the loss itself. + + :param loss: the calculated loss after running a forward pass and loss_fn + :return: the modified loss tensor + """ + loss = self._wrapped_manager.loss_update( + loss, + self._wrapped_module, + self._wrapped_optimizer, + self._wrapped_epoch, + self._wrapped_steps_per_epoch, + ) + + return loss + + def _perform_wrapped_step(self, *args, **kwargs) -> Any: + skip_orig_step = ( + kwargs["skip_orig_step"] if "skip_orig_step" in kwargs else False + ) + ret = None + if self._wrapped_manager.enabled: self._wrapped_manager.update( self._wrapped_module, @@ -172,7 +209,8 @@ def step(self, *args, **kwargs): self._wrapped_steps_per_epoch, ) - ret = self._wrapped.step(*args, **kwargs) + if not skip_orig_step: + ret = self._wrapped.step(*args, **kwargs) if self._wrapped_manager.enabled: self._wrapped_manager.optimizer_post_step( @@ -192,25 +230,6 @@ def step(self, *args, **kwargs): return ret - def loss_update(self, loss: Tensor) -> Tensor: - """ - Optional call to update modifiers based on the calculated loss. - Not needed unless one or more of the modifier is using the loss - to make a modification or is modifying the loss itself. - - :param loss: the calculated loss after running a forward pass and loss_fn - :return: the modified loss tensor - """ - loss = self._wrapped_manager.loss_update( - loss, - self._wrapped_module, - self._wrapped_optimizer, - self._wrapped_epoch, - self._wrapped_steps_per_epoch, - ) - - return loss - class ScheduledModifierManager(BaseManager, Modifier): """ From bc61347d5e4e983a8a05622b22869d32772d304a Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Thu, 13 May 2021 10:39:20 -0400 Subject: [PATCH 2/2] update transfer learning notebook quant file name (#237) --- .../pytorch_sparse_quantized_transfer_learning.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb b/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb index b46729de0ce..63002433da0 100644 --- a/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb +++ b/examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb @@ -332,13 +332,13 @@ "from sparseml.pytorch.utils import ModuleExporter\n", "\n", "save_dir = \"pytorch_sparse_quantized_transfer_learning\"\n", - "qat_onnx_graph_name = \"resnet50_imagenette_pruned_qat.onnx\"\n", - "quantized_onnx_path = os.path.join(save_dir, \"resnet50_imagenette_pruned_quant.onnx\")\n", + "quant_onnx_graph_name = \"resnet50_imagenette_pruned_quant.onnx\"\n", + "quantized_onnx_path = os.path.join(save_dir, quant_onnx_graph_name)\n", "\n", "exporter = ModuleExporter(model, output_dir=save_dir)\n", "exporter.export_pytorch(name=\"resnet50_imagenette_pruned_qat.pth\")\n", "exporter.export_onnx(\n", - " torch.randn(1, 3, 224, 224), name=qat_onnx_graph_name, convert_qat=True\n", + " torch.randn(1, 3, 224, 224), name=quant_onnx_graph_name, convert_qat=True\n", ")\n", "\n", "print(f\"Sparse-Quantized ONNX model saved to {quantized_onnx_path}\")"