Skip to content

Commit

Permalink
Add compression_lr_multiplier for tf part
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ignatyev committed Aug 11, 2021
1 parent 169cef8 commit 3333974
Show file tree
Hide file tree
Showing 22 changed files with 833 additions and 532 deletions.
12 changes: 9 additions & 3 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def __init__(self, num_bits: int,
mode: QuantizationMode,
signedness_to_force: bool,
narrow_range: bool,
half_range: bool):
half_range: bool,
compression_lr_multiplier: Optional[float] = None):
"""
:param num_bits: Bitwidth of the quantization.
:param mode: The mode of quantization (symmetric or asymmetric).
Expand All @@ -147,23 +148,28 @@ def __init__(self, num_bits: int,
naive case, False if all 2^`num_bits` quantizations should be used.
:param half_range: If ``True`` effectively only a half of an quantizer range are used.
False - the full range are used.
:compression_lr_multiplier: Multiplier for gradient values
"""
self.num_bits = num_bits
self.mode = mode
self.signedness_to_force = signedness_to_force
self.narrow_range = narrow_range
self.half_range = half_range
self.compression_lr_multiplier = compression_lr_multiplier

def __eq__(self, other: 'QuantizerSpec'):
return self.__dict__ == other.__dict__

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'QuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig,
narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float]) -> 'QuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range)
half_range,
compression_lr_multiplier)


class QuantizationConstraints:
Expand Down
13 changes: 13 additions & 0 deletions nncf/tensorflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import tensorflow as tf
from typing import Callable


@tf.function
Expand All @@ -24,3 +25,15 @@ def st_threshold(input_):
def grad(upstream):
return upstream
return tf.round(input_), grad


def get_id_with_multiplied_grad(grad_multiplier: float) -> Callable[[tf.Tensor], tf.Tensor]:
@tf.custom_gradient
def id_with_multiplied_grad(x):
def grad(upstream):
if grad_multiplier is None:
return upstream
return grad_multiplier * upstream
return x, grad

return id_with_multiplied_grad
17 changes: 11 additions & 6 deletions nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
limitations under the License.
"""
from copy import deepcopy
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf

Expand Down Expand Up @@ -395,9 +392,11 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup:

half_range = self._get_half_range(qconfig)
applied_saturation_fix = applied_saturation_fix or half_range
compression_lr_multiplier = self._get_compression_lr_multiplier()
quantizer_spec = TFQuantizerSpec.from_config(qconfig,
narrow_range=not half_range,
half_range=half_range)
half_range=half_range,
compression_lr_multiplier=compression_lr_multiplier)
target_point = TFLayerWeight(layer_info.layer_name, weight_def.weight_attr_name)
qpoint = TFQuantizationPoint(op_name, quantizer_spec, target_point)
setup.add_quantization_point(qpoint)
Expand All @@ -408,7 +407,9 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup:
target_node_name = ip.target_node_name
input_port_id = ip.input_port_id
fake_quantize_name = self._get_fake_quantize_name(target_node_name, input_port_id)
quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False)
compression_lr_multiplier = self._get_compression_lr_multiplier()
quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False,
compression_lr_multiplier=compression_lr_multiplier)
fake_quantize_layer = FakeQuantize(
quantizer_spec,
name=fake_quantize_name)
Expand Down Expand Up @@ -566,6 +567,10 @@ def _get_fake_quantize_name(self, node_name: NNCFNodeName, input_port_id: int =
def _get_quantizer_operation_name(self, layer_name, weight_attr_name):
return f'{layer_name}_{weight_attr_name}_quantizer'

def _get_compression_lr_multiplier(self) -> Optional[float]:
return self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
self.name)


class QuantizationController(BaseCompressionAlgorithmController):
def __init__(self, target_model, config, op_names: List[str]):
Expand Down
36 changes: 28 additions & 8 deletions nncf/tensorflow/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.tensorflow.layers.custom_objects import NNCF_QUANTIZATION_OPERATONS
from nncf.tensorflow.layers.data_layout import get_channel_axis
from nncf.tensorflow.layers.data_layout import get_channel_size
from nncf.tensorflow.functions import get_id_with_multiplied_grad
from nncf.tensorflow.layers.operation import NNCFOperation
from nncf.tensorflow.quantization.functions import asymmetric_quantize
from nncf.tensorflow.quantization.functions import symmetric_quantize
Expand All @@ -36,18 +37,21 @@ def __init__(self, num_bits: int,
signedness_to_force: Optional[bool],
narrow_range: bool,
half_range: bool,
per_channel: bool):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
per_channel: bool,
compression_lr_multiplier: Optional[float] = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier)
self.per_channel = per_channel

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'TFQuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float] = None) -> 'TFQuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range,
qconfig.per_channel)
qconfig.per_channel,
compression_lr_multiplier)

def get_state(self) -> Dict[str, Any]:
"""
Expand All @@ -62,7 +66,8 @@ def get_state(self) -> Dict[str, Any]:
'signedness_to_force': self.signedness_to_force,
'narrow_range': self.narrow_range,
'half_range': self.half_range,
'per_channel': self.per_channel
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier
}

@classmethod
Expand Down Expand Up @@ -285,6 +290,8 @@ def __init__(self, name: str, qspec: TFQuantizerSpec):
self.narrow_range = qspec.narrow_range
self.signedness_to_force = qspec.signedness_to_force
self._half_range = qspec.half_range
self.compression_lr_multiplier = qspec.compression_lr_multiplier
self.id_with_multiplied_grad = get_id_with_multiplied_grad(qspec.compression_lr_multiplier)

@property
def half_range(self):
Expand Down Expand Up @@ -336,6 +343,10 @@ def apply_saturation_fix(self, weights):
self._half_range = False

def quantize(self, inputs, weights, _):
if self.compression_lr_multiplier is not None:
for k in weights.keys():
weights[k] = self.id_with_multiplied_grad(weights[k])

def _half_range_quantize():
return symmetric_quantize(
inputs,
Expand Down Expand Up @@ -389,6 +400,7 @@ def get_config(self):
'narrow_range': self.narrow_range,
'half_range': self._half_range,
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier,
}
config = {
'quantizer_spec': qspec_dict,
Expand All @@ -404,12 +416,12 @@ def from_config(cls, config):
signedness_to_force=qspec_dict['signedness_to_force'],
narrow_range=qspec_dict['narrow_range'],
half_range=qspec_dict['half_range'],
per_channel=qspec_dict['per_channel'])
per_channel=qspec_dict['per_channel'],
compression_lr_multiplier=qspec_dict['compression_lr_multiplier'])
name = config['name']
return cls(name, qspec)



@NNCF_CUSTOM_OBJECTS.register()
@NNCF_QUANTIZATION_OPERATONS.register(QuantizationMode.ASYMMETRIC)
class AsymmetricQuantizer(Quantizer):
Expand All @@ -419,6 +431,8 @@ def __init__(self, name: str, qspec: TFQuantizerSpec):
self.narrow_range = qspec.narrow_range
self.per_channel = qspec.per_channel
self._half_range = qspec.half_range
self.compression_lr_multiplier = qspec.compression_lr_multiplier
self.id_with_multiplied_grad = get_id_with_multiplied_grad(qspec.compression_lr_multiplier)

@property
def half_range(self):
Expand Down Expand Up @@ -466,6 +480,10 @@ def apply_saturation_fix(self, weights):
self._half_range = False

def quantize(self, inputs, weights, _):
if self.compression_lr_multiplier is not None:
for k in weights.keys():
weights[k] = self.id_with_multiplied_grad(weights[k])

def _half_range_quantize():
return asymmetric_quantize(
inputs,
Expand Down Expand Up @@ -519,6 +537,7 @@ def get_config(self):
'narrow_range': self.narrow_range,
'half_range': self._half_range,
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier,
}
config = {
'quantizer_spec': qspec_dict,
Expand All @@ -534,6 +553,7 @@ def from_config(cls, config):
signedness_to_force=None,
narrow_range=qspec_dict['narrow_range'],
half_range=qspec_dict['half_range'],
per_channel=qspec_dict['per_channel'])
per_channel=qspec_dict['per_channel'],
compression_lr_multiplier=qspec_dict['compression_lr_multiplier'])
name = config['name']
return cls(name, qspec)
8 changes: 3 additions & 5 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ def __init__(self, num_bits: int,
half_range: bool,
scale_shape: Tuple[int, ...],
logarithm_scale: bool,
compression_lr_multiplier: float = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
compression_lr_multiplier: Optional[float] = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier)
self.scale_shape = scale_shape
self.logarithm_scale = logarithm_scale
self.compression_lr_multiplier = compression_lr_multiplier


@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool,
half_range: bool, scale_shape: Tuple[int],
logarithm_scale: bool, compression_lr_multiplier: float) -> 'PTQuantizerSpec':
logarithm_scale: bool, compression_lr_multiplier: Optional[float] = None) -> 'PTQuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
Expand Down
59 changes: 57 additions & 2 deletions tests/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from pathlib import Path
import numpy as np
import os
import shutil
import subprocess
import sys

from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Callable, List, TypeVar, Union

TensorType = TypeVar('TensorType')

TEST_ROOT = Path(__file__).absolute().parents[1]
PROJECT_ROOT = TEST_ROOT.parent.absolute()
EXAMPLES_DIR = PROJECT_ROOT / 'examples'
Expand Down Expand Up @@ -113,3 +119,52 @@ def run_install_checks(venv_path, tmp_path, package_type, test_dir, install_type
install_mode,
package_type),
check=True, shell=True, cwd=run_path)


class BaseTensorListComparator(ABC):
@classmethod
@abstractmethod
def _to_numpy(cls, tensor: TensorType) -> np.ndarray:
pass

@classmethod
def _compare_tensor_lists(cls, test: Union[TensorType, List[TensorType]],
reference: Union[TensorType, List[TensorType]],
assert_fn: Callable[[np.ndarray, np.ndarray], bool]):
if not isinstance(test, list):
test = [test]
if not isinstance(reference, list):
reference = [reference]
assert len(test) == len(reference)

for x, y in zip(test, reference):
x = cls._to_numpy(x)
y = cls._to_numpy(y)
assert_fn(x, y)

@classmethod
def check_equal(cls, test: Union[TensorType, List[TensorType]], reference: Union[TensorType, List[TensorType]],
rtol: float = 1e-1):
cls._compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol))

@classmethod
def check_not_equal(cls, test: Union[TensorType, List[TensorType]], reference: Union[TensorType, List[TensorType]],
rtol: float = 1e-4):
cls._compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_raises(AssertionError,
np.testing.assert_allclose, x, y, rtol=rtol))

@classmethod
def check_less(cls, test: Union[TensorType, List[TensorType]], reference: Union[TensorType, List[TensorType]],
rtol=1e-4):
cls.check_not_equal(test, reference, rtol=rtol)
cls._compare_tensor_lists(test, reference, np.testing.assert_array_less)

@classmethod
def check_greater(cls, test: Union[TensorType, List[TensorType]], reference: Union[TensorType, List[TensorType]],
rtol=1e-4):
cls.check_not_equal(test, reference, rtol=rtol)
cls._compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_raises(AssertionError,
np.testing.assert_array_less, x, y))
Loading

0 comments on commit 3333974

Please sign in to comment.