Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@

_dev_deps = [
"beautifulsoup4==4.9.3",
"black>=20.8b1",
"flake8>=3.8.3",
"isort>=5.7.0",
"black==21.5b2",
"flake8==3.9.2",
"isort==5.8.0",
"m2r2~=0.2.7",
"myst-parser~=0.14.0",
"rinohtype>=0.4.2",
"sphinx>=3.4.0",
"sphinx-copybutton>=0.3.0",
"sphinx-markdown-tables>=0.0.15",
"sphinx-multiversion==0.2.4",
"sphinx-pydantic>=0.1.0",
"sphinx-rtd-theme>=0.5.0",
"rinohtype~=0.4.2",
"sphinx~=3.5.0",
"sphinx-copybutton~=0.3.0",
"sphinx-markdown-tables~=0.0.15",
"sphinx-multiversion~=0.2.4",
"sphinx-pydantic~=0.1.0",
"sphinx-rtd-theme~=0.5.0",
"wheel>=0.36.2",
"pytest>=6.0.0",
"pytest-mock>=3.6.1",
"flaky>=3.0.0",
"pytest~=6.2.0",
"pytest-mock~=3.6.0",
"flaky~=3.7.0",
"sphinx-rtd-theme",
]

Expand Down Expand Up @@ -112,25 +112,35 @@ def _setup_extras() -> Dict:
}


_transformers_entry_point_template = (
"sparseml.transformers.train.{task}=sparseml.transformers.train.{task}:main"
)


def _setup_entry_points() -> Dict:
return {
entry_points = {
"console_scripts": [
# sparsification
"sparseml.benchmark=sparseml.benchmark.info:_main",
"sparseml.framework=sparseml.framework.info:_main",
"sparseml.sparsification=sparseml.sparsification.info:_main",
_transformers_entry_point_template.format(task="question_answering"),
_transformers_entry_point_template.format(task="text_classification"),
_transformers_entry_point_template.format(task="token_classification"),
_transformers_entry_point_template.format(task="language_modeling"),
"sparseml.transformers.export_onnx=sparseml.transformers.utils.export:main",
]
}

# transformers integration
for task in [
"question_answering",
"text_classification",
"token_classification",
]:
entry_points["console_scripts"].extend(
[
f"sparseml.transformers.{task}=sparseml.transformers.{task}:main",
f"sparseml.transformers.train.{task}=sparseml.transformers.{task}:main",
]
)

entry_points["console_scripts"].append(
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
)

return entry_points


def _setup_long_description() -> Tuple[str, str]:
return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
Expand Down
7 changes: 4 additions & 3 deletions src/sparseml/keras/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
Also handles loading modifiers from yaml files
"""

from typing import List, Union
from typing import Any, Dict, List, Optional, Union

from tensorflow import Tensor

from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
from sparseml.keras.utils.compat import keras
from sparseml.keras.utils.logger import KerasLogger
from sparseml.optim import BaseManager, load_recipe_yaml_str
from sparseml.optim import BaseManager, load_recipe_yaml_str, parse_recipe_variables
from sparsezoo.objects import Recipe


Expand All @@ -41,7 +41,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
def from_yaml(
file_path: Union[str, Recipe],
add_modifiers: List[Modifier] = None,
**recipe_variables,
recipe_variables: Optional[Union[Dict[str, Any], str]] = None,
):
"""
Convenience function used to create the manager of multiple modifiers from a
Expand All @@ -59,6 +59,7 @@ def from_yaml(
with (i.e. num_epochs, init_lr)
:return: ScheduledModifierManager() created from the recipe file
"""
recipe_variables = parse_recipe_variables(recipe_variables)
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
modifiers = Modifier.load_list(yaml_str)
if add_modifiers:
Expand Down
60 changes: 59 additions & 1 deletion src/sparseml/optim/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
Helper functions for base Modifier and Manger utilities
"""

import json
import re
from typing import Any, Dict, Tuple, Union
from contextlib import suppress
from typing import Any, Dict, Optional, Tuple, Union

import yaml

Expand All @@ -32,6 +34,7 @@
"rewrite_recipe_yaml_string_with_classes",
"update_recipe_variables",
"evaluate_recipe_yaml_str_equations",
"parse_recipe_variables",
]


Expand Down Expand Up @@ -137,6 +140,61 @@ def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str:
return pattern.sub(r"!\g<class_name>", updated_yaml_str)


def parse_recipe_variables(
recipe_variables: Optional[Union[Dict[str, Any], str]] = None
) -> Dict[str, Any]:
"""
Parse input recipe_variables into a dictionary that can be used to overload
variables at the root of a recipe.
Supports dictionaries as well as parsing a string in either json or
csv key=value format

:param recipe_variables: the recipe_variables string or dictionary to parse
for variables used with overloading recipes
:return: the parsed recipe variables
"""
if not recipe_variables:
return {}

if isinstance(recipe_variables, Dict):
return recipe_variables

if not isinstance(recipe_variables, str):
raise ValueError(
f"recipe_args must be a string for parsing, given {recipe_variables}"
)

# assume json first, try and parse
with suppress(Exception):
recipe_variables = json.loads(recipe_variables)
return recipe_variables

# assume csv, and standardize to format key=val
orig_recipe_variables = recipe_variables
recipe_vars_str = recipe_variables.replace(":", "=")
recipe_variables = {}
for arg_val in recipe_vars_str.split(","):
vals = arg_val.split("=")
if len(vals) != 2:
raise ValueError(
"Improper key=val given in csv for recipe variables with value "
f"{arg_val} in {orig_recipe_variables}"
)
key = vals[0].strip()
if any(char in key for char in ["{", "!", "=", "}"]):
raise ValueError(
"Improper key given in csv for recipe variables with value "
f"{key} in {orig_recipe_variables}"
)
val = vals[1].strip()
with suppress(Exception):
# check if val should be a number, otherwise fall back on string
val = float(val)
recipe_variables[key] = val

return recipe_variables


def update_recipe_variables(recipe_yaml_str: str, variables: Dict[str, Any]) -> str:
"""
:param recipe_yaml_str: YAML string of a SparseML recipe
Expand Down
89 changes: 72 additions & 17 deletions src/sparseml/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@
from functools import cmp_to_key
from typing import List

from sparseml.optim.modifier import (
BaseModifier,
BaseObject,
BaseScheduled,
ModifierProp,
)
from sparseml.optim.modifier import BaseModifier, BaseObject, ModifierProp
from sparseml.sparsification.types import SparsificationTypes
from sparseml.utils import clean_path, create_parent_dirs


Expand All @@ -42,7 +38,7 @@ class BaseManager(BaseObject):
:param modifiers: the modifiers to wrap
"""

def __init__(self, modifiers: List[BaseScheduled], **kwargs):
def __init__(self, modifiers: List[BaseModifier], **kwargs):
super().__init__(**kwargs)
# sort modifiers by when they start and end so that later modifiers
# can overwrite in a deterministic order such as when initializing
Expand All @@ -57,44 +53,88 @@ def __del__(self):
def __str__(self) -> str:
return "\n".join(self.to_string_lines())

def __eq__(self, compare: object) -> bool:
return str(self) == str(compare)

@ModifierProp(serializable=False)
def modifiers(self) -> List[BaseScheduled]:
def modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe
"""
return self._modifiers

@ModifierProp(serializable=False)
def epoch_modifiers(self) -> List[BaseScheduled]:
def epoch_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that modify the
epoch range
"""
return [mod for mod in self._modifiers if "Epoch" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.epoch in mod.sparsification_types
]

@ModifierProp(serializable=False)
def learning_rate_modifiers(self) -> List[BaseScheduled]:
def learning_rate_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that modify the
LearningRate schedule
"""
return [mod for mod in self._modifiers if "LearningRate" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.learning_rate in mod.sparsification_types
]

@ModifierProp(serializable=False)
def pruning_modifiers(self) -> List[BaseScheduled]:
def pruning_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
model sparsity
"""
return [mod for mod in self._modifiers if "Pruning" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.pruning in mod.sparsification_types
]

@ModifierProp(serializable=False)
def quantization_modifiers(self) -> List[BaseScheduled]:
def quantization_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
model quantization
"""
return [mod for mod in self._modifiers if "Quantization" in str(type(mod))]
return [
mod
for mod in self._modifiers
if SparsificationTypes.quantization in mod.sparsification_types
]

@ModifierProp(serializable=False)
def distillation_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
Distillation
"""
return [
mod
for mod in self._modifiers
if SparsificationTypes.distillation in mod.sparsification_types
]

@ModifierProp(serializable=False)
def structured_modifiers(self) -> List[BaseModifier]:
"""
:return: list of all SparseML modifiers in the managed recipe that manage
structure changes to a model such as layer pruning, fitler pruning,
and quantization
"""
return [
mod
for mod in self._modifiers
if SparsificationTypes.structured in mod.sparsification_types
]

@ModifierProp(serializable=False)
def min_epochs(self) -> int:
Expand Down Expand Up @@ -154,7 +194,7 @@ def to_string_lines(self) -> List[str]:

return yaml_str_lines

def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]:
def modifiers_to_string_lines(self, modifiers: List[BaseModifier]) -> List[str]:
"""
:param modifiers: the modifiers to convert into string / yaml representation
for within the manage
Expand All @@ -176,3 +216,18 @@ def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]
yaml_str_lines.append("")

return yaml_str_lines

def qat_active(self, epoch: float) -> bool:
"""
:param epoch: the epoch to check if quantization aware training will be
active during
:return: True if quantization aware training will be active at the start
of or within the given epoch, False otherwise
"""
quant_modifiers = self.quantization_modifiers

return (
min(mod.start_epoch for mod in quant_modifiers) < epoch + 1
if quant_modifiers
else False
)
8 changes: 8 additions & 0 deletions src/sparseml/optim/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml

from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
from sparseml.sparsification.types import SparsificationTypes
from sparseml.utils import ALL_TOKEN, validate_str_iterable


Expand Down Expand Up @@ -466,6 +467,13 @@ def __repr__(self):
self.props(only_serializable=False, format_repr=True),
)

@ModifierProp(serializable=False)
def sparsification_types(self) -> List[SparsificationTypes]:
"""
:return: the sparsification types this modifier instance will apply
"""
return []

@ModifierProp(serializable=True)
def log_types(self) -> Union[None, str, List[str]]:
"""
Expand Down
Loading