Skip to content

Commit

Permalink
Add bias quantization in QAT and refactor the code of weight quantiza…
Browse files Browse the repository at this point in the history
…tion (#2914)
  • Loading branch information
linbinskn committed Oct 10, 2020
1 parent 6126960 commit 0a6c234
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 26 deletions.
5 changes: 2 additions & 3 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
Expand Down Expand Up @@ -617,7 +616,7 @@ def forward(ctx, tensor, quant_type, wrapper, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
Expand Down
72 changes: 51 additions & 21 deletions src/sdk/pynni/nni/compression/torch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import logging
import copy
import torch
from schema import Schema, And, Or, Optional
from ..utils.config_validation import CompressorSchema
Expand All @@ -15,6 +16,7 @@
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
Expand All @@ -29,13 +31,15 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)

weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight = weight
return weight

def update_ema(biased_ema, value, decay, step):
"""
Expand All @@ -60,6 +64,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema


def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
Expand Down Expand Up @@ -116,6 +121,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
"""
Parameters
Expand Down Expand Up @@ -215,20 +221,35 @@ def _dequantize(self, op, quantized_val):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return weight

# if bias exists, quantize bias to uint32
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
bias_bits = 32
rmin, rmax = torch.min(bias), torch.max(bias)
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
bias = self._quantize(bias_bits, module, bias)
bias = self._dequantize(module, bias)
wrapper.module.bias.data = bias


# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
wrapper.module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
Expand All @@ -241,8 +262,10 @@ def quantize_output(self, output, wrapper, **kwargs):
return output

current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
Expand All @@ -264,6 +287,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)

Expand All @@ -287,17 +311,20 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
# wrapper.module.weight.data = weight
return weight

def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
scale = pow(2, q_bits) - 1
output = torch.round(input_ri * scale) / scale
return output


Expand All @@ -314,6 +341,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
Expand All @@ -339,11 +367,13 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
weight[weight == 0] = 1
wrapper.module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
Expand Down
18 changes: 16 additions & 2 deletions src/sdk/pynni/tests/test_compressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,34 @@ def test_torch_QAT_quantizer(self):
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()

# test quantize
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.old_weight.data = weight
model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))

# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
Expand Down

0 comments on commit 0a6c234

Please sign in to comment.