Skip to content

Commit

Permalink
enhance nlp text-classification examples (#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he committed Aug 15, 2022
1 parent cbdb79a commit 46d9453
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 145 deletions.
Expand Up @@ -19,7 +19,7 @@ Recommend python 3.6 or higher version.
#### Install transformers

```bash
pip install transformers==4.10.0
pip install transformers
```

#### Install dependency
Expand All @@ -35,44 +35,15 @@ pip install torch

## 2. Prepare pretrained model

Before use Intel® Neural Compressor, you should fine tune the model to get pretrained model, You should also install the additional packages required by the examples:
Before use Intel® Neural Compressor, you should fine tune the model to get pretrained model or reuse fine-tuned models in [model hub](https://huggingface.co/models), You should also install the additional packages required by the examples.

### Text-classification

#### BERT
* For BERT base and glue tasks(task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI...)

```shell
export TASK_NAME=MRPC

python run_glue_tune.py \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/
```
> NOTE
>
> model_name_or_path : Path to pretrained model or model identifier from huggingface.co/models
>
> task_name : where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
The dev set results will be present within the text file 'eval_results.txt' in the specified output_dir. In case of MNLI, since there are two separate dev sets, matched and mismatched, there will be a separate output folder called '/tmp/MNLI-MM/' in addition to '/tmp/MNLI/'.

please refer to [BERT base scripts and instructions](common/examples/text-classification/README.md#PyTorch version).

* After fine tuning, you can get a checkpoint dir which include pretrained model, tokenizer and training arguments. This checkpoint dir will be used by neural_compressor tuning as below.

# Start to neural_compressor tune for Model Quantization
```shell
cd examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_dynamic/eager
```
## Glue task
### 1. To get the tuned model and its accuracy:

```bash
sh run_tuning.sh --topology=topology_name --dataset_location=/path/to/glue/data/dir --input_model=/path/to/checkpoint/dir
Expand All @@ -83,12 +54,73 @@ sh run_tuning.sh --topology=topology_name --dataset_location=/path/to/glue/data/
>
> /path/to/checkpoint/dir is the path to finetune output_dir
or

```bash
python -u ./run_glue_tune.py \
--model_name_or_path distilbert-base-uncased-finetuned-sst-2-english \
--task_name sst2 \
--do_eval \
--do_train \
--max_seq_length 128 \
--per_device_eval_batch_size 16 \
--no_cuda \
--output_dir ./int8_model_dir \
--tune \
--overwrite_output_dir
```

### 2. To get the benchmark of tuned model, includes batch_size and throughput:

```bash
python -u ./run_glue_tune.py \
--model_name_or_path ./int8_model_dir \
--task_name sst2 \
--do_eval \
--max_seq_length 128 \
--per_device_eval_batch_size 1 \
--no_cuda \
--output_dir ./output_log \
--benchmark \
--int8 \
--overwrite_output_dir
```
# HuggingFace model hub
## To upstream into HuggingFace model hub
We provide an API `save_for_huggingface_upstream` to collect configuration files, tokenizer files and int8 model weights in the format of [transformers](https://github.com/huggingface/transformers).
```
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
...
save_for_huggingface_upstream(q_model, tokenizer, output_dir)
```
Users can upstream files in the `output_dir` into model hub and reuse them with our `OptimizedModel` API.

----

## To download into HuggingFace model hub
We provide an API `OptimizedModel` to initialize int8 models from HuggingFace model hub and its usage is the same as the model class provided by [transformers](https://github.com/huggingface/transformers).
```python
from neural_compressor.utils.load_huggingface import OptimizedModel
model = OptimizedModel.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
```

We also upstreamed several int8 models into HuggingFace [model hub](https://huggingface.co/models?other=Intel%C2%AE%20Neural%20Compressor) for users to ramp up.

----

Examples of enabling Intel® Neural Compressor
============================================================

This is a tutorial of how to enable BERT model with Intel® Neural Compressor.

# User Code Analysis
## User Code Analysis

Intel® Neural Compressor supports two usages:

Expand Down
Expand Up @@ -3,6 +3,6 @@ datasets >= 1.1.3
sentencepiece != 0.1.92
protobuf
torch >= 1.3
transformers==4.10.0
transformers>=4.10.0
shap
scipy
Expand Up @@ -109,13 +109,13 @@ function run_benchmark {
echo $extra_cmd

python -u run_glue_tune.py \
--model_name_or_path ${model_name_or_path} \
--model_name_or_path ${tuned_checkpoint} \
--task_name ${TASK_NAME} \
--do_eval \
--max_seq_length ${MAX_SEQ_LENGTH} \
--per_gpu_eval_batch_size ${batch_size} \
--per_device_eval_batch_size ${batch_size} \
--no_cuda \
--output_dir ${tuned_checkpoint} \
--output_dir ./output_log \
${mode_cmd} \
${extra_cmd}
}
Expand Down
Expand Up @@ -294,14 +294,25 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
if model_args.int8:
from neural_compressor.utils.load_huggingface import OptimizedModel
model = OptimizedModel.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Preprocessing the datasets
if data_args.task_name is not None:
Expand Down Expand Up @@ -426,18 +437,13 @@ def eval_func_for_nc(model_tuned):
quantizer.model = common.Model(model)
quantizer.eval_func = eval_func_for_nc
q_model = quantizer.fit()
q_model.save(training_args.output_dir)
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)
exit(0)

if model_args.accuracy_only:
if model_args.int8:
from neural_compressor.utils.pytorch import load
new_model = load(
os.path.abspath(os.path.expanduser(training_args.output_dir)), model)
else:
new_model = model
trainer = Trainer(
model=new_model,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand All @@ -454,18 +460,12 @@ def eval_func_for_nc(model_tuned):
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (results["eval_samples_per_second"]))
print('Latency: %.3f ms' % (1 * 1000 / results["eval_samples_per_second"]))
print('Batch size = %d' % training_args.per_gpu_eval_batch_size)
print('Batch size = %d' % training_args.per_device_eval_batch_size)
exit(0)

if model_args.benchmark:
if model_args.int8:
from neural_compressor.utils.pytorch import load
new_model = load(
os.path.abspath(os.path.expanduser(training_args.output_dir)), model)
else:
new_model = model
trainer = Trainer(
model=new_model,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand All @@ -482,7 +482,7 @@ def eval_func_for_nc(model_tuned):
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (results["eval_samples_per_second"]))
print('Latency: %.3f ms' % (1 * 1000 / results["eval_samples_per_second"]))
print('Batch size = %d' % training_args.per_gpu_eval_batch_size)
print('Batch size = %d' % training_args.per_device_eval_batch_size)
exit(0)

# Initialize our Trainer
Expand Down
Expand Up @@ -82,7 +82,7 @@ function run_tuning {
--do_eval \
--do_train \
--max_seq_length ${MAX_SEQ_LENGTH} \
--per_gpu_eval_batch_size ${batch_size} \
--per_device_eval_batch_size ${batch_size} \
--no_cuda \
--output_dir ${tuned_checkpoint} \
--tune \
Expand Down
Expand Up @@ -19,7 +19,7 @@ Recommend python 3.6 or higher version.
#### Install BERT model

```bash
pip install transformers==4.10.0
pip install transformers
```

#### Install dependency
Expand All @@ -35,27 +35,7 @@ pip install torch

## 2. Prepare pretrained model

Before use Intel® Neural Compressor, you should fine tune the model to get pretrained model, You should also install the additional packages required by the examples:

### bert
For glue tasks(task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI...)

```shell
cd examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_static/fx
export TASK_NAME=MRPC
export batch_size=16

python run_glue.py
--model_name_or_path bert_base_uncased
--task_name ${TASK_NAME}
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size ${batch_size}
--learning_rate 5e-5
--num_train_epochs 5
--output_dir /path/to/checkpoint/dir
```
Before use Intel® Neural Compressor, you should fine tune the model to get pretrained model or reuse fine-tuned models in [model hub](https://huggingface.co/models), You should also install the additional packages required by the examples.

# Start to neural_compressor tune for Model Quantization
- Here we implemented several models in fx mode.
Expand All @@ -64,27 +44,65 @@ cd examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_
```
## Glue task

### 1. To get the tuned model and its accuracy:
```bash
python -u ./run_glue.py \
--model_name_or_path /path/to/checkpoint/dir \
--task_name ${TASK_NAME} \
--model_name_or_path distilbert-base-uncased-finetuned-sst-2-english \
--task_name sst2 \
--do_eval \
--do_train \
--max_seq_length 128 \
--per_device_eval_batch_size ${batch_size} \
--per_device_eval_batch_size 16 \
--no_cuda \
--output_dir /path/to/checkpoint/dir \
--output_dir ./int8_model_dir \
--tune \
--overwrite_output_dir \
--dataloader_drop_last
--overwrite_output_dir
```

### 2. To get the benchmark of tuned model, includes Batch_size and Throughput:

```bash
python -u ./run_glue.py \
--model_name_or_path ./int8_model_dir \
--task_name sst2 \
--do_eval \
--max_seq_length 128 \
--per_device_eval_batch_size 1 \
--no_cuda \
--output_dir ./output_log \
--benchmark \
--int8 \
--overwrite_output_dir
```
> NOTE
>
> /path/to/checkpoint/dir is the path to finetune output_dir

# HuggingFace model hub
## To upstream into HuggingFace model hub
We provide an API `save_for_huggingface_upstream` to collect configuration files, tokenizer files and int8 model weights in the format of [transformers](https://github.com/huggingface/transformers).
```
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
...
save_for_huggingface_upstream(q_model, tokenizer, output_dir)
```
Users can upstream files in the `output_dir` into model hub and reuse them with our `OptimizedModel` API.

## To download into HuggingFace model hub
We provide an API `OptimizedModel` to initialize int8 models from HuggingFace model hub and its usage is the same as the model class provided by [transformers](https://github.com/huggingface/transformers).
```python
from neural_compressor.utils.load_huggingface import OptimizedModel
model = OptimizedModel.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
```

---------------------
We also upstreamed several int8 models into HuggingFace [model hub](https://huggingface.co/models?other=Intel%C2%AE%20Neural%20Compressor) for users to ramp up.

----
----
## This is a tutorial of how to enable NLP model with Intel® Neural Compressor.


Expand All @@ -108,13 +126,6 @@ device: cpu

quantization:
approach: post_training_static_quant
op_wise: {
# PyTorch limitation: PyTorch unsupport specific qconfig for function when version <=1.10, will remove furture.
'default_qconfig': {
'activation': {'dtype': ['fp32']},
'weight': {'dtype': ['fp32']}
},
}

tuning:
accuracy_criterion:
Expand Down
Expand Up @@ -4,6 +4,6 @@ protobuf
scipy
scikit-learn
Keras-Preprocessing
transformers == 4.16.0
transformers >= 4.16.0
--find-links https://download.pytorch.org/whl/torch_stable.html
torch >= 1.8.0+cpu

0 comments on commit 46d9453

Please sign in to comment.