Skip to content

Commit

Permalink
Add feature of distillation for quantization (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
XinyuYe-Intel committed Sep 6, 2022
1 parent 8c2330d commit 03f1f3e
Show file tree
Hide file tree
Showing 8 changed files with 820 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Step-by-Step
============

This document is used to illustrate how to run the distillation for quantization examples.
<br>
These examples will take a NLP model fine tuned on the down stream task, use its copy as a teacher model, and do distillation during the process of quantization aware training.
<br>
For more informations of this algorithm, please refer to the paper [ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers](https://arxiv.org/abs/2206.01861)

# Prerequisite

## Python Version

Recommend python 3.7 or higher version.


## Install dependency

```shell
pip install -r requirements.txt
```

# Start running neural_compressor implementation of distillation for quantization

Below are example NLP tasks of distillation for quantization to quantize the fine tuned BERT model on the specific tasks.
<br>
It requires the pre-trained task specific model such as `yoshitomo-matsubara/bert-base-uncased-sst2` from yoshitomo-matsubara Huggingface portal as the teacher model for distillation, as well as for quantization.
<br>
The distillation configuration is specified in yaml file i.e. distillation.yaml, the quantization aware training configuration is specified in yaml file i.e. qat.yaml.

## SST-2 task

```bash
python run_glue_no_trainer.py --task_name sst2 --model_name_or_path yoshitomo-matsubara/bert-base-uncased-sst2 --teacher_model_name_or_path yoshitomo-matsubara/bert-base-uncased-sst2 --batch_size 32 --do_eval --do_quantization --do_distillation --pad_to_max_length --num_train_epochs 9 --output_dir /path/to/output_dir
```

## MNLI task

```bash
python run_glue_no_trainer.py --task_name mnli --model_name_or_path yoshitomo-matsubara/bert-base-uncased-mnli --teacher_model_name_or_path yoshitomo-matsubara/bert-base-uncased-mnli --batch_size 32 --do_eval --do_quantization --do_distillation --pad_to_max_length --num_train_epochs 9 --output_dir /path/to/output_dir
```

## QQP task

```bash
python run_glue_no_trainer.py --task_name qqp --model_name_or_path yoshitomo-matsubara/bert-base-uncased-qqp --teacher_model_name_or_path yoshitomo-matsubara/bert-base-uncased-qqp --batch_size 32 --do_eval --do_quantization --do_distillation --pad_to_max_length --num_train_epochs 9 --output_dir /path/to/output_dir
```

## QNLI task

```bash
python run_glue_no_trainer.py --task_name qnli --model_name_or_path yoshitomo-matsubara/bert-base-uncased-qnli --teacher_model_name_or_path yoshitomo-matsubara/bert-base-uncased-qnli --batch_size 32 --do_eval --do_quantization --do_distillation --pad_to_max_length --num_train_epochs 9 --output_dir /path/to/output_dir
```

# Results
We listed the results on 4 distillation for quantization experiments, for comparison, we also listed the results of QAT as well as the baselie metrics of the FP32 model. These experiments use a fine-tuned BERT-Base model on the 4 GLUE task (SST-2, QNLI, QQP and MNLI), data in the column of FP32 is the metrics of the 4 fine-tuned BERT-Base model, data in the column of INT8 (QAT) is the metrics of the 4 INT8 BERT-Base model from QAT process, data in the column of INT8 (Distillation for Quantization) is the metrics of the 4 INT8 BERT-Base model from distillation for quantization process.
| | FP32 | INT8 (QAT) | INT8 (Distillation for Quantization) |
|---------------|----------------|--------------------------|--------------------------|
| SST-2 (ACC) | 92.48% | 91.90% | 92.01% |
| QNLI (ACC) | 91.58% | 89.49% | 90.33% |
| QQP (ACC/F1) | 90.95%/87.83% | 89.60%/86.56% | 91.07%/87.91% |
| MNLI (ACC) | 84.20% | 78.67% | 84.42% |
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
version: 1.0

model:
name: bert_distillation
framework: pytorch_fx

distillation:
train:
optimizer:
SGD:
learning_rate: 0.001
criterion:
IntermediateLayersKnowledgeDistillationLoss:
layer_mappings: [
['bert.encoder.layer.0.output', 'bert.encoder.layer.0.output'],
['bert.encoder.layer.0.attention', '1', 'bert.encoder.layer.0.attention', '1'],
['bert.encoder.layer.1.output', 'bert.encoder.layer.1.output'],
['bert.encoder.layer.1.attention', '1', 'bert.encoder.layer.1.attention', '1'],
['bert.encoder.layer.2.output', 'bert.encoder.layer.2.output'],
['bert.encoder.layer.2.attention', '1', 'bert.encoder.layer.2.attention', '1'],
['bert.encoder.layer.3.output', 'bert.encoder.layer.3.output'],
['bert.encoder.layer.3.attention', '1', 'bert.encoder.layer.3.attention', '1'],
['bert.encoder.layer.4.output', 'bert.encoder.layer.4.output'],
['bert.encoder.layer.4.attention', '1', 'bert.encoder.layer.4.attention', '1'],
['bert.encoder.layer.5.output', 'bert.encoder.layer.5.output'],
['bert.encoder.layer.5.attention', '1', 'bert.encoder.layer.5.attention', '1'],
['bert.encoder.layer.6.output', 'bert.encoder.layer.6.output'],
['bert.encoder.layer.6.attention', '1', 'bert.encoder.layer.6.attention', '1'],
['bert.encoder.layer.7.output', 'bert.encoder.layer.7.output'],
['bert.encoder.layer.7.attention', '1', 'bert.encoder.layer.7.attention', '1'],
['bert.encoder.layer.8.output', 'bert.encoder.layer.8.output'],
['bert.encoder.layer.8.attention', '1', 'bert.encoder.layer.8.attention', '1'],
['bert.encoder.layer.9.output', 'bert.encoder.layer.9.output'],
['bert.encoder.layer.9.attention', '1', 'bert.encoder.layer.9.attention', '1'],
['bert.encoder.layer.10.output', 'bert.encoder.layer.10.output'],
['bert.encoder.layer.10.attention', '1', 'bert.encoder.layer.10.attention', '1'],
['bert.encoder.layer.11.output', 'bert.encoder.layer.11.output'],
['bert.encoder.layer.11.attention', '1', 'bert.encoder.layer.11.attention', '1'],
['classifier', 'classifier'],
]
loss_weights: [1, 1, 1, 1, 1,
1, 1, 1, 1, 1,
1, 1, 1, 1, 1,
1, 1, 1, 1, 1,
1, 1, 1, 1, 1]

tuning:
accuracy_criterion:
relative: 0.01 # the tuning target of accuracy loss percentage: 1%
exit_policy:
timeout: 0 # tuning timeout (seconds)
random_seed: 9527 # random seed
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

version: 1.0

model: # mandatory. used to specify model specific information.
name: bert_qat
framework: pytorch_fx # mandatory. supported values are tensorflow, pytorch, pytorch_ipex, onnxrt_integer, onnxrt_qlinear or mxnet; allow new framework backend extension.

quantization: # optional. required for QAT and PTQ.
approach: quant_aware_training

tuning:
random_seed: 9527 # optional. random seed for deterministic tuning.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
transformers
torch
datasets
accelerate
Loading

0 comments on commit 03f1f3e

Please sign in to comment.