Skip to content

Commit

Permalink
Add compression_lr_multiplier for tf part
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ignatyev committed Jun 21, 2021
1 parent 5f8c132 commit db1da64
Show file tree
Hide file tree
Showing 14 changed files with 486 additions and 244 deletions.
9 changes: 9 additions & 0 deletions nncf/common/framework_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from importlib.util import find_spec


def check_torch_installed():
return find_spec('torch')


def check_tf_installed():
return find_spec('tensorflow')
12 changes: 9 additions & 3 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __init__(self, num_bits: int,
mode: QuantizationMode,
signedness_to_force: bool,
narrow_range: bool,
half_range: bool):
half_range: bool,
compression_lr_multiplier: Optional[float] = None):
"""
:param num_bits: Bitwidth of the quantization.
:param mode: The mode of quantization (symmetric or asymmetric).
Expand All @@ -126,23 +127,28 @@ def __init__(self, num_bits: int,
naive case, False if all 2^`num_bits` quantizations should be used.
:param half_range: If ``True`` effectively only a half of an quantizer range are used.
False - the full range are used.
:compression_lr_multiplier: Multiplier for gradient values
"""
self.num_bits = num_bits
self.mode = mode
self.signedness_to_force = signedness_to_force
self.narrow_range = narrow_range
self.half_range = half_range
self.compression_lr_multiplier = compression_lr_multiplier

def __eq__(self, other: 'QuantizerSpec'):
return self.__dict__ == other.__dict__

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'QuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig,
narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float]) -> 'QuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range)
half_range,
compression_lr_multiplier)


class QuantizationConstraints:
Expand Down
13 changes: 13 additions & 0 deletions nncf/tensorflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import tensorflow as tf
from typing import Callable


@tf.function
Expand All @@ -24,3 +25,15 @@ def st_threshold(input_):
def grad(upstream):
return upstream
return tf.round(input_), grad


def get_id_with_multiplied_grad(grad_multiplier: float) -> Callable[[tf.Tensor], tf.Tensor]:
@tf.custom_gradient
def id_with_multiplied_grad(x):
def grad(dy):
if grad_multiplier is None:
return dy
return grad_multiplier * dy
return x, grad

return id_with_multiplied_grad
8 changes: 6 additions & 2 deletions nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations = TFTransformationLayout()
qconfig = self._get_default_qconfig(self.global_quantizer_constraints[QuantizerGroup.WEIGHTS])
half_range = self._get_half_range(qconfig)
compression_lr_multiplier = self.config.get('compression_lr_multiplier', None)
processed_shared_layer_names = set() # type: Set[str]
for node in nodes:
if node.is_shared():
Expand All @@ -164,7 +165,8 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
op_name,
TFQuantizerSpec.from_config(qconfig,
narrow_range=not half_range,
half_range=half_range))
half_range=half_range,
compression_lr_multiplier=compression_lr_multiplier))

transformations.register(
TFInsertionCommand(
Expand All @@ -177,7 +179,9 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
for original_node_name, instance_index in insertion_points:
fake_quantize_name = self._get_fake_quantize_name(original_node_name, instance_index)
fake_quantize_layer = FakeQuantize(
TFQuantizerSpec.from_config(qconfig, narrow_range=False, half_range=False),
TFQuantizerSpec.from_config(qconfig,
narrow_range=False, half_range=False,
compression_lr_multiplier=compression_lr_multiplier),
name=fake_quantize_name)
self._op_names.append(fake_quantize_layer.op_name)

Expand Down
15 changes: 14 additions & 1 deletion nncf/tensorflow/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import tensorflow as tf

from nncf.tensorflow.functions import get_id_with_multiplied_grad
from nncf.tensorflow.layers.custom_objects import NNCF_CUSTOM_OBJECTS
from nncf.tensorflow.layers.custom_objects import NNCF_QUANTIZATION_OPERATONS
from nncf.tensorflow.layers.operation import InputType
Expand All @@ -32,7 +33,8 @@ def __init__(self, config: TFQuantizerSpec, data_format: str ='channels_last', *

self._op_name = f'{self.name}_quantizer'
self._quantizer = self._create_quantizer(config, self._op_name)
self._quantizer_weights = {}
self._quantizer_weights_dict = {}
self.id_with_multiplied_grad = get_id_with_multiplied_grad(config.compression_lr_multiplier)

@property
def num_bits(self):
Expand Down Expand Up @@ -62,6 +64,17 @@ def enabled(self):
def enabled(self, v):
self._quantizer.enabled = v

@property
def _quantizer_weights(self):
res = {}
for k in self._quantizer_weights_dict:
res[k] = self.id_with_multiplied_grad(self._quantizer_weights_dict[k])
return res

@_quantizer_weights.setter
def _quantizer_weights(self, value):
self._quantizer_weights_dict = value

def build(self, input_shape):
self._quantizer_weights = self._quantizer.build(
input_shape, InputType.INPUTS, self.name, self)
Expand Down
11 changes: 7 additions & 4 deletions nncf/tensorflow/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ def __init__(self, num_bits: int,
signedness_to_force: Optional[bool],
narrow_range: bool,
half_range: bool,
per_channel: bool):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
per_channel: bool,
compression_lr_multiplier: Optional[float] = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier)
self.per_channel = per_channel

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'TFQuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float] = None) -> 'TFQuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range,
qconfig.per_channel)
qconfig.per_channel,
compression_lr_multiplier)


class Quantizer(NNCFOperation):
Expand Down
8 changes: 3 additions & 5 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ def __init__(self, num_bits: int,
half_range: bool,
scale_shape: Tuple[int, ...],
logarithm_scale: bool,
compression_lr_multiplier: float = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
compression_lr_multiplier: Optional[float] = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier)
self.scale_shape = scale_shape
self.logarithm_scale = logarithm_scale
self.compression_lr_multiplier = compression_lr_multiplier


@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool,
half_range: bool, scale_shape: Tuple[int],
logarithm_scale: bool, compression_lr_multiplier: float) -> 'PTQuantizerSpec':
logarithm_scale: bool, compression_lr_multiplier: Optional[float] = None) -> 'PTQuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
Expand Down
52 changes: 52 additions & 0 deletions tests/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,20 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from pathlib import Path
from typing import Callable, List, TypeVar

from nncf.common.framework_checker import check_tf_installed
from nncf.common.framework_checker import check_torch_installed

if check_tf_installed():
import tensorflow as tf
if check_torch_installed():
import torch

TensorType = TypeVar('TensorType')

TEST_ROOT = Path(__file__).absolute().parents[1]
PROJECT_ROOT = TEST_ROOT.parent.absolute()
Expand All @@ -26,3 +38,43 @@ def get_cli_dict_args(args):
if val is not None:
cli_args[cli_key] = str(val)
return cli_args


def to_numpy(tensor: TensorType) -> np.ndarray:
if check_tf_installed() and isinstance(tensor, tf.Tensor):
return tensor.numpy()
if check_torch_installed() and isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy()
return tensor


def compare_tensor_lists(test: List[TensorType], reference: List[TensorType],
assert_fn: Callable[[np.ndarray, np.ndarray], bool]):
assert len(test) == len(reference)

for x, y in zip(test, reference):
x = to_numpy(x)
y = to_numpy(y)
assert_fn(x, y)


def check_equal(test: List[TensorType], reference: List[TensorType], rtol: float = 1e-1):
compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol))


def check_not_equal(test: List[TensorType], reference: List[TensorType], rtol: float = 1e-4):
compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_raises(AssertionError,
np.testing.assert_allclose, x, y, rtol=rtol))


def check_less(test: List[TensorType], reference: List[TensorType], rtol=1e-4):
check_not_equal(test, reference, rtol=rtol)
compare_tensor_lists(test, reference, np.testing.assert_array_less)


def check_greater(test: List[TensorType], reference: List[TensorType], rtol=1e-4):
check_not_equal(test, reference, rtol=rtol)
compare_tensor_lists(test, reference,
lambda x, y: np.testing.assert_raises(AssertionError, np.testing.assert_array_less, x, y))
Loading

0 comments on commit db1da64

Please sign in to comment.