Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def _setup_entry_points() -> Dict:
"question_answering",
"text_classification",
"token_classification",
"language_modeling",
]:
entry_points["console_scripts"].extend(
[
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/pytorch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .mfac_helpers import *
from .model import *
from .module import *
from .sparsification import *
from .ssd_helpers import *
from .yolo_helpers import *

Expand Down
73 changes: 66 additions & 7 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,23 @@
import torch
from torch import Tensor
from torch.nn import Linear, Module, Parameter
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader


try:
quant_err = None
from torch.nn.qat import Conv2d as QATConv2d
from torch.nn.qat import Conv3d as QATConv3d
from torch.nn.qat import Linear as QATLinear
from torch.quantization import QuantWrapper
except Exception:
except Exception as _err:
quant_err = _err
QuantWrapper = None
QATLinear = None
QATConv2d = None
QATConv3d = None

from sparseml.utils import create_dirs, save_numpy

Expand Down Expand Up @@ -64,6 +72,7 @@
"get_conv_layers",
"get_linear_layers",
"get_prunable_layers",
"get_quantizable_layers",
"get_named_layers_and_params_by_regex",
"any_str_or_regex_matches_param_name",
"NamedLayerParam",
Expand Down Expand Up @@ -751,13 +760,63 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]:
:return: a list containing the names and modules of the prunable layers
(Linear, ConvNd)
"""
layers = []
return [
(name, mod)
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, _ConvNd)
or (QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
or (QATConv3d and isinstance(mod, QATConv3d))
)
]


def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantizable layers from
:return: a list containing the names and modules of the quantizable layers
(Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
"PyTorch version is not setup for Quantization. "
"Please install a QAT compatible version of PyTorch"
)

return [
(name, mod)
for (name, mod) in module.named_modules()
if (
isinstance(mod, Linear)
or isinstance(mod, Conv2d)
or isinstance(mod, Conv3d)
)
]

for name, mod in module.named_modules():
if isinstance(mod, Linear) or isinstance(mod, _ConvNd):
layers.append((name, mod))

return layers
def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]:
"""
:param module: the module to get the quantized layers from
:return: a list containing the names and modules of the quantized layers
(Linear, Conv2d, Conv3d)
"""
if QATLinear is None:
raise ImportError(
"PyTorch version is not setup for Quantization. "
"Please install a QAT compatible version of PyTorch"
)

return [
(name, mod)
for (name, mod) in module.named_modules()
if (
(QATLinear and isinstance(mod, QATLinear))
or (QATConv2d and isinstance(mod, QATConv2d))
or (QATConv3d and isinstance(mod, QATConv3d))
)
]


def get_layer_param(param: str, layer: str, module: Module) -> Parameter:
Expand Down
169 changes: 169 additions & 0 deletions src/sparseml/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper functions for retrieving information related to model sparsification
"""

import json
from typing import Dict

import torch
from torch.nn import Module

from sparseml.pytorch.utils.helpers import (
get_prunable_layers,
get_quantizable_layers,
get_quantized_layers,
tensor_sparsity,
)


__all__ = ["ModuleSparsificationInfo"]


class ModuleSparsificationInfo:
"""
Helper class for providing information related to torch Module parameters
and the amount of sparsification applied. Includes information for pruning
and quantization

:param module: torch Module to analyze
"""

def __init__(self, module: Module):
self.module = module
self.trainable_params = list(
filter(lambda param: param.requires_grad, self.module.parameters())
)

def __str__(self):
return json.dumps(
{
"params_summary": {
"total": self.params_total,
"sparse": self.params_sparse,
"sparsity_percent": self.params_sparse_percent,
"prunable": self.params_prunable_total,
"prunable_sparse": self.params_prunable_sparse,
"prunable_sparsity_percent": self.params_prunable_sparse_percent,
"quantizable": self.params_quantizable,
"quantized": self.params_quantized,
"quantized_percent": self.params_quantized_percent,
},
"params_info": self.params_info,
}
)

@property
def params_total(self) -> int:
"""
:return: total number of trainable parameters in the model
"""
return sum(torch.numel(param) for param in self.trainable_params)

@property
def params_sparse(self) -> int:
"""
:return: total number of sparse (0) trainable parameters in the model
"""
return sum(
round(tensor_sparsity(param).item() * torch.numel(param))
for param in self.trainable_params
)

@property
def params_sparse_percent(self) -> float:
"""
:return: percent of sparsified parameters in the entire model
"""
return self.params_sparse / float(self.params_total) * 100

@property
def params_prunable_total(self) -> int:
"""
:return: total number of parameters across prunable layers
"""
return sum(
torch.numel(layer.weight)
for (name, layer) in get_prunable_layers(self.module)
)

@property
def params_prunable_sparse(self) -> int:
"""
:return: total number of sparse (0) parameters across prunable lauyers
"""
return sum(
round(tensor_sparsity(layer.weight).item() * torch.numel(layer.weight))
for (name, layer) in get_prunable_layers(self.module)
)

@property
def params_prunable_sparse_percent(self) -> float:
"""
:return: percent of prunable parameters that have been pruned
"""
return self.params_prunable_sparse / float(self.params_prunable_total) * 100

@property
def params_quantizable(self) -> int:
"""
:return: number of parameters that are included in quantizable layers
"""
return sum(
torch.numel(layer.weight)
+ (
torch.numel(layer.bias)
if hasattr(layer, "bias") and layer.bias is not None
else 0
)
for (name, layer) in get_quantizable_layers(self.module)
)

@property
def params_quantized(self) -> int:
"""
:return: number of parameters across quantized layers
"""
return sum(
torch.numel(layer.weight)
+ (
torch.numel(layer.bias)
if hasattr(layer, "bias") and layer.bias is not None
else 0
)
for (name, layer) in get_quantized_layers(self.module)
)

@property
def params_quantized_percent(self) -> float:
"""
:return: percentage of parameters that have been quantized
"""
return self.params_quantized / float(self.params_quantizable) * 100

@property
def params_info(self) -> Dict[str, Dict]:
"""
:return: dict of parameter name to its sparsification information
"""
return {
f"{name}.weight": {
"numel": torch.numel(layer.weight),
"sparsity": tensor_sparsity(layer.weight).item(),
"quantized": hasattr(layer, "weight_fake_quant"),
}
for (name, layer) in get_prunable_layers(self.module)
}
11 changes: 9 additions & 2 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ def _load_task_model(task: str, model_path: str, config: Any) -> Module:
return SparseAutoModel.masked_language_modeling_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
)

if task == "question-answering" or task == "qa":
return SparseAutoModel.question_answering_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
)

if (
Expand All @@ -97,12 +99,14 @@ def _load_task_model(task: str, model_path: str, config: Any) -> Module:
return SparseAutoModel.text_classification_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
)

if task == "token-classification" or task == "ner":
return SparseAutoModel.token_classification_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
)

raise ValueError(f"unrecognized task given of {task}")
Expand Down Expand Up @@ -177,8 +181,11 @@ def export_transformer_to_onnx(
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
).data # Dict[Tensor]
inputs_shapes = {
key: f"{type(val)}({val.shape if hasattr(val, 'shape') else 'unknown'})"
for key, val in inputs
key: (
f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: "
f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}"
)
for key, val in inputs.items()
}
_LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}")

Expand Down
Loading