Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LSQ quantizer #3503

Merged
merged 18 commits into from
May 18, 2021
Merged

Add LSQ quantizer #3503

merged 18 commits into from
May 18, 2021

Conversation

chenbohua3
Copy link
Contributor

@chenbohua3 chenbohua3 commented Mar 31, 2021

This PR contains an implementation of LSQ quantizer (Learned Step Size Quantization, ICLR 2020, see here). It uses gradients to update quantization scales and can achieve sound results in our production environment, especially for lower bits.

In the mnist experiment, it can get about 99.20% top1 acc. And the results on imagenet-1k are on going.

@ghost
Copy link

ghost commented Mar 31, 2021

CLA assistant check
All CLA requirements met.

@@ -146,7 +146,7 @@ def __init__(self, model, config_list, optimizer=None):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad
self.quant_grad = QATGrad.apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we have to move apply here instead of using it directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is for avoiding STE in LSQ quantizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is aimed at unifying the framework of quantizers with customized gradient and quantizers with auto-grad gradient. Also, use.apply is the way recommended by PyTorch (see here)



class LsqQuantizer(Quantizer):
Copy link
Contributor

@linbinskn linbinskn Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstring as the other Quantizers, especially for parameters and return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if "weight" in config.get("quant_types", []):
# todo: support per-channel quantization for weight since TensorRT it for conv weight
q_bit = get_bits_length(config, "weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current implementation, we only support single bit quantization in LsqQuantizer? Can we support mixed precision right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that mixed precision of quantization is supported in this implementation since each layer has its own q_bit. We can achieve mixed quantization through some specific settings in config_list like:

configure_list = [{
        'quant_types': ['weight'],
        'quant_bits': 8,
        'op_types': ['Conv2d'],
        'op_names': ['features.3']
    }, {
        'quant_types': ['weight'],
        'quant_bits': 7,
        'op_types': ['Conv2d'],
        'op_names': ['features.6']
    }]

# todo: in the origin paper, the initial value of activation is calculated from first input batch
if "output" in config.get("quant_types", []):
q_bit = get_bits_length(config, "")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question with single bit weight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the weights

def quantize(self, x, scale, zero_point, qmin, qmax):
grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5)
scale = self.grad_scale(scale, grad_scale_factor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused about the name of value and function. Can we polish naming here or in grad_scale function? For instance, change the second parameter name 'scale' to 'scale_factor'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of functions and variables are the same as those defined in the paper.

module = wrapper.module
output = self.quantize(output, module.scale, module.zero_point, module.activation_qmin, module.activation_qmax)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this quantization algorithm support exporting model and related quantization parameters? If yes, maybe we can consider adding function export_model() based on what parameters should export to inference framework like TensorRT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check it out

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantForward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep the original forward and backward structure, the Lsq can forward as usual and backward by STE. In this way, will it be anything wrong? May be have something to do with the update of scale and zeropoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will not be anything wrong if the gradients are handled carefully. However, there exists one major limitation for the origin framework, that is, we must customize all gradients for all learnable parameters. If the gradient-based algorithms become complex, it will be troubling and error-prone to do the customization. In this situation, I think using the auto-grad system to determine the gradient is more convenient for users.

@linbinskn
Copy link
Contributor

Many good points in this PR! Please test exported model on TensorRT and modify initialization of activation scale.

@chenbohua3
Copy link
Contributor Author

I have added codes about using the first batch data to initialize activation scale. After that, we can get about 99.20% top1 acc with the provided example.

Also, I have tested exporting the model to TensorRT. The transformed TensorRT model get almost the same acc with the PyTorch model.

qmax = module.activation_qmax
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.scale.data = init_oup_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that weight and activation use the same scale in single module which means weight and activation have the same rescale parameter, and the value of scale will update by the gradient of weight and activation simultaneously. What consequence would be caused if we quantized both weight and activation of the same layer? Would it cause something wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right:) Now each layer will construct input_scale/weight_scale/output_sclae according to the config setting.

calibration_config[name]['tracked_min_activation'] = -abs_max_activation
calibration_config[name]['tracked_max_activation'] = abs_max_activation
if hasattr(module, 'input_bit'):
calibration_config[name]['weight_bit'] = int(module.input_bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why assigning calibration_config[name]['weight_bit'] with module.input_bit instead of module.weight_bit. If 'weight_bit' is not equal to input_bit when setting the config, the export result will be incorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to here, weight_bit is used to determine whether set input tensor's dynamic ranges or not, which I think may be not appropriate. Assigning input_bit to weight_bit here is just to be consistent with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we choose to record range of input tensor during the process of quantizing weight in the algorithm QAT. The reason why we handle it in this way is the requirement of integration with TensorRT which needs input tensor's dynamic range when setting layer precision to 8bit. So we record input dynamic range as here.
And if we want to export LSQ model to TensorRT, input dynamic range should also be set in most situations and input_bit should be the same as weight_bit.
However, it is still strange not to set calibration_config[name]['weight_bit] with weight_bit since we already have the value of weight_bit ==.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. How about changing the codes like:

 if hasattr(module, 'weight_bit'):
     calibration_config[name]['weight_bit'] = int(module.weight_bit)
     abs_max_input = float(module.input_scale * module.input_qmax)
     calibration_config[name]['tracked_min_input'] = -abs_max_input
     calibration_config[name]['tracked_max_input'] = abs_max_input

@linbinskn
Copy link
Contributor

linbinskn commented May 17, 2021

Looks good. Completing related doc is necessary. Please refer to overview, quantization and Quantizer. Feel free to ask me if have any questions.

Learned Step Size Quantization (ICLR 2020)
https://arxiv.org/pdf/1902.08153.pdf
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please align

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we only quantize the first input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that currently the quantization framework only supports layers with single input (see here, so is the trt backend, see here ). So current implementation does not support layers with multi inputs. It may be a better choice to modify the lsq quantizer to support layers with multi inputs after the framework supports it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, it is reasonable

@chenbohua3
Copy link
Contributor Author

Docs have been added:)

type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{'weight', 8} -> {'weight': 8}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if "input" in config.get("quant_types", []):
# scale of activation will be initialized using the first batch data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation -> input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def grad_scale(x, scale):
"""
Used to scale the gradient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend explaining this function in detail since both of reviewers were confused during reviewing this part. Whatever, I think this function is also part of key implementation of LSQ which can helps others understand the insight of this algorithm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


..

We introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned in conjunction with other network parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We -> The authors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()

You can view example for more information
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add a hyperlink to the example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

model = Mnist()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

v2.3 automation moved this from Review in progress to Reviewer approved May 18, 2021
@QuanluZhang
Copy link
Contributor

@chenbohua3 looks great, thanks for your contribution!

@QuanluZhang QuanluZhang merged commit af929fd into microsoft:master May 18, 2021
v2.3 automation moved this from Reviewer approved to Done May 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
v2.3
Done
Development

Successfully merging this pull request may close these issues.

None yet

5 participants