diff --git a/integrations/ultralytics/README.md b/integrations/ultralytics/README.md index 8ab4d3d8f21..ade7a4ae19a 100644 --- a/integrations/ultralytics/README.md +++ b/integrations/ultralytics/README.md @@ -119,9 +119,12 @@ cp sparseml/integrations/ultralytics/deepsparse/*.py yolov5 cd yolov5 # install deepsparse and server dependencies -pip install deepsparse flask flask-cors +pip install deepsparse sparseml flask flask-cors ``` +Note: on new Ubuntu systems, to install `cv2` running `sudo apt-get update && apt-get install -y python3-opencv` +may be necessary. + ### Benchmarking `benchmarking.py` is a script for benchmarking sparsified and quantized YOLOv3 @@ -130,8 +133,8 @@ performance with DeepSparse. For a full list of options run `python benchmarkin To run a benchmark run: ```bash python benchmark.py - zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_90 \ - --batch-size 32 \ + zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \ + --batch-size 1 \ --quantized-inputs ``` diff --git a/integrations/ultralytics/deepsparse/SERVER.md b/integrations/ultralytics/deepsparse/SERVER.md index 3cd0a19b582..04a7e1c1418 100644 --- a/integrations/ultralytics/deepsparse/SERVER.md +++ b/integrations/ultralytics/deepsparse/SERVER.md @@ -49,11 +49,14 @@ pip install deepsparse sparseml flask flask-cors ### Server -First, start up the host `server.py` with your model of choice. +First, start up the host `server.py` with your model of choice, SparseZoo stubs are +also supported. Example command: ```bash -python server.py ~/models/yolov3-pruned_quant.onnx +python server.py \ + zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \ + --quantized-inputs ``` You can leave that running as a detached process or in a spare terminal. diff --git a/integrations/ultralytics/deepsparse/benchmark.py b/integrations/ultralytics/deepsparse/benchmark.py index 13bc1841b28..ada02bb1afd 100644 --- a/integrations/ultralytics/deepsparse/benchmark.py +++ b/integrations/ultralytics/deepsparse/benchmark.py @@ -21,8 +21,9 @@ usage: benchmark.py [-h] [-e {deepsparse,onnxruntime,torch}] [--data-path DATA_PATH] [--image-shape IMAGE_SHAPE [IMAGE_SHAPE ...]] - [-b BATCH_SIZE] [-c NUM_CORES] [-i NUM_ITERATIONS] - [-w NUM_WARMUP_ITERATIONS] [-q] [--fp16] [--device DEVICE] + [-b BATCH_SIZE] [-c NUM_CORES] [-s NUM_SOCKETS] + [-i NUM_ITERATIONS] [-w NUM_WARMUP_ITERATIONS] [-q] + [--fp16] [--device DEVICE] model_filepath Benchmark sparsified YOLOv3 models @@ -53,7 +54,13 @@ The batch size to run the benchmark for -c NUM_CORES, --num-cores NUM_CORES The number of physical cores to run the benchmark on, - defaults to all physical cores available on the system + defaults to None where it uses all physical cores + available on the system. For DeepSparse benchmarks, + this value is the number of cores per socket + -s NUM_SOCKETS, --num-sockets NUM_SOCKETS + For DeepSparse benchmarks only. The number of physical + cores to run the benchmark on. Defaults to None where + is uses all sockets available on the system -i NUM_ITERATIONS, --num-iterations NUM_ITERATIONS The number of iterations the benchmark will be run for -w NUM_WARMUP_ITERATIONS, --num-warmup-iterations NUM_WARMUP_ITERATIONS @@ -110,14 +117,17 @@ from deepsparse import compile_model, cpu from deepsparse.benchmark import BenchmarkResults -from deepsparse_utils import load_image, postprocess_nms, pre_nms_postprocess +from deepsparse_utils import ( + YoloPostprocessor, + load_image, + modify_yolo_onnx_input_shape, + postprocess_nms, +) from sparseml.onnx.utils import override_model_batch_size from sparsezoo.models.detection import yolo_v3 as zoo_yolo_v3 from sparsezoo.utils import load_numpy_list -CORES_PER_SOCKET, AVX_TYPE, _ = cpu.cpu_details() - DEEPSPARSE_ENGINE = "deepsparse" ORT_ENGINE = "onnxruntime" TORCH_ENGINE = "torch" @@ -180,10 +190,22 @@ def parse_args(): "-c", "--num-cores", type=int, - default=CORES_PER_SOCKET, + default=None, help=( "The number of physical cores to run the benchmark on, " - "defaults to all physical cores available on the system" + "defaults to None where it uses all physical cores available on the system. " + "For DeepSparse benchmarks, this value is the number of cores per socket" + ), + ) + parser.add_argument( + "-s", + "--num-sockets", + type=int, + default=None, + help=( + "For DeepSparse benchmarks only. The number of physical cores to run the " + "benchmark on. Defaults to None where is uses all sockets available on the " + "system" ), ) parser.add_argument( @@ -227,7 +249,6 @@ def parse_args(): ) args = parser.parse_args() - if args.engine == TORCH_ENGINE and args.device is None: args.device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -273,30 +294,46 @@ def _load_model(args) -> Any: raise ValueError(f"half precision is not supported for {args.engine}") if args.quantized_inputs and args.engine == TORCH_ENGINE: raise ValueError(f"quantized inputs not supported for {args.engine}") - if args.num_cores != CORES_PER_SOCKET and args.engine == TORCH_ENGINE: + if args.num_cores is not None and args.engine == TORCH_ENGINE: raise ValueError( f"overriding default num_cores not supported for {args.engine}" ) if ( - args.num_cores != CORES_PER_SOCKET + args.num_cores is not None and args.engine == ORT_ENGINE and onnxruntime.__version__ < "1.7" ): - print( + raise ValueError( "overriding default num_cores not supported for onnxruntime < 1.7.0. " "If using an older build with OpenMP, try setting the OMP_NUM_THREADS " "environment variable" ) + if args.num_sockets is not None and args.engine != DEEPSPARSE_ENGINE: + raise ValueError(f"Overriding num_sockets is not supported for {args.engine}") + + # scale static ONNX graph to desired image shape + if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]: + args.model_filepath, _ = modify_yolo_onnx_input_shape( + args.model_filepath, args.image_shape + ) # load model if args.engine == DEEPSPARSE_ENGINE: print(f"Compiling deepsparse model for {args.model_filepath}") - model = compile_model(args.model_filepath, args.batch_size, args.num_cores) + model = compile_model( + args.model_filepath, args.batch_size, args.num_cores, args.num_sockets + ) + if args.quantized_inputs and not model.cpu_vnni: + print( + "WARNING: VNNI instructions not detected, " + "quantization speedup not well supported" + ) elif args.engine == ORT_ENGINE: print(f"loading onnxruntime model for {args.model_filepath}") sess_options = onnxruntime.SessionOptions() - sess_options.intra_op_num_threads = args.num_cores + if args.num_cores is not None: + sess_options.intra_op_num_threads = args.num_cores sess_options.log_severity_level = 3 sess_options.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -376,7 +413,7 @@ def _run_model( def benchmark_yolo(args): model = _load_model(args) print("Loading dataset") - dataset = _load_images(args.data_path, args.image_shape) + dataset = _load_images(args.data_path, tuple(args.image_shape)) total_iterations = args.num_iterations + args.num_warmup_iterations data_loader = _iter_batches(dataset, args.batch_size, total_iterations) @@ -388,6 +425,12 @@ def benchmark_yolo(args): flush=True, ) + postprocessor = ( + YoloPostprocessor(args.image_shape) + if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE] + else None + ) + results = BenchmarkResults() progress_bar = tqdm(total=args.num_iterations) @@ -403,8 +446,8 @@ def benchmark_yolo(args): outputs = _run_model(args, model, batch) # post-processing - if args.engine != TORCH_ENGINE: - outputs = pre_nms_postprocess(outputs) + if postprocessor: + outputs = postprocessor.pre_nms_postprocess(outputs) # NMS outputs = postprocess_nms(outputs) diff --git a/integrations/ultralytics/deepsparse/deepsparse_utils.py b/integrations/ultralytics/deepsparse/deepsparse_utils.py index 9eecb8bf22f..d2ce24bab1a 100644 --- a/integrations/ultralytics/deepsparse/deepsparse_utils.py +++ b/integrations/ultralytics/deepsparse/deepsparse_utils.py @@ -20,12 +20,17 @@ """ -from typing import List, Tuple, Union +from tempfile import NamedTemporaryFile +from typing import List, Optional, Tuple, Union import cv2 import numpy +import onnx import torch +from sparseml.onnx.utils import get_tensor_dim_shape, set_tensor_dim_shape +from sparsezoo import Zoo + # ultralytics/yolov5 imports from utils.general import non_max_suppression @@ -37,13 +42,6 @@ ] -def _get_grid(size: int) -> torch.Tensor: - # adapted from yolov5.yolo.Detect._make_grid - coords_y, coords_x = torch.meshgrid([torch.arange(size), torch.arange(size)]) - grid = torch.stack((coords_x, coords_y), 2) - return grid.view(1, 1, size, size, 2).float() - - # Yolo V3 specific variables _YOLO_V3_ANCHORS = [ torch.Tensor([[10, 13], [16, 30], [33, 23]]), @@ -51,8 +49,6 @@ def _get_grid(size: int) -> torch.Tensor: torch.Tensor([[116, 90], [156, 198], [373, 326]]), ] _YOLO_V3_ANCHOR_GRIDS = [t.clone().view(1, -1, 1, 1, 2) for t in _YOLO_V3_ANCHORS] -_YOLO_V3_OUTPUT_SHAPES = [80, 40, 20] -_YOLO_V3_GRIDS = [_get_grid(grid_size) for grid_size in _YOLO_V3_OUTPUT_SHAPES] def load_image( @@ -70,28 +66,53 @@ def load_image( return img -def pre_nms_postprocess(outputs: List[numpy.ndarray]) -> torch.Tensor: +class YoloPostprocessor: """ - :param outputs: raw outputs of a YOLOv3 model before anchor grid processing - :return: post-processed model outputs without NMS. - """ - # postprocess and transform raw outputs into single torch tensor - processed_outputs = [] - for idx, pred in enumerate(outputs): - pred = torch.from_numpy(pred) - pred = pred.sigmoid() + Class for performing postprocessing of YOLOv3 model predictions - # get grid and stride - grid = _YOLO_V3_GRIDS[idx] - anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx] - stride = 640 / _YOLO_V3_OUTPUT_SHAPES[idx] + :param image_size: size of input image to model. used to calculate stride based on + output shapes + """ - # decode xywh box values - pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride - pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid - # flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5) - processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1))) - return torch.cat(processed_outputs, 1) + def __init__(self, image_size: Tuple[int]): + self._image_size = image_size + self._grids = {} # Dict[Tuple[int], torch.Tensor] + + def pre_nms_postprocess(self, outputs: List[numpy.ndarray]) -> torch.Tensor: + """ + :param outputs: raw outputs of a YOLOv3 model before anchor grid processing + :return: post-processed model outputs without NMS. + """ + # postprocess and transform raw outputs into single torch tensor + processed_outputs = [] + for idx, pred in enumerate(outputs): + pred = torch.from_numpy(pred) + pred = pred.sigmoid() + + # get grid and stride + grid_shape = pred.shape[2:4] + grid = self._get_grid(grid_shape) + anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx] + stride = self._image_size[0] / grid_shape[0] + + # decode xywh box values + pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride + pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid + # flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5) + processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1))) + return torch.cat(processed_outputs, 1) + + def _get_grid(self, grid_shape: Tuple[int]) -> torch.Tensor: + if grid_shape not in self._grids: + # adapted from yolov5.yolo.Detect._make_grid + coords_y, coords_x = torch.meshgrid( + [torch.arange(grid_shape[0]), torch.arange(grid_shape[1])] + ) + grid = torch.stack((coords_x, coords_y), 2) + self._grids[grid_shape] = grid.view( + 1, 1, grid_shape[0], grid_shape[1], 2 + ).float() + return self._grids[grid_shape] def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]: @@ -102,3 +123,60 @@ def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]: # run nms in PyTorch, only post-process first output nms_outputs = non_max_suppression(outputs) return [output.cpu().numpy() for output in nms_outputs] + + +def modify_yolo_onnx_input_shape( + model_path: str, image_shape: Tuple[int] +) -> Tuple[str, Optional[NamedTemporaryFile]]: + """ + Creates a new YOLOv3 ONNX model from the given path that accepts the given input + shape. If the given model already has the given input shape no modifications are + made. Uses a tempfile to store the modified model file. + + :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model + to be loaded + :param image_shape: 2-tuple of the image shape to resize this yolo model to + :return: filepath to an onnx model reshaped to the given input shape will be the + original path if the shape is the same. Additionally returns the + NamedTemporaryFile for managing the scope of the object for file deletion + """ + original_model_path = model_path + if model_path.startswith("zoo:"): + # load SparseZoo Model from stub + model = Zoo.load_model_from_stub(model_path) + model_path = model.onnx_file.downloaded_path() + print(f"Downloaded {original_model_path} to {model_path}") + + model = onnx.load(model_path) + model_input = model.graph.input[0] + + initial_x = get_tensor_dim_shape(model_input, 2) + initial_y = get_tensor_dim_shape(model_input, 3) + + if not (isinstance(initial_x, int) and isinstance(initial_y, int)): + return model_path, None # model graph does not have static integer input shape + + if (initial_x, initial_y) == tuple(image_shape): + return model_path, None # no shape modification needed + + scale_x = initial_x / image_shape[0] + scale_y = initial_y / image_shape[1] + set_tensor_dim_shape(model_input, 2, image_shape[0]) + set_tensor_dim_shape(model_input, 3, image_shape[1]) + + for model_output in model.graph.output: + output_x = get_tensor_dim_shape(model_output, 2) + output_y = get_tensor_dim_shape(model_output, 3) + set_tensor_dim_shape(model_output, 2, int(output_x / scale_x)) + set_tensor_dim_shape(model_output, 3, int(output_y / scale_y)) + + tmp_file = NamedTemporaryFile() # file will be deleted after program exit + onnx.save(model, tmp_file.name) + + print( + f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n" + f"Original model path: {original_model_path}, new temporary model saved to " + f"{tmp_file.name}" + ) + + return tmp_file.name, tmp_file diff --git a/integrations/ultralytics/train.py b/integrations/ultralytics/train.py index ab4b474a001..c0d949246b6 100644 --- a/integrations/ultralytics/train.py +++ b/integrations/ultralytics/train.py @@ -104,7 +104,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check # Model - pretrained = weights.endswith('.pt') + pretrained = weights.endswith('.pt') or weights.endswith('.pth') # SparseML integration if pretrained: with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally @@ -269,7 +269,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): #################################################################################### from sparseml.pytorch.nn import replace_activations from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer - from sparseml.pytorch.utils import PythonLogger, TensorBoardLogger + from sparseml.pytorch.utils import is_parallel_model, PythonLogger, TensorBoardLogger if not opt.no_leaky_relu_override: # use LeakyReLU activations model = replace_activations(model, "lrelu", inplace=True) @@ -277,7 +277,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): manager = ScheduledModifierManager.from_yaml(opt.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, - model, + model if not is_parallel_model(model) else model.module, manager, steps_per_epoch=len(dataloader), loggers=[PythonLogger(), TensorBoardLogger(writer=tb_writer)] @@ -516,7 +516,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): dataloader=testloader, save_dir=save_dir, save_json=save_json, - plots=False) + plots=False, + half_precision=opt.use_amp) # SparseML integration ################################################################################# # Start SparseML ONNX Export ################################################################################# @@ -527,9 +528,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): logger.info( f"training complete, exporting ONNX to {onnx_path}" ) - model.model[-1].export = True # do not export grid post-procesing - exporter = ModuleExporter(model, save_dir) - exporter.export_onnx(torch.randn((1, 3, *imgsz)), convert_qat=True) + export_model = model.module if is_parallel_model(model) else model + export_model.model[-1].export = True # do not export grid post-procesing + exporter = ModuleExporter(export_model, save_dir) + exporter.export_onnx(torch.randn(1, 3, imgsz, imgsz), convert_qat=True) if qat: skip_onnx_input_quantize(onnx_path, onnx_path) ################################################################################# @@ -621,7 +623,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 set_logging(opt.global_rank) if opt.global_rank in [-1, 0]: - check_git_status() + # check_git_status() SparseML integration, will be out of sync with master check_requirements() #################################################################################### @@ -641,14 +643,14 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When --weights is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " - f"stub. sparseml-recipe was set to {args.sparseml_recipe}" + f"stub. sparseml-recipe was set to {opt.sparseml_recipe}" ) elif opt.weights.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(opt.weights) - args.initial_checkpoint = zoo_model.download_framework_files( + opt.weights = zoo_model.download_framework_files( extensions=[".pt", ".pth"] - ) + )[0] #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### diff --git a/src/sparseml/keras/optim/__init__.py b/src/sparseml/keras/optim/__init__.py index 1a1a3ae40e2..19346e0d876 100644 --- a/src/sparseml/keras/optim/__init__.py +++ b/src/sparseml/keras/optim/__init__.py @@ -25,5 +25,6 @@ from .modifier import * from .modifier_epoch import * from .modifier_lr import * +from .modifier_params import * from .modifier_pruning import * from .utils import * diff --git a/src/sparseml/keras/optim/modifier.py b/src/sparseml/keras/optim/modifier.py index 136fd0940fb..c858d81a8be 100644 --- a/src/sparseml/keras/optim/modifier.py +++ b/src/sparseml/keras/optim/modifier.py @@ -38,6 +38,7 @@ "ModifierProp", "KerasModifierYAML", "Modifier", + "ModifierProp", "ScheduledModifier", "ScheduledUpdateModifier", ] @@ -162,14 +163,6 @@ def __init__( **kwargs, ) - @property - def start_epoch(self): - return self._start_epoch - - @property - def end_epoch(self): - return self._end_epoch - def start_end_steps(self, steps_per_epoch, after_optim: bool) -> Tuple[int, int]: """ Calculate the start and end steps for this modifier given a certain diff --git a/src/sparseml/keras/utils/exporter.py b/src/sparseml/keras/utils/exporter.py index 92e3fb010ee..4cde5d10819 100644 --- a/src/sparseml/keras/utils/exporter.py +++ b/src/sparseml/keras/utils/exporter.py @@ -60,6 +60,7 @@ def export_onnx( opset: int = DEFAULT_ONNX_OPSET, doc_string: str = "", debug_mode: bool = True, + raise_on_tf_support: bool = True, **kwargs, ): """ @@ -74,6 +75,16 @@ def export_onnx( if keras2onnx_import_error is not None: raise keras2onnx_import_error + if raise_on_tf_support: + import tensorflow + v = tensorflow.__version__ + if v >= "2.3.0": + raise ValueError( + f"Tensorflow version {v} is greater than the currently supported " + "version for keras2onnx. Please downgrade the Tensorflow <2.3.0 " + "or set raise_on_tf_support to False to continue." + ) + model_name = self._model.name or name.split(".onnx")[0] onnx_model = keras2onnx.convert_keras( self._model, diff --git a/src/sparseml/onnx/utils/helpers.py b/src/sparseml/onnx/utils/helpers.py index 35a81a14c75..d128e21ea55 100644 --- a/src/sparseml/onnx/utils/helpers.py +++ b/src/sparseml/onnx/utils/helpers.py @@ -70,6 +70,8 @@ "get_kernel_shape", "calculate_flops", "get_quantize_parent_for_dequantize_node", + "get_tensor_dim_shape", + "set_tensor_dim_shape", ] @@ -1189,3 +1191,23 @@ def get_quantize_parent_for_dequantize_node( input_nodes = get_node_input_nodes(quantized_model, curr_node) curr_node = input_nodes[0] if input_nodes else None return curr_node + + +def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int: + """ + :param tensor: ONNX tensor to get the shape of a dimension of + :param dim: dimension index of the tensor to get the shape of + :return: shape of the tensor at the given dimension + """ + return tensor.type.tensor_type.shape.dim[dim].dim_value + + +def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): + """ + Sets the shape of the tensor at the given dimension to the given value + + :param tensor: ONNX tensor to modify the shape of + :param dim: dimension index of the tensor to modify the shape of + :param value: new shape for the given dimension + """ + tensor.type.tensor_type.shape.dim[dim].dim_value = value diff --git a/src/sparseml/pytorch/utils/loss.py b/src/sparseml/pytorch/utils/loss.py index 65d9aaf288d..ac8aae1c2bb 100644 --- a/src/sparseml/pytorch/utils/loss.py +++ b/src/sparseml/pytorch/utils/loss.py @@ -220,7 +220,7 @@ def __init__( class InceptionCrossEntropyLossWrapper(LossWrapper): """ - Loss wrapper for training an inception model that as an aux output + Loss wrapper for training an inception model that has an aux output with cross entropy. Defines the loss in the following way: diff --git a/tests/sparseml/keras/optim/test_modifier.py b/tests/sparseml/keras/optim/test_modifier.py new file mode 100644 index 00000000000..04cafb791bb --- /dev/null +++ b/tests/sparseml/keras/optim/test_modifier.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, List, Union + +import pytest + +from sparseml.keras.optim import ( + KerasModifierYAML, + Modifier, + ScheduledModifier, + ScheduledUpdateModifier, +) +from sparseml.keras.utils import keras +from sparseml.utils import KERAS_FRAMEWORK +from tests.sparseml.keras.optim.mock import mnist_model +from tests.sparseml.optim.test_modifier import ( + BaseModifierTest, + BaseScheduledTest, + BaseUpdateTest, +) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +class ModifierTest(BaseModifierTest): + # noinspection PyMethodOverriding + def test_constructor( + self, + modifier_lambda: Callable[[], Modifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_constructor(modifier_lambda, framework=KERAS_FRAMEWORK) + + # noinspection PyMethodOverriding + def test_yaml( + self, + modifier_lambda: Callable[[], Modifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_yaml(modifier_lambda, framework=KERAS_FRAMEWORK) + + # noinspection PyMethodOverriding + def test_yaml_key( + self, + modifier_lambda: Callable[[], Modifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_yaml_key(modifier_lambda, framework=KERAS_FRAMEWORK) + + # noinspection PyMethodOverriding + def test_repr( + self, + modifier_lambda: Callable[[], Modifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_repr(modifier_lambda, framework=KERAS_FRAMEWORK) + + # noinspection PyMethodOverriding + def test_props( + self, + modifier_lambda: Callable[[], Modifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_props(modifier_lambda, framework=KERAS_FRAMEWORK) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +class ScheduledModifierTest(ModifierTest, BaseScheduledTest): + # noinspection PyMethodOverriding + def test_props_start( + self, + modifier_lambda: Callable[[], ScheduledModifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_props_start(modifier_lambda, framework=KERAS_FRAMEWORK) + + # noinspection PyMethodOverriding + def test_props_end( + self, + modifier_lambda: Callable[[], ScheduledModifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_props_end(modifier_lambda, framework=KERAS_FRAMEWORK) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +class ScheduledUpdateModifierTest(ScheduledModifierTest, BaseUpdateTest): + # noinspection PyMethodOverriding + def test_props_frequency( + self, + modifier_lambda: Callable[[], ScheduledUpdateModifier], + model_lambda: Callable[[], keras.models.Model], + steps_per_epoch: int, + ): + super().test_props_frequency(modifier_lambda, framework=KERAS_FRAMEWORK) + + +@KerasModifierYAML() +class ModifierImpl(Modifier): + def __init__(self, log_types: Union[str, List[str]] = ["python"]): + super().__init__(log_types) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +@pytest.mark.parametrize("modifier_lambda", [ModifierImpl], scope="function") +@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function") +@pytest.mark.parametrize("steps_per_epoch", [100], scope="function") +class TestModifierImpl(ModifierTest): + pass + + +@KerasModifierYAML() +class ScheduledModifierImpl(ScheduledModifier): + def __init__( + self, + log_types: Union[str, List[str]] = ["python"], + end_epoch: float = -1.0, + start_epoch: float = -1.0, + ): + super().__init__(log_types) + + +@pytest.mark.parametrize("modifier_lambda", [ScheduledModifierImpl], scope="function") +@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function") +@pytest.mark.parametrize("steps_per_epoch", [100], scope="function") +class TestScheduledModifierImpl(ScheduledModifierTest): + pass + + +@KerasModifierYAML() +class ScheduledUpdateModifierImpl(ScheduledUpdateModifier): + def __init__( + self, + log_types: Union[str, List[str]] = ["python"], + end_epoch: float = -1.0, + start_epoch: float = -1.0, + update_frequency: float = -1, + ): + super().__init__(log_types) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +@pytest.mark.parametrize( + "modifier_lambda", [ScheduledUpdateModifierImpl], scope="function" +) +@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function") +@pytest.mark.parametrize("steps_per_epoch", [100], scope="function") +class TestScheduledUpdateModifierImpl(ScheduledUpdateModifierTest): + pass diff --git a/tests/sparseml/keras/optim/test_modifier_epoch.py b/tests/sparseml/keras/optim/test_modifier_epoch.py new file mode 100644 index 00000000000..2d0c66d121f --- /dev/null +++ b/tests/sparseml/keras/optim/test_modifier_epoch.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from sparseml.keras.optim import EpochRangeModifier +from tests.sparseml.keras.optim.mock import mnist_model +from tests.sparseml.keras.optim.test_modifier import ScheduledModifierTest + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +@pytest.mark.parametrize( + "modifier_lambda", + [lambda: EpochRangeModifier(0.0, 10.0), lambda: EpochRangeModifier(5.0, 15.0)], + scope="function", +) +@pytest.mark.parametrize("model_lambda", [mnist_model], scope="function") +@pytest.mark.parametrize("steps_per_epoch", [100], scope="function") +class TestEpochRangeModifierImpl(ScheduledModifierTest): + pass + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_KERAS_TESTS", False), + reason="Skipping keras tests", +) +def test_epoch_range_yaml(): + start_epoch = 5.0 + end_epoch = 15.0 + yaml_str = """ + !EpochRangeModifier + start_epoch: {start_epoch} + end_epoch: {end_epoch} + """.format( + start_epoch=start_epoch, end_epoch=end_epoch + ) + yaml_modifier = EpochRangeModifier.load_obj(yaml_str) # type: EpochRangeModifier + serialized_modifier = EpochRangeModifier.load_obj( + str(yaml_modifier) + ) # type: EpochRangeModifier + obj_modifier = EpochRangeModifier(start_epoch=start_epoch, end_epoch=end_epoch) + + assert isinstance(yaml_modifier, EpochRangeModifier) + assert ( + yaml_modifier.start_epoch + == serialized_modifier.start_epoch + == obj_modifier.start_epoch + ) + assert ( + yaml_modifier.end_epoch + == serialized_modifier.end_epoch + == obj_modifier.end_epoch + )