Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sparseml/keras/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
from .modifier import *
from .modifier_epoch import *
from .modifier_lr import *
from .modifier_params import *
from .modifier_pruning import *
from .utils import *
9 changes: 1 addition & 8 deletions src/sparseml/keras/optim/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"ModifierProp",
"KerasModifierYAML",
"Modifier",
"ModifierProp",
"ScheduledModifier",
"ScheduledUpdateModifier",
]
Expand Down Expand Up @@ -162,14 +163,6 @@ def __init__(
**kwargs,
)

@property
def start_epoch(self):
return self._start_epoch

@property
def end_epoch(self):
return self._end_epoch

def start_end_steps(self, steps_per_epoch, after_optim: bool) -> Tuple[int, int]:
"""
Calculate the start and end steps for this modifier given a certain
Expand Down
183 changes: 183 additions & 0 deletions tests/sparseml/keras/optim/test_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Callable, List, Union

import pytest

from sparseml.keras.optim import (
KerasModifierYAML,
Modifier,
ScheduledModifier,
ScheduledUpdateModifier,
)
from sparseml.keras.utils import keras
from sparseml.utils import KERAS_FRAMEWORK
from tests.sparseml.keras.optim.mock import mnist_model
from tests.sparseml.optim.test_modifier import (
BaseModifierTest,
BaseScheduledTest,
BaseUpdateTest,
)


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
class ModifierTest(BaseModifierTest):
# noinspection PyMethodOverriding
def test_constructor(
self,
modifier_lambda: Callable[[], Modifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_constructor(modifier_lambda, framework=KERAS_FRAMEWORK)

# noinspection PyMethodOverriding
def test_yaml(
self,
modifier_lambda: Callable[[], Modifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_yaml(modifier_lambda, framework=KERAS_FRAMEWORK)

# noinspection PyMethodOverriding
def test_yaml_key(
self,
modifier_lambda: Callable[[], Modifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_yaml_key(modifier_lambda, framework=KERAS_FRAMEWORK)

# noinspection PyMethodOverriding
def test_repr(
self,
modifier_lambda: Callable[[], Modifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_repr(modifier_lambda, framework=KERAS_FRAMEWORK)

# noinspection PyMethodOverriding
def test_props(
self,
modifier_lambda: Callable[[], Modifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_props(modifier_lambda, framework=KERAS_FRAMEWORK)


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
class ScheduledModifierTest(ModifierTest, BaseScheduledTest):
# noinspection PyMethodOverriding
def test_props_start(
self,
modifier_lambda: Callable[[], ScheduledModifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_props_start(modifier_lambda, framework=KERAS_FRAMEWORK)

# noinspection PyMethodOverriding
def test_props_end(
self,
modifier_lambda: Callable[[], ScheduledModifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_props_end(modifier_lambda, framework=KERAS_FRAMEWORK)


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
class ScheduledUpdateModifierTest(ScheduledModifierTest, BaseUpdateTest):
# noinspection PyMethodOverriding
def test_props_frequency(
self,
modifier_lambda: Callable[[], ScheduledUpdateModifier],
model_lambda: Callable[[], keras.models.Model],
steps_per_epoch: int,
):
super().test_props_frequency(modifier_lambda, framework=KERAS_FRAMEWORK)


@KerasModifierYAML()
class ModifierImpl(Modifier):
def __init__(self, log_types: Union[str, List[str]] = ["python"]):
super().__init__(log_types)


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
@pytest.mark.parametrize("modifier_lambda", [ModifierImpl], scope="function")
@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function")
@pytest.mark.parametrize("steps_per_epoch", [100], scope="function")
class TestModifierImpl(ModifierTest):
pass


@KerasModifierYAML()
class ScheduledModifierImpl(ScheduledModifier):
def __init__(
self,
log_types: Union[str, List[str]] = ["python"],
end_epoch: float = -1.0,
start_epoch: float = -1.0,
):
super().__init__(log_types)


@pytest.mark.parametrize("modifier_lambda", [ScheduledModifierImpl], scope="function")
@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function")
@pytest.mark.parametrize("steps_per_epoch", [100], scope="function")
class TestScheduledModifierImpl(ScheduledModifierTest):
pass


@KerasModifierYAML()
class ScheduledUpdateModifierImpl(ScheduledUpdateModifier):
def __init__(
self,
log_types: Union[str, List[str]] = ["python"],
end_epoch: float = -1.0,
start_epoch: float = -1.0,
update_frequency: float = -1,
):
super().__init__(log_types)


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
@pytest.mark.parametrize(
"modifier_lambda", [ScheduledUpdateModifierImpl], scope="function"
)
@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function")
@pytest.mark.parametrize("steps_per_epoch", [100], scope="function")
class TestScheduledUpdateModifierImpl(ScheduledUpdateModifierTest):
pass
69 changes: 69 additions & 0 deletions tests/sparseml/keras/optim/test_modifier_epoch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from sparseml.keras.optim import EpochRangeModifier
from tests.sparseml.keras.optim.mock import mnist_model
from tests.sparseml.keras.optim.test_modifier import ScheduledModifierTest


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
@pytest.mark.parametrize(
"modifier_lambda",
[lambda: EpochRangeModifier(0.0, 10.0), lambda: EpochRangeModifier(5.0, 15.0)],
scope="function",
)
@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function")
@pytest.mark.parametrize("steps_per_epoch", [100], scope="function")
class TestEpochRangeModifierImpl(ScheduledModifierTest):
pass


@pytest.mark.skipif(
os.getenv("NM_ML_SKIP_KERAS_TESTS", False),
reason="Skipping keras tests",
)
def test_epoch_range_yaml():
start_epoch = 5.0
end_epoch = 15.0
yaml_str = """
!EpochRangeModifier
start_epoch: {start_epoch}
end_epoch: {end_epoch}
""".format(
start_epoch=start_epoch, end_epoch=end_epoch
)
yaml_modifier = EpochRangeModifier.load_obj(yaml_str) # type: EpochRangeModifier
serialized_modifier = EpochRangeModifier.load_obj(
str(yaml_modifier)
) # type: EpochRangeModifier
obj_modifier = EpochRangeModifier(start_epoch=start_epoch, end_epoch=end_epoch)

assert isinstance(yaml_modifier, EpochRangeModifier)
assert (
yaml_modifier.start_epoch
== serialized_modifier.start_epoch
== obj_modifier.start_epoch
)
assert (
yaml_modifier.end_epoch
== serialized_modifier.end_epoch
== obj_modifier.end_epoch
)