diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py index abaa5210de2..c343d7e6a03 100644 --- a/nncf/common/quantization/structs.py +++ b/nncf/common/quantization/structs.py @@ -136,7 +136,8 @@ def __init__(self, num_bits: int, mode: QuantizationMode, signedness_to_force: bool, narrow_range: bool, - half_range: bool): + half_range: bool, + compression_lr_multiplier: Optional[float] = None): """ :param num_bits: Bitwidth of the quantization. :param mode: The mode of quantization (symmetric or asymmetric). @@ -147,23 +148,28 @@ def __init__(self, num_bits: int, naive case, False if all 2^`num_bits` quantizations should be used. :param half_range: If ``True`` effectively only a half of an quantizer range are used. False - the full range are used. + :compression_lr_multiplier: Multiplier for gradient values """ self.num_bits = num_bits self.mode = mode self.signedness_to_force = signedness_to_force self.narrow_range = narrow_range self.half_range = half_range + self.compression_lr_multiplier = compression_lr_multiplier def __eq__(self, other: 'QuantizerSpec'): return self.__dict__ == other.__dict__ @classmethod - def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'QuantizerSpec': + def from_config(cls, qconfig: QuantizerConfig, + narrow_range: bool, half_range: bool, + compression_lr_multiplier: Optional[float]) -> 'QuantizerSpec': return cls(qconfig.num_bits, qconfig.mode, qconfig.signedness_to_force, narrow_range, - half_range) + half_range, + compression_lr_multiplier) class QuantizationConstraints: diff --git a/nncf/tensorflow/functions.py b/nncf/tensorflow/functions.py index 2da9c106929..a8481e5c9a6 100644 --- a/nncf/tensorflow/functions.py +++ b/nncf/tensorflow/functions.py @@ -12,6 +12,7 @@ """ import tensorflow as tf +from typing import Callable @tf.function @@ -24,3 +25,15 @@ def st_threshold(input_): def grad(upstream): return upstream return tf.round(input_), grad + + +def get_id_with_multiplied_grad(grad_multiplier: float) -> Callable[[tf.Tensor], tf.Tensor]: + @tf.custom_gradient + def id_with_multiplied_grad(x): + def grad(upstream): + if grad_multiplier is None: + return upstream + return grad_multiplier * upstream + return x, grad + + return id_with_multiplied_grad diff --git a/nncf/tensorflow/quantization/algorithm.py b/nncf/tensorflow/quantization/algorithm.py index c86680d9cf9..4afbf66f78b 100644 --- a/nncf/tensorflow/quantization/algorithm.py +++ b/nncf/tensorflow/quantization/algorithm.py @@ -11,10 +11,7 @@ limitations under the License. """ from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from typing import Tuple +from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf @@ -394,9 +391,11 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup: half_range = self._get_half_range(qconfig) applied_saturation_fix = applied_saturation_fix or half_range + compression_lr_multiplier = self._get_compression_lr_multiplier() quantizer_spec = TFQuantizerSpec.from_config(qconfig, narrow_range=not half_range, - half_range=half_range) + half_range=half_range, + compression_lr_multiplier=compression_lr_multiplier) target_point = TFLayerWeight(layer_info.layer_name, weight_def.weight_attr_name) qpoint = TFQuantizationPoint(op_name, quantizer_spec, target_point) setup.add_quantization_point(qpoint) @@ -407,7 +406,9 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup: target_node_name = ip.target_node_name input_port_id = ip.input_port_id fake_quantize_name = self._get_fake_quantize_name(target_node_name, input_port_id) - quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False) + compression_lr_multiplier = self._get_compression_lr_multiplier() + quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False, + compression_lr_multiplier=compression_lr_multiplier) fake_quantize_layer = FakeQuantize( quantizer_spec, name=fake_quantize_name) @@ -565,6 +566,10 @@ def _get_fake_quantize_name(self, node_name: NNCFNodeName, input_port_id: int = def _get_quantizer_operation_name(self, layer_name, weight_attr_name): return f'{layer_name}_{weight_attr_name}_quantizer' + def _get_compression_lr_multiplier(self) -> Optional[float]: + return self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', + self.name) + class QuantizationController(BaseCompressionAlgorithmController): def __init__(self, target_model, config, op_names: List[str]): diff --git a/nncf/tensorflow/quantization/layers.py b/nncf/tensorflow/quantization/layers.py index b6ba49fb2ba..f4d5c518be4 100644 --- a/nncf/tensorflow/quantization/layers.py +++ b/nncf/tensorflow/quantization/layers.py @@ -14,6 +14,7 @@ import tensorflow as tf from nncf.common.quantization.structs import QuantizationMode +from nncf.tensorflow.functions import get_id_with_multiplied_grad from nncf.tensorflow.layers.custom_objects import NNCF_CUSTOM_OBJECTS from nncf.tensorflow.layers.custom_objects import NNCF_QUANTIZATION_OPERATONS from nncf.tensorflow.layers.operation import InputType @@ -33,7 +34,8 @@ def __init__(self, config: TFQuantizerSpec, data_format: str ='channels_last', * self._op_name = f'{self.name}_quantizer' self._quantizer = self._create_quantizer(config, self._op_name) - self._quantizer_weights = {} + self._quantizer_weights_dict = {} + self.id_with_multiplied_grad = get_id_with_multiplied_grad(config.compression_lr_multiplier) @property def num_bits(self): @@ -79,6 +81,24 @@ def enabled(self): def enabled(self, v): self._quantizer.enabled = v + @property + def _quantizer_weights(self): + res = {} + for k in self._quantizer_weights_dict: + res[k] = self.id_with_multiplied_grad(self._quantizer_weights_dict[k]) + return res + + @property + def _real_quantizer_weights(self): + res = {} + for k in self._quantizer_weights_dict: + res[k] = self._quantizer_weights_dict[k] + return res + + @_quantizer_weights.setter + def _quantizer_weights(self, value): + self._quantizer_weights_dict = value + def build(self, input_shape): self._quantizer_weights = self._quantizer.build( input_shape, InputType.INPUTS, self.name, self) @@ -91,7 +111,7 @@ def register_hook_pre_quantizer(self, hook): return self._quantizer.register_hook_pre_call(hook) def apply_range_initialization(self, min_values, max_values, min_range=0.1, eps=0.01): - self._quantizer.apply_range_initialization(self._quantizer_weights, min_values, max_values, min_range, eps) + self._quantizer.apply_range_initialization(self._real_quantizer_weights, min_values, max_values, min_range, eps) def _create_quantizer(self, qspec: TFQuantizerSpec, op_name: str) -> Quantizer: quantizer_cls = NNCF_QUANTIZATION_OPERATONS.get(qspec.mode) diff --git a/nncf/tensorflow/quantization/quantizers.py b/nncf/tensorflow/quantization/quantizers.py index 79e7ff657c0..8d9c77ef395 100644 --- a/nncf/tensorflow/quantization/quantizers.py +++ b/nncf/tensorflow/quantization/quantizers.py @@ -36,18 +36,21 @@ def __init__(self, num_bits: int, signedness_to_force: Optional[bool], narrow_range: bool, half_range: bool, - per_channel: bool): - super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range) + per_channel: bool, + compression_lr_multiplier: Optional[float] = None): + super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier) self.per_channel = per_channel @classmethod - def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'TFQuantizerSpec': + def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool, + compression_lr_multiplier: Optional[float] = None) -> 'TFQuantizerSpec': return cls(qconfig.num_bits, qconfig.mode, qconfig.signedness_to_force, narrow_range, half_range, - qconfig.per_channel) + qconfig.per_channel, + compression_lr_multiplier) def get_state(self) -> Dict[str, Any]: """ @@ -62,7 +65,8 @@ def get_state(self) -> Dict[str, Any]: 'signedness_to_force': self.signedness_to_force, 'narrow_range': self.narrow_range, 'half_range': self.half_range, - 'per_channel': self.per_channel + 'per_channel': self.per_channel, + 'compression_lr_multiplier': self.compression_lr_multiplier } @classmethod @@ -285,6 +289,7 @@ def __init__(self, name: str, qspec: TFQuantizerSpec): self.narrow_range = qspec.narrow_range self.signedness_to_force = qspec.signedness_to_force self._half_range = qspec.half_range + self.compression_lr_multiplier = qspec.compression_lr_multiplier @property def half_range(self): @@ -389,6 +394,7 @@ def get_config(self): 'narrow_range': self.narrow_range, 'half_range': self._half_range, 'per_channel': self.per_channel, + 'compression_lr_multiplier': self.compression_lr_multiplier, } config = { 'quantizer_spec': qspec_dict, @@ -404,12 +410,12 @@ def from_config(cls, config): signedness_to_force=qspec_dict['signedness_to_force'], narrow_range=qspec_dict['narrow_range'], half_range=qspec_dict['half_range'], - per_channel=qspec_dict['per_channel']) + per_channel=qspec_dict['per_channel'], + compression_lr_multiplier=qspec_dict['compression_lr_multiplier']) name = config['name'] return cls(name, qspec) - @NNCF_CUSTOM_OBJECTS.register() @NNCF_QUANTIZATION_OPERATONS.register(QuantizationMode.ASYMMETRIC) class AsymmetricQuantizer(Quantizer): @@ -419,6 +425,7 @@ def __init__(self, name: str, qspec: TFQuantizerSpec): self.narrow_range = qspec.narrow_range self.per_channel = qspec.per_channel self._half_range = qspec.half_range + self.compression_lr_multiplier = qspec.compression_lr_multiplier @property def half_range(self): @@ -519,6 +526,7 @@ def get_config(self): 'narrow_range': self.narrow_range, 'half_range': self._half_range, 'per_channel': self.per_channel, + 'compression_lr_multiplier': self.compression_lr_multiplier, } config = { 'quantizer_spec': qspec_dict, @@ -534,6 +542,7 @@ def from_config(cls, config): signedness_to_force=None, narrow_range=qspec_dict['narrow_range'], half_range=qspec_dict['half_range'], - per_channel=qspec_dict['per_channel']) + per_channel=qspec_dict['per_channel'], + compression_lr_multiplier=qspec_dict['compression_lr_multiplier']) name = config['name'] return cls(name, qspec) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 90c2a1b5d88..41f730d389a 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -59,17 +59,15 @@ def __init__(self, num_bits: int, half_range: bool, scale_shape: Tuple[int, ...], logarithm_scale: bool, - compression_lr_multiplier: float = None): - super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range) + compression_lr_multiplier: Optional[float] = None): + super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier) self.scale_shape = scale_shape self.logarithm_scale = logarithm_scale - self.compression_lr_multiplier = compression_lr_multiplier - @classmethod def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool, scale_shape: Tuple[int], - logarithm_scale: bool, compression_lr_multiplier: float) -> 'PTQuantizerSpec': + logarithm_scale: bool, compression_lr_multiplier: Optional[float] = None) -> 'PTQuantizerSpec': return cls(qconfig.num_bits, qconfig.mode, qconfig.signedness_to_force, diff --git a/tests/common/helpers.py b/tests/common/helpers.py index 2088a7fe61f..e5a4128fb8c 100644 --- a/tests/common/helpers.py +++ b/tests/common/helpers.py @@ -10,13 +10,19 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from pathlib import Path +import numpy as np import os import shutil import subprocess import sys +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Callable, List, TypeVar + +TensorType = TypeVar('TensorType') + TEST_ROOT = Path(__file__).absolute().parents[1] PROJECT_ROOT = TEST_ROOT.parent.absolute() EXAMPLES_DIR = PROJECT_ROOT / 'examples' @@ -113,3 +119,43 @@ def run_install_checks(venv_path, tmp_path, package_type, test_dir, install_type install_mode, package_type), check=True, shell=True, cwd=run_path) + + +class BaseTensorListComparator(ABC): + @classmethod + @abstractmethod + def _to_numpy(cls, tensor: TensorType) -> np.ndarray: + pass + + @classmethod + def _compare_tensor_lists(cls, test: List[TensorType], reference: List[TensorType], + assert_fn: Callable[[np.ndarray, np.ndarray], bool]): + assert len(test) == len(reference) + + for x, y in zip(test, reference): + x = cls._to_numpy(x) + y = cls._to_numpy(y) + assert_fn(x, y) + + @classmethod + def check_equal(cls, test: List[TensorType], reference: List[TensorType], rtol: float = 1e-1): + cls._compare_tensor_lists(test, reference, + lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)) + + @classmethod + def check_not_equal(cls, test: List[TensorType], reference: List[TensorType], rtol: float = 1e-4): + cls._compare_tensor_lists(test, reference, + lambda x, y: np.testing.assert_raises(AssertionError, + np.testing.assert_allclose, x, y, rtol=rtol)) + + @classmethod + def check_less(cls, test: List[TensorType], reference: List[TensorType], rtol=1e-4): + # cls.check_not_equal(test, reference, rtol=rtol) + cls._compare_tensor_lists(test, reference, np.testing.assert_array_less) + + @classmethod + def check_greater(cls, test: List[TensorType], reference: List[TensorType], rtol=1e-4): + # cls.check_not_equal(test, reference, rtol=rtol) + cls._compare_tensor_lists(test, reference, + lambda x, y: np.testing.assert_raises(AssertionError, + np.testing.assert_array_less, x, y)) diff --git a/tests/common/test_compression_lr_multiplier.py b/tests/common/test_compression_lr_multiplier.py new file mode 100644 index 00000000000..281d50d068a --- /dev/null +++ b/tests/common/test_compression_lr_multiplier.py @@ -0,0 +1,356 @@ +import copy +import pytest +import numpy as np + +from abc import ABC +from abc import abstractmethod +from typing import Callable, Dict, Generator, List, Optional, Tuple, TypeVar + +from nncf import NNCFConfig +from tests.common.helpers import BaseTensorListComparator +from tests.common.helpers import TensorType + +ParameterType = TypeVar('ParameterType') +GradientType = TypeVar('GradientType') +ModelType = TypeVar('ModelType') +DatasetType = TypeVar('DatasetType') + + +def get_config_algos(config: NNCFConfig) -> List[Dict]: + if isinstance(config['compression'], list): + algos = config['compression'] + else: + algos = [config['compression']] + return algos + + +def add_multiplier_to_config(config: NNCFConfig, + local_multiplier: Optional[float] = None, + global_multiplier: Optional[float] = None) -> NNCFConfig: + config = copy.deepcopy(config) + + if local_multiplier is not None: + algos = get_config_algos(config) + for algo in algos: + algo.update({ + 'compression_lr_multiplier': local_multiplier + }) + + if global_multiplier is not None: + config['compression_lr_multiplier'] = global_multiplier + + return config + + +def get_multipliers_from_config(config: NNCFConfig) -> Dict[str, float]: + algo_to_multipliers = {} + + algos = get_config_algos(config) + global_multiplier = config.get('compression_lr_multiplier', 1) + for algo in algos: + algo_name = algo['algorithm'] + algo_to_multipliers[algo_name] = algo.get('compression_lr_multiplier', global_multiplier) + + return algo_to_multipliers + + +def merge_configs(configs: List[NNCFConfig], use_algo_list: bool = True) -> NNCFConfig: + res_config = NNCFConfig({}) + algos = [] + + for source_config in configs: + source_config = copy.deepcopy(source_config) + algos.extend(get_config_algos(source_config)) + del source_config['compression'] + res_config.update(source_config) + + if not use_algo_list: + if len(algos) > 1: + raise Exception('If there is more than one algorithm ' + 'you could use only use_algo_list=True') + res_config['compression'] = algos[0] + else: + res_config['compression'] = algos + + res_config['model'] = 'merged_model' + return res_config + + +def get_quantization_config(sample_size: List[int]) -> NNCFConfig: + config = NNCFConfig() + config.update({ + 'model': 'basic_quantization_config', + + 'input_info': { + 'sample_size': sample_size, + }, + + 'compression': { + 'algorithm': 'quantization', + 'initializer': { + 'range': { + 'num_init_samples': 10, + }, + 'batchnorm_adaptation': { + 'num_bn_adaptation_samples': 0, + }, + }, + }, + }) + return config + + +def get_rb_sparsity_config(sample_size: List[int]) -> NNCFConfig: + config = NNCFConfig() + config.update({ + 'model': 'basic_rb_sparsity_config', + + 'input_info': { + 'sample_size': sample_size, + }, + + 'compression': { + 'algorithm': 'rb_sparsity', + 'sparsity_init': 0.1, + 'params': { + 'schedule': 'polynomial', + 'sparsity_target': 0.5, + 'sparsity_target_epoch': 1, + 'sparsity_freeze_epoch': 1, + }, + }, + }) + return config + + +def get_binarization_config(sample_size: List[int]) -> NNCFConfig: + config = NNCFConfig() + config.update({ + 'model': 'basic_binarization_config', + + 'input_info': { + 'sample_size': sample_size, + }, + + 'compression': { + 'algorithm': 'binarization', + 'mode': 'xnor', + 'params': { + 'activations_quant_start_epoch': 0, + 'weights_quant_start_epoch': 0, + }, + }, + }) + return config + + +def get_configs_building_params() -> List[Dict]: + res = [] + get_orig_config_fns = [get_quantization_config, get_rb_sparsity_config, get_binarization_config] + get_orig_config_fns = [get_quantization_config] + num_orig_configs = len(get_orig_config_fns) + + for global_multiplier in [0, 1, 10]: + res.append({ + 'get_orig_config_fns': get_orig_config_fns, + 'multipliers': [None] * num_orig_configs, + 'global_multiplier': global_multiplier, + 'use_algo_list': True + }) + + global_multiplier = 10 + multipliers = [global_multiplier * (1.1 ** i) for i in range(num_orig_configs)] + + res.append({ + 'get_orig_config_fns': get_orig_config_fns, + 'multipliers': multipliers, + 'global_multiplier': global_multiplier, + 'use_algo_list': True + }) + + for i in range(num_orig_configs): + cur_multipliers = copy.deepcopy(multipliers) + cur_multipliers[i] = None + res.append({ + 'get_orig_config_fns': get_orig_config_fns, + 'multipliers': cur_multipliers, + 'global_multiplier': None, + 'use_algo_list': True + }) + + for get_orig_config_fn in get_orig_config_fns: + for use_algo_list in [False, True]: + for global_multiplier, multiplier in [(11, 10), (11, None), (None, 10)]: + res.append({ + 'get_orig_config_fns': [get_orig_config_fn], + 'multipliers': [multiplier], + 'global_multiplier': global_multiplier, + 'use_algo_list': use_algo_list + }) + + return res + + +class BaseCompressionLRMultiplierTester(ABC): + ALGO_NAME_TO_PATH_MAP = {} + TensorListComparator: BaseTensorListComparator + + @pytest.fixture(name='configs_building_params', + params=get_configs_building_params()) + def configs_building_params_(self, request) -> Dict: + return request.param + + @pytest.fixture(name='ref_configs') + def ref_configs_(self, configs_building_params: Dict, sample_size: Tuple[int, ...]) -> List[NNCFConfig]: + return [get_ref_config_fn(sample_size) for get_ref_config_fn in configs_building_params['get_orig_config_fns']] + + @pytest.fixture(name='ref_config') + def ref_config_(self, ref_configs, configs_building_params) -> NNCFConfig: + return merge_configs(ref_configs, configs_building_params['use_algo_list']) + + @pytest.fixture(name='target_configs') + def target_configs_(self, ref_configs: List[NNCFConfig], configs_building_params: Dict) -> List[NNCFConfig]: + return [add_multiplier_to_config(config, local_multiplier=multiplier) + for config, multiplier in zip(ref_configs, configs_building_params['multipliers'])] + + @pytest.fixture(name='target_config') + def target_config_(self, target_configs: List[NNCFConfig], configs_building_params: Dict) -> NNCFConfig: + target_config = merge_configs(target_configs, configs_building_params['use_algo_list']) + return add_multiplier_to_config(target_config, global_multiplier=configs_building_params['global_multiplier']) + + @classmethod + @abstractmethod + def _get_layer_cls_and_params(cls, model: ModelType) -> Generator[Tuple[type, List[ParameterType]], None, None]: + pass + + @classmethod + def _get_params_grouped_by_algos(cls, model: ModelType) -> Dict[str, List[ParameterType]]: + algo_name_to_params = {} + for layer_cls, layer_params in cls._get_layer_cls_and_params(model): + if len(layer_params) == 0: + continue + cls_name = '.'.join([layer_cls.__module__, layer_cls.__name__]) + + for cur_algo_name, cur_algo_path in cls.ALGO_NAME_TO_PATH_MAP.items(): + if cur_algo_path in cls_name: + algo_name = cur_algo_name + break + else: + algo_name = 'regular' + + if algo_name not in algo_name_to_params: + algo_name_to_params[algo_name] = [] + algo_name_to_params[algo_name].extend(layer_params) + + return algo_name_to_params + + @classmethod + @abstractmethod + def _get_params_and_grads_after_training_steps(cls, model: ModelType, dataset: DatasetType, + num_steps: int = 1) -> Tuple[Dict[str, ParameterType], + Dict[str, GradientType]]: + pass + + def test_algos_in_config_add_params( + self, + get_ref_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + target_config: NNCFConfig + ): + algo_to_params, _algo_to_grads = \ + self._get_params_and_grads_after_training_steps(*get_ref_model_and_dataset(), num_steps=0) + algo_names = get_multipliers_from_config(target_config).keys() + + assert sorted(algo_to_params.keys()) == sorted(list(algo_names) + ['regular']) + + @classmethod + def _check_multipliers_in_config_multiplies_grads( + cls, + get_ref_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + get_target_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + multipliers: Dict[str, float] + ): + _ref_params, ref_grads = \ + cls._get_params_and_grads_after_training_steps(*get_ref_model_and_dataset()) + _target_params, target_grads = \ + cls._get_params_and_grads_after_training_steps(*get_target_model_and_dataset()) + + for algo in ref_grads: + cls.TensorListComparator.check_equal([multipliers[algo] * grad for grad in ref_grads[algo]], + target_grads[algo]) + + def test_multipliers_in_config_multiplies_grads( + self, + get_ref_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + get_target_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + target_config: NNCFConfig + ): + multipliers = get_multipliers_from_config(target_config) + multipliers['regular'] = 1 + + self._check_multipliers_in_config_multiplies_grads(get_ref_model_and_dataset, + get_target_model_and_dataset, + multipliers) + + @classmethod + def _check_zero_multiplier_freezes_training(cls, orig_params: List[ParameterType], params: List[ParameterType], + multiplier: float): + if multiplier == 0: + cls.TensorListComparator.check_equal(orig_params, params) + else: + cls.TensorListComparator.check_not_equal(orig_params, params) + + @classmethod + def _get_params_diff(cls, orig_params: List[ParameterType], params: List[ParameterType]) -> List[TensorType]: + param_diffs = [] + for param, orig_param in zip(params, orig_params): + param_diffs.append(np.abs(param - orig_param)) + return param_diffs + + @classmethod + def _check_multiplier_affects_training_speed(cls, orig_params: List[ParameterType], + ref_params: List[ParameterType], target_params: List[ParameterType], + multiplier: float): + assert len(ref_params) == len(orig_params) + assert len(target_params) == len(orig_params) + + ref_diff = cls._get_params_diff(ref_params, orig_params) + target_diff = cls._get_params_diff(target_params, orig_params) + + if pytest.approx(multiplier) == 1: + cls.TensorListComparator.check_equal(target_diff, ref_diff) + elif multiplier < 1: + cls.TensorListComparator.check_less(target_diff, ref_diff) + else: + cls.TensorListComparator.check_greater(target_diff, ref_diff) + + def _check_multipliers_in_config_affect_training_speed( + self, + get_ref_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + get_target_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + multipliers + ): + orig_params, _orig_grads = \ + self._get_params_and_grads_after_training_steps(*get_ref_model_and_dataset(), num_steps=0) + ref_params, _ref_grads = \ + self._get_params_and_grads_after_training_steps(*get_ref_model_and_dataset(), num_steps=1) + target_params, _target_grads = \ + self._get_params_and_grads_after_training_steps(*get_target_model_and_dataset(), num_steps=1) + + for algo in multipliers: + self._check_zero_multiplier_freezes_training(orig_params[algo], target_params[algo], multipliers[algo]) + self._check_multiplier_affects_training_speed( + orig_params[algo], ref_params[algo], target_params[algo], multipliers[algo] + ) + + def test_multipliers_in_config_affect_training_speed( + self, + get_ref_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + get_target_model_and_dataset: Callable[[], Tuple[ModelType, DatasetType]], + target_config: NNCFConfig, + ): + multipliers = get_multipliers_from_config(target_config) + multipliers['regular'] = 1 + + self._check_multipliers_in_config_affect_training_speed(get_ref_model_and_dataset, + get_target_model_and_dataset, + multipliers) diff --git a/tests/tensorflow/compressed_model.pb b/tests/tensorflow/compressed_model.pb new file mode 100644 index 00000000000..8d77471014d --- /dev/null +++ b/tests/tensorflow/compressed_model.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0597c58e97b963d37ea752ad62258c84fa392725d8901a964195488af209b99e +size 102969144 diff --git a/tests/tensorflow/helpers.py b/tests/tensorflow/helpers.py index 4de35c3f5a6..5660bcd8d62 100644 --- a/tests/tensorflow/helpers.py +++ b/tests/tensorflow/helpers.py @@ -16,11 +16,15 @@ import tensorflow as tf from nncf.common.compression import BaseCompressionAlgorithmController from tensorflow.python.ops.init_ops import Constant +from typing import Union from nncf import NNCFConfig from nncf.tensorflow.helpers.model_creation import create_compressed_model from examples.tensorflow.common.object_detection.datasets.builder import COCODatasetBuilder +from tests.common.helpers import BaseTensorListComparator + +TensorType = Union[tf.Tensor, np.ndarray] def get_conv_init_value(shape, value): @@ -86,6 +90,20 @@ def get_basic_n_conv_test_model(input_shape=(24, 24, 1), in_out_ch=((1, 3), (3, return tf.keras.Model(inputs=inputs, outputs=outputs) +def get_lenet_model(): + model = tf.keras.Sequential([ + tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=(32, 32, 1)), + tf.keras.layers.MaxPool2D(padding='valid'), + tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'), + tf.keras.layers.AveragePooling2D(), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(units=120, activation='relu'), + tf.keras.layers.Dense(units=84, activation='relu'), + tf.keras.layers.Dense(units=10, activation='softmax') + ]) + return model + + def create_compressed_model_and_algo_for_test(model, config, compression_state=None, force_no_init=False): assert isinstance(config, NNCFConfig) tf.keras.backend.clear_session() @@ -108,13 +126,6 @@ def create_conv(in_channels, out_channels, kernel_size, weight_init, bias_init, return conv_cls(**args) -def check_equal(test, reference, rtol=1e-4): - test = test.numpy() - reference = reference.numpy() - for i, (x, y) in enumerate(zip(test, reference)): - np.testing.assert_allclose(x, y, rtol=rtol, err_msg="Index: {}".format(i)) - - class MockCOCODatasetBuilder(COCODatasetBuilder): @property def num_examples(self): @@ -157,3 +168,15 @@ def get_op_by_cls(wrapper, cls): if isinstance(op, cls): return op return None + + +def to_numpy(tensor: TensorType) -> np.ndarray: + if isinstance(tensor, tf.Tensor): + return tensor.numpy() + return tensor + + +class TFTensorListComparator(BaseTensorListComparator): + @classmethod + def _to_numpy(cls, tensor: tf.Tensor) -> np.ndarray: + return to_numpy(tensor) diff --git a/tests/tensorflow/sparsity/magnitude/test_algorithm.py b/tests/tensorflow/sparsity/magnitude/test_algorithm.py index eb87e4caca7..229df79350a 100644 --- a/tests/tensorflow/sparsity/magnitude/test_algorithm.py +++ b/tests/tensorflow/sparsity/magnitude/test_algorithm.py @@ -21,7 +21,7 @@ from nncf.tensorflow.sparsity.magnitude.algorithm import MagnitudeSparsityController from nncf.tensorflow.sparsity.magnitude.functions import normed_magnitude from nncf.tensorflow.sparsity.magnitude.operation import BinaryMask -from tests.tensorflow.helpers import check_equal +from tests.tensorflow.helpers import TFTensorListComparator from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.helpers import get_basic_conv_test_model from tests.tensorflow.helpers import get_empty_config @@ -131,8 +131,8 @@ def test_magnitude_algo_set_binary_mask_on_forward(): sparse_model, compression_ctrl = create_compressed_model_and_algo_for_test(get_magnitude_test_model(), config) compression_ctrl.set_sparsity_level(0.3) - check_equal(ref_mask_1, sparse_model.layers[1].weights[-1]) - check_equal(ref_mask_2, sparse_model.layers[2].weights[-1]) + TFTensorListComparator.check_equal(ref_mask_1, sparse_model.layers[1].weights[-1]) + TFTensorListComparator.check_equal(ref_mask_2, sparse_model.layers[2].weights[-1]) def test_magnitude_algo_binary_masks_are_applied(): diff --git a/tests/tensorflow/sparsity/magnitude/test_helpers.py b/tests/tensorflow/sparsity/magnitude/test_helpers.py index 0ef009ce8a4..e5a51c4d6df 100644 --- a/tests/tensorflow/sparsity/magnitude/test_helpers.py +++ b/tests/tensorflow/sparsity/magnitude/test_helpers.py @@ -14,7 +14,8 @@ import tensorflow as tf from nncf import NNCFConfig -from tests.tensorflow.helpers import create_conv, check_equal +from tests.tensorflow.helpers import TFTensorListComparator +from tests.tensorflow.helpers import create_conv sub_tensor = tf.constant([[[[1., 0.], @@ -58,11 +59,11 @@ def test_magnitude_model_has_expected_params(): # OIHW -> HWIO ref_weights_2 = tf.transpose(ref_weights_2, (2, 3, 1, 0)) - check_equal(act_weights_1, ref_weights_1) - check_equal(act_weights_2, ref_weights_2) + TFTensorListComparator.check_equal(act_weights_1, ref_weights_1) + TFTensorListComparator.check_equal(act_weights_2, ref_weights_2) - check_equal(act_bias_1, tf.constant([-2., -2])) - check_equal(act_bias_2, tf.constant([0])) + TFTensorListComparator.check_equal(act_bias_1, tf.constant([-2., -2])) + TFTensorListComparator.check_equal(act_bias_2, tf.constant([0])) def get_basic_magnitude_sparsity_config(input_sample_size=None): diff --git a/tests/tensorflow/test_compression_lr_multiplier.py b/tests/tensorflow/test_compression_lr_multiplier.py new file mode 100644 index 00000000000..0d8c0ec35cc --- /dev/null +++ b/tests/tensorflow/test_compression_lr_multiplier.py @@ -0,0 +1,125 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import itertools +import copy +import numpy as np +import pytest +import tensorflow as tf +from typing import Callable, Dict, Generator, List, Optional, Tuple + +from nncf import NNCFConfig +from nncf.tensorflow import register_default_init_args +from tests.common.test_compression_lr_multiplier import BaseCompressionLRMultiplierTester +from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test +from tests.tensorflow.helpers import get_lenet_model +from tests.tensorflow.helpers import TFTensorListComparator + + +def create_initialized_model_and_dataset(config: NNCFConfig) -> Tuple[tf.keras.Model, tf.data.Dataset]: + tf.random.set_seed(42) + np.random.seed(42) + config = copy.deepcopy(config) + + def generate_dataset(): + while True: + # pylint: disable = no-value-for-parameter + yield np.random.randn(1, 32, 32, 1), tf.one_hot([np.random.randint(10)], 10) + + dataset = tf.data.Dataset.from_generator( + generate_dataset, output_signature=( + tf.TensorSpec(shape=(1, 32, 32, 1), dtype=tf.float32), + tf.TensorSpec(shape=(1, 10), dtype=tf.float32) + )) + + config = register_default_init_args(config, dataset, batch_size=1) + model, _algo = create_compressed_model_and_algo_for_test(get_lenet_model(), config) + return model, dataset + + +@pytest.fixture(name='sample_size') +def sample_size_(): + return [1, 32, 32, 1] + + +@pytest.fixture(name='get_ref_model_and_dataset') +def get_ref_model_and_dataset_(ref_config: NNCFConfig) -> Callable[[], Tuple[tf.keras.Model, tf.data.Dataset]]: + def f(): + return create_initialized_model_and_dataset(ref_config) + return f + + +@pytest.fixture(name='get_target_model_and_dataset') +def get_target_model_and_dataset_(target_config: NNCFConfig) -> Callable[[], Tuple[tf.keras.Model, tf.data.Dataset]]: + def f(): + return create_initialized_model_and_dataset(target_config) + return f + + +class TestTFCompressionLRMultiplier(BaseCompressionLRMultiplierTester): + ALGO_NAME_TO_PATH_MAP = { + 'quantization': 'nncf.tensorflow.quantization', + 'rb_sparsity': 'nncf.tensorflow.sparsity.rb', + 'binarization': 'nncf.tensorflow.binarization' + } + + TensorListComparator = TFTensorListComparator + + @classmethod + def _perform_model_training_steps(cls, model: tf.keras.Model, train_data: tf.data.Dataset, + num_steps: int = 1) -> Tuple[tf.keras.Model, Optional[tf.GradientTape]]: + tf.random.set_seed(42) + loss_obj = tf.keras.losses.mean_squared_error + optimizer = tf.keras.optimizers.SGD(lr=0.1) + + @tf.function + def train_step(inputs, labels): + with tf.GradientTape() as grad_tape: + predictions = model(inputs, training=True) + loss = loss_obj(labels, predictions) + + grads = grad_tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(grads, model.trainable_variables)) + return grads + + grads = None + for input_batch, label_batch in itertools.islice(train_data, num_steps): + grads = train_step(input_batch, label_batch) + keys = map(lambda v: v.ref(), model.trainable_variables) + grads = dict(zip(keys, grads)) + + return model, grads + + @classmethod + def _get_grads(cls, algo_to_params: Dict[str, List[tf.Tensor]], grads) -> Dict[str, List[tf.Tensor]]: + res = {} + for algo, params in algo_to_params.items(): + res[algo] = [] + if grads is not None: + for param in params: + res[algo].append(grads[param.ref()]) + return res + + @classmethod + def _get_layer_cls_and_params(cls, model: tf.keras.Model) -> Generator[Tuple[type, List[tf.Tensor]], None, None]: + for layer in model.layers: + params = layer.trainable_weights + yield layer.__class__, params + + @classmethod + def _get_params_and_grads_after_training_steps(cls, model: tf.keras.Model, dataset: tf.data.Dataset, + num_steps: int = 1) -> Tuple[Dict[str, List[tf.Tensor]], + Dict[str, List[tf.Tensor]]]: + model, grads = cls._perform_model_training_steps(model, dataset, num_steps) + params = cls._get_params_grouped_by_algos(model) + grads = cls._get_grads(params, grads) + return params, grads diff --git a/tests/tensorflow/test_helpers.py b/tests/tensorflow/test_helpers.py index 14e2d69e4cb..ede0630189a 100644 --- a/tests/tensorflow/test_helpers.py +++ b/tests/tensorflow/test_helpers.py @@ -13,7 +13,8 @@ import tensorflow as tf -from tests.tensorflow.helpers import get_basic_conv_test_model, check_equal +from tests.tensorflow.helpers import TFTensorListComparator +from tests.tensorflow.helpers import get_basic_conv_test_model def test_basic_model_has_expected_params(): @@ -28,8 +29,9 @@ def test_basic_model_has_expected_params(): act_bias = model.layers[1].weights[1] ref_bias = default_bias - check_equal(act_bias, ref_bias) - check_equal(act_weights, ref_weights) + raise Exception(act_bias, ref_bias) + TFTensorListComparator.check_equal(act_bias, ref_bias) + TFTensorListComparator.check_equal(act_weights, ref_weights) def test_basic_model_is_valid(): @@ -37,4 +39,4 @@ def test_basic_model_is_valid(): input_ = tf.ones([1, 4, 4, 1]) ref_output = tf.ones((1, 3, 3, 2)) * (-4) act_output = model(input_) - check_equal(ref_output, act_output) + TFTensorListComparator.check_equal(ref_output, act_output) diff --git a/tests/torch/binarization/test_functions.py b/tests/torch/binarization/test_functions.py index 75c60d327e6..8fa88111ac0 100644 --- a/tests/torch/binarization/test_functions.py +++ b/tests/torch/binarization/test_functions.py @@ -17,7 +17,8 @@ from torch.autograd import Variable from nncf.torch.binarization.layers import xnor_binarize_op, dorefa_binarize_op, activation_bin_scale_threshold_op -from tests.torch.helpers import check_equal, get_grads +from tests.torch.helpers import PTTensorListComparator +from tests.torch.helpers import get_grads # reference impl @@ -138,7 +139,7 @@ def test_binarize_weights_forward(self, _seed, input_size, weight_bin_type, use_ ref_value = ReferenceDOREFABinarize.forward(ref_input) test_value = dorefa_binarize_op(test_input) - check_equal(test_value, ref_value, rtol=1e-3) + PTTensorListComparator.check_equal(test_value, ref_value, rtol=1e-3) def test_binarize_activations_forward(self, _seed, input_size, use_cuda): if not torch.cuda.is_available() and use_cuda is True: @@ -151,7 +152,7 @@ def test_binarize_activations_forward(self, _seed, input_size, use_cuda): ref_value = ReferenceActivationBinarize.forward(ref_input, ref_scale, ref_threshold) test_value = activation_bin_scale_threshold_op(test_input, test_scale, test_threshold) - check_equal(test_value, ref_value, rtol=1e-3) + PTTensorListComparator.check_equal(test_value, ref_value, rtol=1e-3) def test_binarize_activations_backward(self, _seed, input_size, use_cuda): if not torch.cuda.is_available() and use_cuda is True: @@ -169,4 +170,4 @@ def test_binarize_activations_backward(self, _seed, input_size, use_cuda): test_value.sum().backward() test_grads = get_grads([test_input, test_scale, test_threshold]) - check_equal(test_grads, ref_grads, rtol=1e-3) + PTTensorListComparator.check_equal(test_grads, ref_grads, rtol=1e-3) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index b1cd9247836..64502150c7b 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -12,7 +12,7 @@ """ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Dict, Callable, Any, Union, List, Tuple, TypeVar +from typing import Dict, Callable, Any, Union, List, Tuple import contextlib import numpy as np @@ -39,8 +39,9 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.utils import get_all_modules_by_type from tests.common.command import Command as BaseCommand +from tests.common.helpers import BaseTensorListComparator -TensorType = TypeVar('TensorType', bound=Union[torch.Tensor, np.ndarray]) +TensorType = Union[torch.Tensor, np.ndarray] def fill_conv_weight(conv, value): @@ -208,48 +209,14 @@ def _create_input_info(): def get_grads(variables: List[nn.Parameter]) -> List[torch.Tensor]: - return [var.grad.clone() for var in variables] + grads = [var.grad for var in variables] + for i, grad in enumerate(grads): + if grad is not None: + grads[i] = grad.clone() + return grads -def to_numpy(tensor: TensorType) -> np.ndarray: - if isinstance(tensor, torch.Tensor): - return tensor.cpu().detach().numpy() - return tensor - - -def compare_tensor_lists(test: List[TensorType], reference: List[TensorType], - assert_fn: Callable[[np.ndarray, np.ndarray], bool]): - assert len(test) == len(reference) - - for x, y in zip(test, reference): - x = to_numpy(x) - y = to_numpy(y) - assert_fn(x, y) - - -def check_equal(test: List[TensorType], reference: List[TensorType], rtol: float = 1e-1): - compare_tensor_lists(test, reference, - lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)) - - -def check_not_equal(test: List[TensorType], reference: List[TensorType], rtol: float = 1e-4): - compare_tensor_lists(test, reference, - lambda x, y: np.testing.assert_raises(AssertionError, - np.testing.assert_allclose, x, y, rtol=rtol)) - - -def check_less(test: List[TensorType], reference: List[TensorType], rtol=1e-4): - check_not_equal(test, reference, rtol=rtol) - compare_tensor_lists(test, reference, np.testing.assert_array_less) - - -def check_greater(test: List[TensorType], reference: List[TensorType], rtol=1e-4): - check_not_equal(test, reference, rtol=rtol) - compare_tensor_lists(test, reference, - lambda x, y: np.testing.assert_raises(AssertionError, np.testing.assert_array_less, x, y)) - - -def create_compressed_model_and_algo_for_test(model: Module, config: NNCFConfig=None, +def create_compressed_model_and_algo_for_test(model: Module, config: NNCFConfig, dummy_forward_fn: Callable[[Module], Any] = None, wrap_inputs_fn: Callable[[Tuple, Dict], Tuple[Tuple, Dict]] = None, compression_state: Dict[str, Any] = None) \ @@ -443,3 +410,15 @@ def set_torch_seed(seed: int = 42): torch.manual_seed(seed) yield torch.manual_seed(saved_seed) + + +def to_numpy(tensor: TensorType) -> np.ndarray: + if isinstance(tensor, torch.Tensor): + return tensor.cpu().detach().numpy() + return tensor + + +class PTTensorListComparator(BaseTensorListComparator): + @classmethod + def _to_numpy(cls, tensor: torch.Tensor) -> np.ndarray: + return to_numpy(tensor) diff --git a/tests/torch/quantization/test_functions.py b/tests/torch/quantization/test_functions.py index 01d6fb08d96..a2cfd830a94 100644 --- a/tests/torch/quantization/test_functions.py +++ b/tests/torch/quantization/test_functions.py @@ -18,7 +18,8 @@ from nncf.torch.quantization.quantize_functions import asymmetric_quantize, symmetric_quantize from nncf.torch.utils import sum_like -from tests.torch.helpers import get_grads, check_equal +from tests.torch.helpers import PTTensorListComparator +from tests.torch.helpers import get_grads EPS = 1e-6 @@ -127,7 +128,7 @@ def check_outputs_for_quantization_functions(test_val: torch.Tensor, ref_val: np # tensor equality - the test passes for FP32 cases, and the kernel implementation # is exactly the same for FP16 calculations-wise. return - check_equal(test_val, ref_val, rtol) + PTTensorListComparator.check_equal(test_val, ref_val, rtol) @pytest.mark.parametrize('input_size', diff --git a/tests/torch/sparsity/const/test_algo.py b/tests/torch/sparsity/const/test_algo.py index df1d23e4595..dbad37cc314 100644 --- a/tests/torch/sparsity/const/test_algo.py +++ b/tests/torch/sparsity/const/test_algo.py @@ -18,7 +18,7 @@ from nncf.torch.module_operations import UpdateWeight from nncf.torch.sparsity.const.algo import ConstSparsityController from nncf.torch.sparsity.layers import BinaryMask -from tests.torch.quantization.test_functions import check_equal +from tests.torch.helpers import PTTensorListComparator from tests.torch.sparsity.magnitude.test_helpers import MagnitudeTestModel from tests.torch.helpers import BasicConvTestModel, get_empty_config, create_compressed_model_and_algo_for_test, \ check_correct_nncf_modules_replacement @@ -71,10 +71,10 @@ def test_can_restore_binary_mask_on_magnitude_algo_resume(): load_state(const_sparse_model, sparse_model.state_dict()) op = const_sparse_model.conv1.pre_ops['0'] - check_equal(ref_mask_1, op.operand.binary_mask) + PTTensorListComparator.check_equal(ref_mask_1, op.operand.binary_mask) op = const_sparse_model.conv2.pre_ops['0'] - check_equal(ref_mask_2, op.operand.binary_mask) + PTTensorListComparator.check_equal(ref_mask_2, op.operand.binary_mask) @pytest.mark.parametrize("use_data_parallel", [True, False], ids=["dataparallel", "regular"]) @@ -105,7 +105,7 @@ def test_can_restore_binary_mask_on_magnitude_quant_algo_resume(tmp_path, use_da load_state(const_sparse_model, sparse_model.state_dict()) op = const_sparse_model.get_nncf_wrapped_model().conv1.pre_ops['0'] - check_equal(ref_mask_1, op.operand.binary_mask) + PTTensorListComparator.check_equal(ref_mask_1, op.operand.binary_mask) op = const_sparse_model.get_nncf_wrapped_model().conv2.pre_ops['0'] - check_equal(ref_mask_2, op.operand.binary_mask) + PTTensorListComparator.check_equal(ref_mask_2, op.operand.binary_mask) diff --git a/tests/torch/sparsity/magnitude/test_algo.py b/tests/torch/sparsity/magnitude/test_algo.py index ac60ead2398..5b6f1e6d186 100644 --- a/tests/torch/sparsity/magnitude/test_algo.py +++ b/tests/torch/sparsity/magnitude/test_algo.py @@ -21,7 +21,7 @@ from nncf.torch.sparsity.layers import BinaryMask from nncf.torch.sparsity.magnitude.algo import MagnitudeSparsityController from nncf.torch.sparsity.magnitude.functions import normed_magnitude -from tests.torch.quantization.test_functions import check_equal +from tests.torch.helpers import PTTensorListComparator from tests.torch.sparsity.const.test_algo import ref_mask_2, ref_mask_1 from tests.torch.sparsity.magnitude.test_helpers import MagnitudeTestModel, get_basic_magnitude_sparsity_config from tests.torch.helpers import create_compressed_model_and_algo_for_test, MockModel, BasicConvTestModel, \ @@ -122,10 +122,10 @@ def test_magnitude_algo_set_binary_mask_on_forward(): sparse_model(torch.ones([1, 1, 10, 10])) op = sparse_model.conv1.pre_ops['0'] - check_equal(ref_mask_1, op.operand.binary_mask) + PTTensorListComparator.check_equal(ref_mask_1, op.operand.binary_mask) op = sparse_model.conv2.pre_ops['0'] - check_equal(ref_mask_2, op.operand.binary_mask) + PTTensorListComparator.equal(ref_mask_2, op.operand.binary_mask) def test_magnitude_algo_binary_masks_are_applied(): diff --git a/tests/torch/sparsity/magnitude/test_helpers.py b/tests/torch/sparsity/magnitude/test_helpers.py index 09394c1bf9a..86ac5b9db1e 100644 --- a/tests/torch/sparsity/magnitude/test_helpers.py +++ b/tests/torch/sparsity/magnitude/test_helpers.py @@ -15,8 +15,8 @@ from torch import nn from nncf.config import NNCFConfig -from tests.torch.quantization.test_functions import check_equal from tests.torch.helpers import create_conv +from tests.torch.helpers import PTTensorListComparator class MagnitudeTestModel(nn.Module): @@ -44,11 +44,11 @@ def test_magnitude_model_has_expected_params(): [-10., -10., -9.]]]]) ref_weights_2 = torch.cat((sub_tensor, sub_tensor), 1) - check_equal(act_weights_1, ref_weights_1) - check_equal(act_weights_2, ref_weights_2) + PTTensorListComparator.check_equal(act_weights_1, ref_weights_1) + PTTensorListComparator.check_equal(act_weights_2, ref_weights_2) - check_equal(act_bias_1, torch.tensor([-2., -2])) - check_equal(act_bias_2, torch.tensor([0])) + PTTensorListComparator.check_equal(act_bias_1, torch.tensor([-2., -2])) + PTTensorListComparator.check_equal(act_bias_2, torch.tensor([0])) def get_basic_magnitude_sparsity_config(input_sample_size=None): diff --git a/tests/torch/test_compression_lr_multiplier.py b/tests/torch/test_compression_lr_multiplier.py index ccda1e2ab3c..4bf318ff619 100644 --- a/tests/torch/test_compression_lr_multiplier.py +++ b/tests/torch/test_compression_lr_multiplier.py @@ -11,11 +11,7 @@ limitations under the License. """ -import copy -from typing import Callable -from typing import Tuple -from typing import List -from typing import Dict +from typing import Callable, Dict, Generator, List, Tuple import pytest import torch @@ -26,179 +22,17 @@ from nncf import NNCFConfig from nncf.torch.layer_utils import CompressionParameter +from tests.common.test_compression_lr_multiplier import BaseCompressionLRMultiplierTester from tests.torch.helpers import create_initialized_compressed_model from tests.torch.helpers import create_random_mock_dataloader -from tests.torch.helpers import check_equal -from tests.torch.helpers import check_not_equal -from tests.torch.helpers import check_less -from tests.torch.helpers import check_greater from tests.torch.helpers import get_grads from tests.torch.helpers import LeNet +from tests.torch.helpers import PTTensorListComparator from tests.torch.helpers import RandomDatasetMock from tests.torch.helpers import set_torch_seed -from tests.torch.quantization.test_algo_quantization import get_quantization_config_without_range_init -from tests.torch.sparsity.rb.test_algo import get_basic_sparsity_config - - -ALGO_NAME_TO_PATH_MAP = { - 'quantization': 'nncf.torch.quantization', - 'rb_sparsity': 'nncf.torch.sparsity.rb', - 'binarization': 'nncf.torch.binarization' -} - - -def get_quantization_config() -> NNCFConfig: - config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1]) - config['compression']['initializer'] = { - 'range': { - 'num_init_samples': 10 - }, - 'batchnorm_adaptation': { - 'num_bn_adaptation_samples': 0, - } - } - return config - - -def get_sparsity_config() -> NNCFConfig: - config = get_basic_sparsity_config([1, *LeNet.INPUT_SIZE]) - return config - - -def get_binarization_config() -> NNCFConfig: - config = NNCFConfig() - config.update({ - "model": "resnet18", - - "input_info": { - "sample_size": [1, *LeNet.INPUT_SIZE] - }, - - "compression": [ - { - "algorithm": "binarization", - "mode": "xnor", - "params": { - "activations_quant_start_epoch": 0, - "weights_quant_start_epoch": 0 - } - } - ] - }) - return config - - -def get_config_algorithms(config: NNCFConfig) -> List[Dict]: - if isinstance(config['compression'], list): - algorithms = config['compression'] - else: - algorithms = [config['compression']] - return algorithms - - -def add_multiplier_to_config(config: NNCFConfig, - local_multiplier: float = None, global_multiplier: float = None) -> NNCFConfig: - config = copy.deepcopy(config) - - if local_multiplier is not None: - algorithms = get_config_algorithms(config) - - for algo in algorithms: - algo.update({ - 'compression_lr_multiplier': local_multiplier - }) - - if global_multiplier is not None: - config['compression_lr_multiplier'] = global_multiplier - - return config - - -def get_multipliers_from_config(config: NNCFConfig) -> Dict[str, float]: - algo_to_multipliers = {} - - algorithms = get_config_algorithms(config) - global_multiplier = config.get('compression_lr_multiplier', 1) - for algo in algorithms: - algo_name = algo['algorithm'] - algo_to_multipliers[algo_name] = algo.get('compression_lr_multiplier', global_multiplier) - - return algo_to_multipliers - - -def merge_configs(configs: List[NNCFConfig], use_algo_list: bool = True) -> NNCFConfig: - res_config = None - algorithms = [] - - for source_config in configs: - source_config = copy.deepcopy(source_config) - - algorithms.extend(get_config_algorithms(source_config)) - del source_config['compression'] - - if res_config is None: - res_config = source_config - res_config.update(source_config) - - if not use_algo_list: - if len(algorithms) > 1: - raise Exception('If there is more than one algorithm ' - 'you could use only use_algo_list=True') - res_config['compression'] = algorithms[0] - else: - res_config['compression'] = algorithms - - res_config['model'] = 'merged_model' - return res_config - - -def get_configs_building_params() -> List[Dict]: - res = [] - get_orig_config_fns = [get_quantization_config, get_sparsity_config, get_binarization_config] - num_orig_configs = len(get_orig_config_fns) - - for global_multiplier in [0, 1, 10]: - res.append({ - 'get_orig_config_fns': get_orig_config_fns, - 'multipliers': [None] * num_orig_configs, - 'global_multiplier': global_multiplier, - 'use_algo_list': True - }) - - global_multiplier = 10 - multipliers = [global_multiplier * (1.1 ** i) for i in range(num_orig_configs)] - - res.append({ - 'get_orig_config_fns': get_orig_config_fns, - 'multipliers': multipliers, - 'global_multiplier': global_multiplier, - 'use_algo_list': True - }) - - for i in range(num_orig_configs): - cur_multipliers = copy.deepcopy(multipliers) - cur_multipliers[i] = None - res.append({ - 'get_orig_config_fns': get_orig_config_fns, - 'multipliers': cur_multipliers, - 'global_multiplier': None, - 'use_algo_list': True - }) - - for get_orig_config_fn in get_orig_config_fns: - for use_algo_list in [False, True]: - for global_multiplier, multiplier in [(11, 10), (11, None), (None, 10)]: - res.append({ - 'get_orig_config_fns': [get_orig_config_fn], - 'multipliers': [multiplier], - 'global_multiplier': global_multiplier, - 'use_algo_list': use_algo_list - }) - - return res -def create_initialized_lenet_model_and_dataloader(config: NNCFConfig) -> Tuple[nn.Module, DataLoader]: +def create_initialized_model_and_dataset(config: NNCFConfig) -> Tuple[nn.Module, DataLoader]: with set_torch_seed(): train_loader = create_random_mock_dataloader(config, num_samples=10) model = LeNet() @@ -208,45 +42,22 @@ def create_initialized_lenet_model_and_dataloader(config: NNCFConfig) -> Tuple[n return model, train_loader -@pytest.fixture(name='configs_building_params', - params=get_configs_building_params()) -def configs_building_params_(request) -> Dict: - return request.param - - -@pytest.fixture(name='ref_configs') -def ref_configs_(configs_building_params: Dict) -> List[NNCFConfig]: - return [get_ref_config_fn() for get_ref_config_fn in configs_building_params['get_orig_config_fns']] - - -@pytest.fixture(name='ref_config') -def ref_config_(ref_configs, configs_building_params) -> NNCFConfig: - return merge_configs(ref_configs, configs_building_params['use_algo_list']) - - -@pytest.fixture(name='target_configs') -def target_configs_(ref_configs: List[NNCFConfig], configs_building_params: Dict) -> List[NNCFConfig]: - return [add_multiplier_to_config(config, local_multiplier=multiplier) - for config, multiplier in zip(ref_configs, configs_building_params['multipliers'])] - - -@pytest.fixture(name='target_config') -def target_config_(target_configs: List[NNCFConfig], configs_building_params: Dict) -> NNCFConfig: - target_config = merge_configs(target_configs, configs_building_params['use_algo_list']) - return add_multiplier_to_config(target_config, global_multiplier=configs_building_params['global_multiplier']) +@pytest.fixture(name='sample_size') +def sample_size_(): + return list((1,) + LeNet.INPUT_SIZE) -@pytest.fixture(name='get_ref_lenet_model_and_dataloader') -def get_ref_lenet_model_and_dataloader_(ref_config: NNCFConfig) -> Callable[[], Tuple[nn.Module, DataLoader]]: +@pytest.fixture(name='get_ref_model_and_dataset') +def get_ref_model_and_dataset_(ref_config: NNCFConfig) -> Callable[[], Tuple[nn.Module, DataLoader]]: def f(): - return create_initialized_lenet_model_and_dataloader(ref_config) + return create_initialized_model_and_dataset(ref_config) return f -@pytest.fixture(name='get_target_lenet_model_and_dataloader') -def get_target_lenet_model_and_dataloader_(target_config: NNCFConfig) -> Callable[[], Tuple[nn.Module, DataLoader]]: +@pytest.fixture(name='get_target_model_and_dataset') +def get_target_model_and_dataset_(target_config: NNCFConfig) -> Callable[[], Tuple[nn.Module, DataLoader]]: def f(): - return create_initialized_lenet_model_and_dataloader(target_config) + return create_initialized_model_and_dataset(target_config) return f @@ -285,9 +96,9 @@ def get_one_parameter_model_creation_params(for_training: bool = False) -> List[ return params -def create_initialized_one_parameter_model_and_dataloader(parameter_cls: type, init_requires_grad: bool, - requires_grad_settings: List[Tuple[str, bool]], - multiplier: float = None) -> [nn.Module, DataLoader]: +def create_initialized_one_parameter_model_and_dataset(parameter_cls: type, init_requires_grad: bool, + requires_grad_settings: List[Tuple[str, bool]], + multiplier: float = None) -> [nn.Module, DataLoader]: with set_torch_seed(): data = torch.randn(size=(1, 1, 5, 5)) if parameter_cls is nn.Parameter: @@ -316,8 +127,8 @@ def create_initialized_one_parameter_model_and_dataloader(parameter_cls: type, i def get_ref_one_parameter_model_and_dataloader_(one_parameter_model_creation_params: Dict) -> \ Callable[[], Tuple[nn.Module, DataLoader]]: def f(): - return create_initialized_one_parameter_model_and_dataloader(nn.Parameter, - **one_parameter_model_creation_params) + return create_initialized_one_parameter_model_and_dataset(nn.Parameter, + **one_parameter_model_creation_params) return f @@ -325,205 +136,107 @@ def f(): def get_target_one_parameter_model_and_dataloader_(one_parameter_model_creation_params: Dict) -> \ Callable[[], Tuple[nn.Module, DataLoader]]: def f(): - return create_initialized_one_parameter_model_and_dataloader(CompressionParameter, - **one_parameter_model_creation_params) + return create_initialized_one_parameter_model_and_dataset(CompressionParameter, + **one_parameter_model_creation_params) return f -def perform_model_training_steps(model: nn.Module, train_loader: DataLoader, num_steps: int = 1) -> nn.Module: - with set_torch_seed(): - train_loader = iter(train_loader) - optimizer = SGD(model.parameters(), lr=0.1) - - # This block of code is needed to initialize scale in the binarization algorithm - # TODO: perform binarization scale init in the same way as for quantization - with torch.no_grad(): - x, y_gt = next(train_loader) - model(x) - - for _ in range(num_steps): - optimizer.zero_grad() - x, y_gt = next(train_loader) - y = model(x) - loss = F.mse_loss(y.sum(), y_gt) - - loss.backward() - optimizer.step() - - return model - - -def get_params_grouped_by_algorithms(model: nn.Module) -> Dict[str, List[nn.Parameter]]: - cls_name_to_params = {} - for module in model.modules(): - params = list(module.parameters(recurse=False)) - full_cls_name = '.'.join([module.__class__.__module__, module.__class__.__name__]) - if full_cls_name not in cls_name_to_params: - cls_name_to_params[full_cls_name] = [] - cls_name_to_params[full_cls_name].extend(params) - - algo_name_to_params = {} - for cls_name, params in cls_name_to_params.items(): - params = [param for param in params if param.requires_grad] - if len(params) == 0: - continue - - algo_name = 'regular' - for cur_algo_name, cur_algo_path in ALGO_NAME_TO_PATH_MAP.items(): - if cur_algo_path in cls_name: - algo_name = cur_algo_name - - if algo_name not in algo_name_to_params: - algo_name_to_params[algo_name] = [] - algo_name_to_params[algo_name].extend(params) - - return algo_name_to_params - - -def get_lenet_params_after_training_steps(model: nn.Module, train_loader: DataLoader, - num_steps: int = 1) -> Dict[str, List[nn.Parameter]]: - with set_torch_seed(): - model = perform_model_training_steps(model, train_loader, num_steps) - return get_params_grouped_by_algorithms(model) - - -def get_one_parameter_model_params_after_training_steps(model: nn.Module, train_loader: DataLoader, - num_steps: int = 1) -> List[nn.Parameter]: - with set_torch_seed(): - model = perform_model_training_steps(model, train_loader, num_steps) - return list(model.parameters()) - +@pytest.mark.usefixtures('ref_config') +class TestPTCompressionLRMultiplier(BaseCompressionLRMultiplierTester): + ALGO_NAME_TO_PATH_MAP = { + 'quantization': 'nncf.torch.quantization', + 'rb_sparsity': 'nncf.torch.sparsity.rb', + 'binarization': 'nncf.torch.binarization' + } -def test_if_algorithms_add_params( - get_target_lenet_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - ref_config: NNCFConfig -): - algo_to_params = get_lenet_params_after_training_steps(*get_target_lenet_model_and_dataloader(), num_steps=0) - algo_names = get_multipliers_from_config(ref_config).keys() - - assert sorted(algo_to_params.keys()) == sorted(list(algo_names) + ['regular']) - - -@pytest.mark.parametrize('one_parameter_model_creation_params', - get_one_parameter_model_creation_params()) -def test_if_parameter_is_initialized_correctly( - get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]] -): - ref_model, _ref_loader = get_ref_one_parameter_model_and_dataloader() - target_model, target_loader = get_target_one_parameter_model_and_dataloader() - - assert pytest.approx(ref_model.param.data) == target_model.param.data - assert ref_model.param.requires_grad == target_model.param.requires_grad - - if ref_model.param.requires_grad: - get_one_parameter_model_params_after_training_steps(target_model, target_loader) - else: - with pytest.raises(Exception): - get_one_parameter_model_params_after_training_steps(target_model, target_loader) - - -def check_if_grads_are_multiplied(ref_params: List[nn.Parameter], target_params: List[nn.Parameter], - multiplier: float): - ref_grads = get_grads(ref_params) - ref_grads = [multiplier * grad for grad in ref_grads] - target_grads = get_grads(target_params) - - check_equal(ref_grads, target_grads) - - -def test_if_setting_multipliers_in_config_multiplies_grads_values( - get_ref_lenet_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - get_target_lenet_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - target_config: NNCFConfig -): - ref_params = get_lenet_params_after_training_steps(*get_ref_lenet_model_and_dataloader()) - target_params = get_lenet_params_after_training_steps(*get_target_lenet_model_and_dataloader()) - multipliers = get_multipliers_from_config(target_config) - multipliers['regular'] = 1 - - for algo in ref_params: - check_if_grads_are_multiplied(ref_params[algo], target_params[algo], multipliers[algo]) - - -@pytest.mark.parametrize('one_parameter_model_creation_params', - get_one_parameter_model_creation_params(for_training=True)) -def test_if_setting_multiplier_in_parameter_multiplies_grads_values( - get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - one_parameter_model_creation_params: Dict -): - ref_params = get_one_parameter_model_params_after_training_steps(*get_ref_one_parameter_model_and_dataloader()) - target_params = \ - get_one_parameter_model_params_after_training_steps(*get_target_one_parameter_model_and_dataloader()) - - assert target_params[0].requires_grad - check_if_grads_are_multiplied(ref_params, target_params, one_parameter_model_creation_params['multiplier']) - - -def check_if_zero_multiplier_freezes_training(orig_params: List[nn.Parameter], params: List[nn.Parameter], - multiplier: float): - if multiplier == 0: - check_equal(orig_params, params) - else: - check_not_equal(orig_params, params) - - -def get_params_diff(orig_params: List[nn.Parameter], params: List[nn.Parameter]) -> List[torch.Tensor]: - param_diffs = [] - for param, orig_param in zip(params, orig_params): - param_diffs.append((param - orig_param).abs()) - return param_diffs - - -def check_params_affect_training_speed(orig_params: List[nn.Parameter], - ref_params: List[nn.Parameter], target_params: List[nn.Parameter], - compression_lr_multiplier: float): - assert len(ref_params) == len(orig_params) - assert len(target_params) == len(orig_params) - - ref_diff = get_params_diff(ref_params, orig_params) - target_diff = get_params_diff(target_params, orig_params) - - if pytest.approx(compression_lr_multiplier) == 1: - check_equal(target_diff, ref_diff) - elif compression_lr_multiplier < 1: - check_less(target_diff, ref_diff) - else: - check_greater(target_diff, ref_diff) - - -def test_if_setting_multipliers_in_config_affect_training_speed( - get_ref_lenet_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - get_target_lenet_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - target_config: NNCFConfig -): - orig_params = get_lenet_params_after_training_steps(*get_ref_lenet_model_and_dataloader(), num_steps=0) - target_params = get_lenet_params_after_training_steps(*get_target_lenet_model_and_dataloader(), num_steps=1) - multipliers = get_multipliers_from_config(target_config) - multipliers['regular'] = 1 - - for algo in orig_params: - check_if_zero_multiplier_freezes_training(orig_params[algo], target_params[algo], multipliers[algo]) - - -@pytest.mark.parametrize('one_parameter_model_creation_params', - get_one_parameter_model_creation_params(for_training=True)) -def test_if_setting_multiplier_in_parameter_affect_training_speed( - get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], - one_parameter_model_creation_params: Dict -): - orig_params = \ - get_one_parameter_model_params_after_training_steps(*get_ref_one_parameter_model_and_dataloader(), num_steps=0) - ref_params = \ - get_one_parameter_model_params_after_training_steps(*get_ref_one_parameter_model_and_dataloader(), num_steps=1) - target_params = \ - get_one_parameter_model_params_after_training_steps(*get_target_one_parameter_model_and_dataloader(), - num_steps=1) - - assert target_params[0].requires_grad - check_if_zero_multiplier_freezes_training(orig_params, target_params, - one_parameter_model_creation_params['multiplier']) - check_params_affect_training_speed(orig_params, ref_params, target_params, - one_parameter_model_creation_params['multiplier']) + TensorListComparator = PTTensorListComparator + + @classmethod + def _perform_model_training_steps(cls, model: nn.Module, dataset: DataLoader, + num_steps: int = 1) -> nn.Module: + with set_torch_seed(): + dataset = iter(dataset) + optimizer = SGD(model.parameters(), lr=0.1) + + # This block of code is needed to initialize scale in the binarization algorithm + # TODO: perform binarization scale init in the same way as for quantization + with torch.no_grad(): + x, y_gt = next(dataset) + model(x) + + for _ in range(num_steps): + optimizer.zero_grad() + x, y_gt = next(dataset) + y = model(x) + loss = F.mse_loss(y.sum(), y_gt) + + loss.backward() + optimizer.step() + + return model + + @classmethod + def _get_layer_cls_and_params(cls, model: nn.Module) -> Generator[Tuple[type, List[nn.Parameter]], None, None]: + for module in model.modules(): + params = list(filter(lambda param: param.requires_grad, module.parameters(recurse=False))) + yield module.__class__, params + + @classmethod + def _get_grads(cls, params: Dict[str, List[nn.Parameter]]) -> Dict[str, List[torch.Tensor]]: + return {k: get_grads(v) for k, v in params.items()} + + @classmethod + def _get_params_and_grads_after_training_steps(cls, model: nn.Module, dataset: DataLoader, + num_steps: int = 1) -> Tuple[Dict[str, List[nn.Parameter]], + Dict[str, List[torch.Tensor]]]: + with set_torch_seed(): + model = cls._perform_model_training_steps(model, dataset, num_steps) + params = cls._get_params_grouped_by_algos(model) + grads = cls._get_grads(params) + params = {algo: [param.cpu().detach() for param in params[algo]] for algo in params} + return params, grads + + @pytest.mark.parametrize('one_parameter_model_creation_params', + get_one_parameter_model_creation_params()) + def test_compression_parameter_is_initialized_correctly( + self, + get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], + get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]] + ): + ref_model, _ref_loader = get_ref_one_parameter_model_and_dataloader() + target_model, target_loader = get_target_one_parameter_model_and_dataloader() + + assert pytest.approx(ref_model.param.data) == target_model.param.data + assert ref_model.param.requires_grad == target_model.param.requires_grad + + if ref_model.param.requires_grad: + self._perform_model_training_steps(target_model, target_loader) + else: + with pytest.raises(Exception): + self._perform_model_training_steps(target_model, target_loader) + + @pytest.mark.parametrize('one_parameter_model_creation_params', + get_one_parameter_model_creation_params(for_training=True)) + def test_multiplier_in_parameter_multiplies_grads( + self, + get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], + get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], + one_parameter_model_creation_params: Dict + ): + multipliers = {'regular': one_parameter_model_creation_params['multiplier']} + self._check_multipliers_in_config_multiplies_grads(get_ref_one_parameter_model_and_dataloader, + get_target_one_parameter_model_and_dataloader, + multipliers) + + @pytest.mark.parametrize('one_parameter_model_creation_params', + get_one_parameter_model_creation_params(for_training=True)) + def test_multiplier_in_parameter_affect_training_speed( + self, + get_ref_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], + get_target_one_parameter_model_and_dataloader: Callable[[], Tuple[nn.Module, DataLoader]], + one_parameter_model_creation_params: Dict, + ): + multipliers = {'regular': one_parameter_model_creation_params['multiplier']} + self._check_multipliers_in_config_affect_training_speed(get_ref_one_parameter_model_and_dataloader, + get_target_one_parameter_model_and_dataloader, + multipliers) diff --git a/tests/torch/test_helpers.py b/tests/torch/test_helpers.py index bacc540f9ea..abcdd0889f5 100644 --- a/tests/torch/test_helpers.py +++ b/tests/torch/test_helpers.py @@ -15,9 +15,9 @@ import torch from nncf import NNCFConfig +from tests.torch.helpers import PTTensorListComparator from tests.torch.helpers import BasicConvTestModel from tests.torch.helpers import TwoConvTestModel -from tests.torch.helpers import check_equal from tests.torch.helpers import create_compressed_model_and_algo_for_test @@ -28,8 +28,8 @@ def test_basic_model_has_expected_params(): act_bias = model.conv.bias.data ref_bias = BasicConvTestModel.default_bias() - check_equal(act_bias, ref_bias) - check_equal(act_weights, ref_weights) + PTTensorListComparator.check_equal(act_bias, ref_bias) + PTTensorListComparator.check_equal(act_weights, ref_weights) assert act_weights.nonzero().size(0) == model.nz_weights_num assert act_bias.nonzero().size(0) == model.nz_bias_num @@ -42,7 +42,7 @@ def test_basic_model_is_valid(): input_ = torch.ones([1, 1, 4, 4]) ref_output = torch.ones((1, 2, 3, 3)) * (-4) act_output = model(input_) - check_equal(act_output, ref_output) + PTTensorListComparator.check_equal(act_output, ref_output) def test_two_conv_model_has_expected_params(): @@ -56,11 +56,11 @@ def test_two_conv_model_has_expected_params(): channel = torch.eye(3, 3).reshape([1, 1, 3, 3]) ref_weights_2 = torch.cat((channel, channel), 1) - check_equal(act_weights_1, ref_weights_1) - check_equal(act_weights_2, ref_weights_2) + PTTensorListComparator.check_equal(act_weights_1, ref_weights_1) + PTTensorListComparator.check_equal(act_weights_2, ref_weights_2) - check_equal(act_bias_1, BasicConvTestModel.default_bias()) - check_equal(act_bias_2, torch.tensor([0])) + PTTensorListComparator.check_equal(act_bias_1, BasicConvTestModel.default_bias()) + PTTensorListComparator.check_equal(act_bias_2, torch.tensor([0])) assert act_weights_1.nonzero().size(0) + act_weights_2.nonzero().size(0) == model.nz_weights_num assert act_bias_1.nonzero().size(0) + act_bias_2.nonzero().size(0) == model.nz_bias_num @@ -73,7 +73,7 @@ def test_two_conv_model_is_valid(): input_ = torch.ones([1, 1, 4, 4]) ref_output = torch.tensor(-24).reshape((1, 1, 1, 1)) act_output = model(input_) - check_equal([act_output], [ref_output]) + PTTensorListComparator.check_equal([act_output], [ref_output]) def load_exported_onnx_version(nncf_config: NNCFConfig, model: torch.nn.Module, diff --git a/tests/torch/test_load_model_state.py b/tests/torch/test_load_model_state.py index 5eda1ed7123..33c42fa4961 100644 --- a/tests/torch/test_load_model_state.py +++ b/tests/torch/test_load_model_state.py @@ -26,8 +26,8 @@ from nncf.torch.layers import NNCF_PADDING_VALUE_ATTR_NAME from nncf.torch.nncf_network import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.nncf_network import LEGACY_ACT_STORAGE_NAME +from tests.torch.helpers import PTTensorListComparator from tests.torch.helpers import BasicConvTestModel -from tests.torch.helpers import check_equal def test_export_sq_11_is_ok(tmp_path): @@ -49,8 +49,8 @@ def test_load_state_skips_not_matched_params__from_larger_to_smaller(): act_bias = model_load.conv.bias.data act_weights = model_load.conv.weight.data assert num_loaded == 0 - check_equal(act_bias, ref_bias) - check_equal(act_weights, ref_weights) + PTTensorListComparator.check_equal(act_bias, ref_bias) + PTTensorListComparator.check_equal(act_weights, ref_weights) def test_can_skip_padding_value(): @@ -93,8 +93,8 @@ def test_load_state_skips_not_matched_params__from_smaller_to_larger(): assert num_loaded == 0 act_bias = model_load.conv.bias.data act_weights = model_load.conv.weight.data - check_equal(act_bias, ref_bias) - check_equal(act_weights, ref_weights) + PTTensorListComparator.check_equal(act_bias, ref_bias) + PTTensorListComparator.check_equal(act_weights, ref_weights) class MatchKeyDesc: