From 056ca3d0520b938efcd095a9aa342610b1d4a842 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Thu, 13 May 2021 12:43:02 -0400 Subject: [PATCH] PyTorch modifier fixes - allow block_shape to be more than two dimensions since on export it changes to weight shape - move initialized check under try catch in delete for safety --- .../pytorch/optim/mask_creator_pruning.py | 16 +++++++++++++--- src/sparseml/pytorch/optim/modifier.py | 5 ++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/sparseml/pytorch/optim/mask_creator_pruning.py b/src/sparseml/pytorch/optim/mask_creator_pruning.py index c83259e4af7..428f55ab606 100644 --- a/src/sparseml/pytorch/optim/mask_creator_pruning.py +++ b/src/sparseml/pytorch/optim/mask_creator_pruning.py @@ -474,13 +474,23 @@ def __init__( block_shape: List[int], grouping_fn_name: str = "mean", ): - if len(block_shape) != 2: + if len(block_shape) < 2: raise ValueError( ( - "Invalid block_shape: {}" - " ,block_shape must have length == 2 for in and out channels" + "Invalid block_shape: {}, " + "block_shape must have length == 2 for in and out channels" ).format(block_shape) ) + + if len(block_shape) > 2 and not all([shape == 1 for shape in block_shape[2:]]): + # after in and out channels, only 1 can be used for other dimensions + raise ValueError( + ( + "Invalid block_shape: {}, " + "block_shape for indices not in [0, 1] must be equal to 1" + ).format(block_shape) + ) + self._block_shape = block_shape self._grouping_fn_name = grouping_fn_name diff --git a/src/sparseml/pytorch/optim/modifier.py b/src/sparseml/pytorch/optim/modifier.py index 55ac01bace4..ec9ef674fa6 100644 --- a/src/sparseml/pytorch/optim/modifier.py +++ b/src/sparseml/pytorch/optim/modifier.py @@ -104,10 +104,9 @@ def __init__(self, log_types: Union[str, List[str]] = None, **kwargs): self._loggers = None def __del__(self): - if not self.initialized: - return - try: + if not self.initialized: + return self.finalize() except Exception: pass