Skip to content

Commit

Permalink
Add compression_lr_multiplier for tf part
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ignatyev committed Aug 12, 2021
1 parent 169cef8 commit d25cd6a
Show file tree
Hide file tree
Showing 28 changed files with 902 additions and 539 deletions.
12 changes: 9 additions & 3 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def __init__(self, num_bits: int,
mode: QuantizationMode,
signedness_to_force: bool,
narrow_range: bool,
half_range: bool):
half_range: bool,
compression_lr_multiplier: Optional[float] = None):
"""
:param num_bits: Bitwidth of the quantization.
:param mode: The mode of quantization (symmetric or asymmetric).
Expand All @@ -147,23 +148,28 @@ def __init__(self, num_bits: int,
naive case, False if all 2^`num_bits` quantizations should be used.
:param half_range: If ``True`` effectively only a half of an quantizer range are used.
False - the full range are used.
:compression_lr_multiplier: Multiplier for gradient values
"""
self.num_bits = num_bits
self.mode = mode
self.signedness_to_force = signedness_to_force
self.narrow_range = narrow_range
self.half_range = half_range
self.compression_lr_multiplier = compression_lr_multiplier

def __eq__(self, other: 'QuantizerSpec'):
return self.__dict__ == other.__dict__

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'QuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig,
narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float]) -> 'QuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range)
half_range,
compression_lr_multiplier)


class QuantizationConstraints:
Expand Down
13 changes: 13 additions & 0 deletions nncf/tensorflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import tensorflow as tf
from typing import Callable


@tf.function
Expand All @@ -24,3 +25,15 @@ def st_threshold(input_):
def grad(upstream):
return upstream
return tf.round(input_), grad


def get_id_with_multiplied_grad(grad_multiplier: float) -> Callable[[tf.Tensor], tf.Tensor]:
@tf.custom_gradient
def id_with_multiplied_grad(x):
def grad(upstream):
if grad_multiplier is None:
return upstream
return grad_multiplier * upstream
return x, grad

return id_with_multiplied_grad
4 changes: 3 additions & 1 deletion nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,12 @@ def _get_insertion_command_binary_mask(self, layer_name: str,
attr_name: str) -> TFInsertionCommand:
op_name = self._get_pruning_operation_name(layer_name, attr_name)
self._op_names.append(op_name)
compression_lr_multiplier = \
self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', self.name)

return TFInsertionCommand(
target_point=TFLayerWeight(layer_name, attr_name),
callable_object=BinaryMask(op_name),
callable_object=BinaryMask(op_name, compression_lr_multiplier=compression_lr_multiplier),
priority=TransformationPriority.PRUNING_PRIORITY
)

Expand Down
17 changes: 11 additions & 6 deletions nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
limitations under the License.
"""
from copy import deepcopy
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf

Expand Down Expand Up @@ -364,6 +361,7 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup:
quantizable_weighted_layer_nodes,
custom_layer_nodes)
setup = TFQuantizationSetup()
compression_lr_multiplier = self._get_compression_lr_multiplier()

quantized_layer_names_vs_qconfigs = {} # type: Dict[str, QuantizerConfig]

Expand Down Expand Up @@ -397,7 +395,8 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup:
applied_saturation_fix = applied_saturation_fix or half_range
quantizer_spec = TFQuantizerSpec.from_config(qconfig,
narrow_range=not half_range,
half_range=half_range)
half_range=half_range,
compression_lr_multiplier=compression_lr_multiplier)
target_point = TFLayerWeight(layer_info.layer_name, weight_def.weight_attr_name)
qpoint = TFQuantizationPoint(op_name, quantizer_spec, target_point)
setup.add_quantization_point(qpoint)
Expand All @@ -408,7 +407,9 @@ def _get_quantizer_setup(self, model: tf.keras.Model) -> TFQuantizationSetup:
target_node_name = ip.target_node_name
input_port_id = ip.input_port_id
fake_quantize_name = self._get_fake_quantize_name(target_node_name, input_port_id)
quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False)
compression_lr_multiplier = self._get_compression_lr_multiplier()
quantizer_spec = TFQuantizerSpec.from_config(qp.qconfig, narrow_range=False, half_range=False,
compression_lr_multiplier=compression_lr_multiplier)
fake_quantize_layer = FakeQuantize(
quantizer_spec,
name=fake_quantize_name)
Expand Down Expand Up @@ -566,6 +567,10 @@ def _get_fake_quantize_name(self, node_name: NNCFNodeName, input_port_id: int =
def _get_quantizer_operation_name(self, layer_name, weight_attr_name):
return f'{layer_name}_{weight_attr_name}_quantizer'

def _get_compression_lr_multiplier(self) -> Optional[float]:
return self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
self.name)


class QuantizationController(BaseCompressionAlgorithmController):
def __init__(self, target_model, config, op_names: List[str]):
Expand Down
15 changes: 15 additions & 0 deletions nncf/tensorflow/quantization/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
"""

import tensorflow as tf
from typing import Dict, Optional

from nncf.tensorflow.functions import get_id_with_multiplied_grad


def symmetric_quantize(inputs,
Expand Down Expand Up @@ -46,6 +49,18 @@ def asymmetric_quantize(inputs,
narrow_range, per_channel)


def add_id_with_multiplied_grad_op(weights: Dict[str, tf.Tensor], compression_lr_multiplier: Optional[float] = None):
if compression_lr_multiplier is None:
return weights

id_with_multiplied_grad = get_id_with_multiplied_grad(compression_lr_multiplier)

modified_weights = {}
for k in weights.keys():
modified_weights[k] = id_with_multiplied_grad(weights[k])
return modified_weights


def _fake_quant_with_min_max_vars(inputs, min_var, max_var, num_bits, narrow_range,
per_channel):
if per_channel:
Expand Down
30 changes: 22 additions & 8 deletions nncf/tensorflow/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.tensorflow.layers.data_layout import get_channel_axis
from nncf.tensorflow.layers.data_layout import get_channel_size
from nncf.tensorflow.layers.operation import NNCFOperation
from nncf.tensorflow.quantization.functions import add_id_with_multiplied_grad_op
from nncf.tensorflow.quantization.functions import asymmetric_quantize
from nncf.tensorflow.quantization.functions import symmetric_quantize

Expand All @@ -36,18 +37,21 @@ def __init__(self, num_bits: int,
signedness_to_force: Optional[bool],
narrow_range: bool,
half_range: bool,
per_channel: bool):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
per_channel: bool,
compression_lr_multiplier: Optional[float] = None):
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range, compression_lr_multiplier)
self.per_channel = per_channel

@classmethod
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool) -> 'TFQuantizerSpec':
def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: bool,
compression_lr_multiplier: Optional[float] = None) -> 'TFQuantizerSpec':
return cls(qconfig.num_bits,
qconfig.mode,
qconfig.signedness_to_force,
narrow_range,
half_range,
qconfig.per_channel)
qconfig.per_channel,
compression_lr_multiplier)

def get_state(self) -> Dict[str, Any]:
"""
Expand All @@ -62,7 +66,8 @@ def get_state(self) -> Dict[str, Any]:
'signedness_to_force': self.signedness_to_force,
'narrow_range': self.narrow_range,
'half_range': self.half_range,
'per_channel': self.per_channel
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier
}

@classmethod
Expand Down Expand Up @@ -285,6 +290,7 @@ def __init__(self, name: str, qspec: TFQuantizerSpec):
self.narrow_range = qspec.narrow_range
self.signedness_to_force = qspec.signedness_to_force
self._half_range = qspec.half_range
self.compression_lr_multiplier = qspec.compression_lr_multiplier

@property
def half_range(self):
Expand Down Expand Up @@ -336,6 +342,8 @@ def apply_saturation_fix(self, weights):
self._half_range = False

def quantize(self, inputs, weights, _):
weights = add_id_with_multiplied_grad_op(weights, self.compression_lr_multiplier)

def _half_range_quantize():
return symmetric_quantize(
inputs,
Expand Down Expand Up @@ -389,6 +397,7 @@ def get_config(self):
'narrow_range': self.narrow_range,
'half_range': self._half_range,
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier,
}
config = {
'quantizer_spec': qspec_dict,
Expand All @@ -404,12 +413,12 @@ def from_config(cls, config):
signedness_to_force=qspec_dict['signedness_to_force'],
narrow_range=qspec_dict['narrow_range'],
half_range=qspec_dict['half_range'],
per_channel=qspec_dict['per_channel'])
per_channel=qspec_dict['per_channel'],
compression_lr_multiplier=qspec_dict['compression_lr_multiplier'])
name = config['name']
return cls(name, qspec)



@NNCF_CUSTOM_OBJECTS.register()
@NNCF_QUANTIZATION_OPERATONS.register(QuantizationMode.ASYMMETRIC)
class AsymmetricQuantizer(Quantizer):
Expand All @@ -419,6 +428,7 @@ def __init__(self, name: str, qspec: TFQuantizerSpec):
self.narrow_range = qspec.narrow_range
self.per_channel = qspec.per_channel
self._half_range = qspec.half_range
self.compression_lr_multiplier = qspec.compression_lr_multiplier

@property
def half_range(self):
Expand Down Expand Up @@ -466,6 +476,8 @@ def apply_saturation_fix(self, weights):
self._half_range = False

def quantize(self, inputs, weights, _):
weights = add_id_with_multiplied_grad_op(weights, self.compression_lr_multiplier)

def _half_range_quantize():
return asymmetric_quantize(
inputs,
Expand Down Expand Up @@ -519,6 +531,7 @@ def get_config(self):
'narrow_range': self.narrow_range,
'half_range': self._half_range,
'per_channel': self.per_channel,
'compression_lr_multiplier': self.compression_lr_multiplier,
}
config = {
'quantizer_spec': qspec_dict,
Expand All @@ -534,6 +547,7 @@ def from_config(cls, config):
signedness_to_force=None,
narrow_range=qspec_dict['narrow_range'],
half_range=qspec_dict['half_range'],
per_channel=qspec_dict['per_channel'])
per_channel=qspec_dict['per_channel'],
compression_lr_multiplier=qspec_dict['compression_lr_multiplier'])
name = config['name']
return cls(name, qspec)
10 changes: 8 additions & 2 deletions nncf/tensorflow/sparsity/magnitude/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
converter = TFModelConverterFactory.create(model)
nncf_graph = converter.convert()
transformations = TFTransformationLayout()
compression_lr_multiplier = \
self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', self.name)

processed_shared_layer_names = set() # type: Set[str]

Expand Down Expand Up @@ -87,7 +89,7 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations.register(
TFInsertionCommand(
target_point=TFLayerWeight(layer_info.layer_name, weight_def.weight_attr_name),
callable_object=BinaryMask(op_name),
callable_object=BinaryMask(op_name, compression_lr_multiplier=compression_lr_multiplier),
priority=TransformationPriority.SPARSIFICATION_PRIORITY
))
elif node.metatype in WEIGHTABLE_TF_OP_METATYPES:
Expand All @@ -101,7 +103,11 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations.register(
TFInsertionCommand(
target_point=TFLayerWeight(layer_info.layer_name, weight_attr_name),
callable_object=BinaryMaskWithWeightsBackup(op_name, weight_attr_name),
callable_object=BinaryMaskWithWeightsBackup(
op_name,
weight_attr_name,
compression_lr_multiplier=compression_lr_multiplier
),
priority=TransformationPriority.SPARSIFICATION_PRIORITY
))

Expand Down
11 changes: 9 additions & 2 deletions nncf/tensorflow/sparsity/magnitude/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
"""

import tensorflow as tf
from typing import Optional

from nncf.tensorflow.graph.utils import get_weight_by_name
from nncf.tensorflow.layers.custom_objects import NNCF_CUSTOM_OBJECTS
from nncf.tensorflow.layers.operation import InputType
from nncf.tensorflow.layers.operation import NNCFOperation
from nncf.tensorflow.sparsity.magnitude.functions import apply_mask
from nncf.tensorflow.sparsity.rb.operation import add_id_with_multiplied_grad_op


@NNCF_CUSTOM_OBJECTS.register()
class BinaryMask(NNCFOperation):
def __init__(self, name: str, compression_lr_multiplier: Optional[float] = None):
super().__init__(name)
self.compression_lr_multiplier = compression_lr_multiplier

def build(self, input_shape, input_type, name, layer):
if input_type is not InputType.WEIGHTS:
raise ValueError(
Expand All @@ -40,6 +46,7 @@ def build(self, input_shape, input_type, name, layer):
}

def call(self, inputs, weights, _):
weights = add_id_with_multiplied_grad_op(weights)
return apply_mask(inputs, weights['mask'])

@staticmethod
Expand All @@ -55,8 +62,8 @@ def get_binary_mask(op_weights):

@NNCF_CUSTOM_OBJECTS.register()
class BinaryMaskWithWeightsBackup(BinaryMask):
def __init__(self, name: str, w_name_to_bkup: str = None):
super().__init__(name)
def __init__(self, name: str, w_name_to_bkup: str = None, compression_lr_multiplier: Optional[float] = None):
super().__init__(name, compression_lr_multiplier)
self.w_name_to_bkup = w_name_to_bkup
self.bkup_var = None

Expand Down
7 changes: 6 additions & 1 deletion nncf/tensorflow/sparsity/rb/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
converter = TFModelConverterFactory.create(model)
nncf_graph = converter.convert()
transformations = TFTransformationLayout()
compression_lr_multiplier = \
self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', self.name)

processed_shared_layer_names = set() # type: Set[str]

Expand All @@ -74,7 +76,10 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa
transformations.register(
TFInsertionCommand(
target_point=TFLayerWeight(layer_info.layer_name, weight_def.weight_attr_name),
callable_object=RBSparsifyingWeight(op_name),
callable_object=RBSparsifyingWeight(
op_name,
compression_lr_multiplier=compression_lr_multiplier
),
priority=TransformationPriority.SPARSIFICATION_PRIORITY
))

Expand Down
Loading

0 comments on commit d25cd6a

Please sign in to comment.