Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx.load() | DecodeError: Error parsing message #10342

Closed
danielbellhv opened this issue Jan 20, 2022 · 16 comments
Closed

onnx.load() | DecodeError: Error parsing message #10342

danielbellhv opened this issue Jan 20, 2022 · 16 comments
Labels
quantization issues related to quantization

Comments

@danielbellhv
Copy link

danielbellhv commented Jan 20, 2022

Bug issue.

Goal: re-develop this BERT Notebook to use textattack/albert-base-v2-MRPC.

Kernel: conda_pytorch_p36. Deleted all output files and did Restart & Run All.

I can successfully create and save an ONNX model from HuggingFace Transformers model in run time memory. Error occurs when onnx.load(), from storage into memory.

Are my ONNX files corrupted?

albert.onnx and alber.opt.onnx here.


Section 2.1 - export in-memory PyTorch model as ONNX model:

import onnxruntime

def export_onnx_model(args, model, tokenizer, onnx_model_path):
    with torch.no_grad():
        inputs = {'input_ids':      torch.ones(1,128, dtype=torch.int64),
                    'attention_mask': torch.ones(1,128, dtype=torch.int64),
                    'token_type_ids': torch.ones(1,128, dtype=torch.int64)}
        outputs = model(**inputs)

        symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
        torch.onnx.export(model,                                            # model being run
                    (inputs['input_ids'],                             # model input (or a tuple for multiple inputs)
                    inputs['attention_mask'], 
                    inputs['token_type_ids']),                                         # model input (or a tuple for multiple inputs)
                    onnx_model_path,                                # where to save the model (can be a file or file-like object)
                    opset_version=11,                                 # the ONNX version to export the model to
                    do_constant_folding=True,                         # whether to execute constant folding for optimization
                    input_names=['input_ids',                         # the model's input names
                                'input_mask', 
                                'segment_ids'],
                    output_names=['output'],                    # the model's output names
                    dynamic_axes={'input_ids': symbolic_names,        # variable length axes
                                'input_mask' : symbolic_names,
                                'segment_ids' : symbolic_names})
        logger.info("ONNX Model exported to {0}".format(onnx_model_path))

export_onnx_model(configs, model, tokenizer, "albert.onnx")

Then optimisation:

pip install torch_optimizer
import torch_optimizer as optim

optimizer = optim.DiffGrad(model.parameters(), lr=0.001)
optimizer.step()

torch.save(optimizer.state_dict(), 'albert.opt.onnx')

Section 2.2 Quantize ONNX model:

from onnxruntime.quantization import quantize_dynamic, QuantType
import onnx

def quantize_onnx_model(onnx_model_path, quantized_model_path):    
    onnx_opt_model = onnx.load(onnx_model_path)  # DecodeError
    quantize_dynamic(onnx_model_path,
                     quantized_model_path,
                     weight_type=QuantType.QInt8)

    logger.info(f"quantized model saved to:{quantized_model_path}")

quantize_onnx_model('albert.opt.onnx', 'albert.opt.quant.onnx')

print('ONNX full precision model size (MB):', os.path.getsize("albert.opt.onnx")/(1024*1024))
print('ONNX quantized model size (MB):', os.path.getsize("albert.opt.quant.onnx")/(1024*1024))

Traceback:

---------------------------------------------------------------------------
DecodeError                               Traceback (most recent call last)
<ipython-input-16-2d2d32b0a667> in <module>
     10     logger.info(f"quantized model saved to:{quantized_model_path}")
     11 
---> 12 quantize_onnx_model('albert.opt.onnx', 'albert.opt.quant.onnx')
     13 
     14 print('ONNX full precision model size (MB):', os.path.getsize("albert.opt.onnx")/(1024*1024))

<ipython-input-16-2d2d32b0a667> in quantize_onnx_model(onnx_model_path, quantized_model_path)
      3 
      4 def quantize_onnx_model(onnx_model_path, quantized_model_path):
----> 5     onnx_opt_model = onnx.load(onnx_model_path)
      6     quantize_dynamic(onnx_model_path,
      7                      quantized_model_path,

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/onnx/__init__.py in load_model(f, format, load_external_data)
    119     '''
    120     s = _load_bytes(f)
--> 121     model = load_model_from_string(s, format=format)
    122 
    123     if load_external_data:

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/onnx/__init__.py in load_model_from_string(s, format)
    156     Loaded in-memory ModelProto
    157     '''
--> 158     return _deserialize(s, ModelProto())
    159 
    160 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/onnx/__init__.py in _deserialize(s, proto)
     97                          '\ntype is {}'.format(type(proto)))
     98 
---> 99     decoded = cast(Optional[int], proto.ParseFromString(s))
    100     if decoded is not None and decoded != len(s):
    101         raise google.protobuf.message.DecodeError(

DecodeError: Error parsing message

Output Files:

albert.onnx  # original save
albert.opt.onnx  # optimised version save

Please let me know if there's anything else I can add to post.

@danielbellhv
Copy link
Author

danielbellhv commented Jan 20, 2022

I had to re-develop this notebook in order to work with Albert. Might the config variables, therefore, have something to do with this?:

configs = Namespace()

# The output directory for the fine-tuned model, $OUT_DIR.
configs.output_dir = "./MRPC/"

# The data directory for the MRPC task in the GLUE benchmark, $GLUE_DIR/$TASK_NAME.
configs.data_dir = "./glue_data/MRPC"

# The model name or path for the pre-trained model.
configs.model_name_or_path = "albert-base-v2"
# The maximum length of an input sequence
configs.max_seq_length = 128

# Prepare GLUE task.
configs.task_name = "MRPC".lower()
configs.processor = processors[configs.task_name]()
configs.output_mode = output_modes[configs.task_name]
configs.label_list = configs.processor.get_labels()
configs.model_type = "albert".lower()
configs.do_lower_case = True

# Set the device, batch size, topology, and caching flags.
configs.device = "cpu"
configs.eval_batch_size = 1
configs.n_gpu = 0
configs.local_rank = -1
configs.overwrite_cache = False

@danielbellhv
Copy link
Author

I can successfully evaluate the model before exporting to ONNX. Yet the export code hasn't been altered.

@edgchen1 edgchen1 added the quantization issues related to quantization label Jan 20, 2022
@yuslepukhin
Copy link
Member

I can successfully evaluate the model before exporting to ONNX. Yet the export code hasn't been altered.

You can also file an issue with ONNX for broader audience (onnx standard library is a separate product).

@tianleiwu
Copy link
Contributor

We have a tool that could export and test ALBert. Currently, it only support one input (input_ids):
python -m onnxruntime.transformers.benchmark -m albert-base-v2 -i 1 -t 100 -b 1 -s 128 -e onnxruntime --model_class AutoModel -p int8 -o -v

Both fp32 and int8 ONNX model generated in this way could run in my machine.

@danielbellhv
Copy link
Author

Both fp32 and int8 ONNX model generated in this way could run in my machine.

Is fp32 a regular ONNX model; and int8 a quantised ONNX model?

If so, where can I run this in my AWS SageMaker Jupyter Lab?

Thanks for getting back guys :)

@danielbellhv
Copy link
Author

it only supports one input (input_ids)

I updated the torch.onnx.export() code to only include that input variable (so not a tuple):

torch.onnx.export(model,                                            # model being run
                    inputs['input_ids'],                             # model input (or a tuple for multiple inputs)
                    #inputs['attention_mask'], 
                    #inputs['token_type_ids']),                                         # model input (or a tuple for multiple inputs)
                    onnx_model_path,                                # where to save the model (can be a file or file-like object)
                    opset_version=11,                                 # the ONNX version to export the model to
                    do_constant_folding=True,                         # whether to execute constant folding for optimization
                    input_names=['input_ids',                         # the model's input names
                                'input_mask', 
                                'segment_ids'],
                    output_names=['output'],                    # the model's output names
                    dynamic_axes={'input_ids': symbolic_names,        # variable length axes
                                'input_mask' : symbolic_names,
                                'segment_ids' : symbolic_names})
        logger.info("ONNX Model exported to {0}".format(onnx_model_path))

This lead to the same DecodeError on onnx.load().

@danielbellhv
Copy link
Author

danielbellhv commented Jan 21, 2022

We have a tool that could export and test ALBert.

I had to pip install ... various dependencies...

Terminal:

sh-4.2$ python -m onnxruntime.transformers.benchmark -m albert-base-v2 -i 1 -t 100 -b 1 -s 128 -e onnxruntime --model_class AutoModel -p int8 -o -v
Arguments: Namespace(batch_sizes=[1], cache_dir='./cache_models', detail_csv=None, disable_ort_io_binding=False, engines=['onnxruntime'], fusion_csv=None, input_counts=[1], model_class='AutoModel', model_source='pt', models=['albert-base-v2'], num_threads=[2], onnx_dir='./onnx_models', optimize_onnx=True, overwrite=False, precision=<Precision.INT8: 'int8'>, result_csv=None, sequence_lengths=[128], test_times=100, use_gpu=False, validate_onnx=True, verbose=False)
Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertModel: ['predictions.LayerNorm.weight', 'predictions.dense.weight', 'predictions.bias', 'predictions.decoder.weight', 'predictions.dense.bias', 'predictions.LayerNorm.bias', 'predictions.decoder.bias']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Skip export since model existed: ./onnx_models/albert_base_v2_1.onnx
./onnx_models/albert_base_v2_1.onnx is a valid ONNX model
inference result of onnxruntime is validated on ./onnx_models/albert_base_v2_1.onnx
Removed 1 Cast nodes with output type same as input
Fused LayerNormalization count: 25
Fused FastGelu count: 12
Fused SkipLayerNormalization count: 24
Fused Attention count: 11
Graph pruned: 0 inputs, 0 outputs and 121 nodes are removed
Fused Shape count: 2
Graph pruned: 0 inputs, 0 outputs and 8 nodes are removed
Remove reshape node Reshape_18 since its input shape is same as output: [2]
Graph pruned: 0 inputs, 0 outputs and 1 nodes are removed
Fused FastGelu(add bias) count: 12
Fused SkipLayerNormalization(add bias) count: 23
opset verion: 12
Optimized operators:{'EmbedLayerNormalization': 0, 'Attention': 11, 'Gelu': 0, 'FastGelu': 12, 'BiasGelu': 0, 'LayerNormalization': 1, 'SkipLayerNormalization': 24}
Sort graphs in topological order
Model saved to ./onnx_models/albert_base_v2_1_int8_cpu.onnx
./onnx_models/albert_base_v2_1_int8_cpu.onnx is a valid ONNX model
inference result of onnxruntime is validated on ./onnx_models/albert_base_v2_1_int8_cpu.onnx
Quantizing model: ./onnx_models/albert_base_v2_1_int8_cpu.onnx
Size of full precision ONNX model(MB):341.691819190979
onnxruntime.quantization.quantize is deprecated.
         Please use quantize_static for static quantization, quantize_dynamic for dynamic quantization.
Warning: Unsupported operator LayerNormalization. No schema registered for this operator.
Quantization parameters for tensor:"91" not specified
Quantization parameters for tensor:"94" not specified
Quantization parameters for tensor:"118" not specified
Quantization parameters for tensor:"148" not specified
Quantization parameters for tensor:"153" not specified
Quantization parameters for tensor:"147" not specified
Quantization parameters for tensor:"163" not specified
Quantization parameters for tensor:"178" not specified
Quantization parameters for tensor:"194" not specified
Quantization parameters for tensor:"209" not specified
Quantization parameters for tensor:"278" not specified
Quantization parameters for tensor:"293" not specified
Quantization parameters for tensor:"309" not specified
Quantization parameters for tensor:"324" not specified
Quantization parameters for tensor:"393" not specified
Quantization parameters for tensor:"408" not specified
Quantization parameters for tensor:"424" not specified
Quantization parameters for tensor:"439" not specified
Quantization parameters for tensor:"508" not specified
Quantization parameters for tensor:"523" not specified
Quantization parameters for tensor:"539" not specified
Quantization parameters for tensor:"554" not specified
Quantization parameters for tensor:"623" not specified
Quantization parameters for tensor:"638" not specified
Quantization parameters for tensor:"654" not specified
Quantization parameters for tensor:"669" not specified
Quantization parameters for tensor:"738" not specified
Quantization parameters for tensor:"753" not specified
Quantization parameters for tensor:"769" not specified
Quantization parameters for tensor:"784" not specified
Quantization parameters for tensor:"853" not specified
Quantization parameters for tensor:"868" not specified
Quantization parameters for tensor:"884" not specified
Quantization parameters for tensor:"899" not specified
Quantization parameters for tensor:"968" not specified
Quantization parameters for tensor:"983" not specified
Quantization parameters for tensor:"999" not specified
Quantization parameters for tensor:"1014" not specified
Quantization parameters for tensor:"1083" not specified
Quantization parameters for tensor:"1098" not specified
Quantization parameters for tensor:"1114" not specified
Quantization parameters for tensor:"1129" not specified
Quantization parameters for tensor:"1198" not specified
Quantization parameters for tensor:"1213" not specified
Quantization parameters for tensor:"1229" not specified
Quantization parameters for tensor:"1244" not specified
Quantization parameters for tensor:"1313" not specified
Quantization parameters for tensor:"1328" not specified
Quantization parameters for tensor:"1344" not specified
Quantization parameters for tensor:"1359" not specified
Quantization parameters for tensor:"1428" not specified
Quantization parameters for tensor:"1443" not specified
Quantization parameters for tensor:"1459" not specified
quantized model saved to:./onnx_models/albert_base_v2_1_int8_cpu.onnx
Size of quantized ONNX model(MB):87.26316165924072
Finished quantizing model: ./onnx_models/albert_base_v2_1_int8_cpu.onnx
Run onnxruntime on albert-base-v2 with input shape [1, 128]
{'engine': 'onnxruntime', 'version': '1.10.0', 'device': 'cpu', 'optimizer': True, 'precision': <Precision.INT8: 'int8'>, 'io_binding': True, 'model_name': 'albert-base-v2', 'inputs': 1, 'threads': 2, 'batch_size': 1, 'sequence_length': 128, 'datetime': '2022-01-21 10:51:21.112257', 'test_times': 100, 'latency_variance': '0.00', 'latency_90_percentile': '88.98', 'latency_95_percentile': '89.74', 'latency_99_percentile': '90.06', 'average_latency_ms': '88.02', 'QPS': '11.36'}
Fusion statistics is saved to csv file: benchmark_fusion_20220121-105130.csv
Detail results are saved to csv file: benchmark_detail_20220121-105130.csv
Summary results are saved to csv file: benchmark_summary_20220121-105130.csv

Outputs:

benchmark_summary_20220121-105130.csv
benchmark_detail_20220121-105130.csv
benchmark_fusion_20220121-105130.csv
onnx_models/  albert_base_v2_1_int8_cpu.onnx  albert_base_v2_1.onnx
sh-4.2$ ls
anaconda3                             benchmark_summary_20220121-105130.csv  LICENSE                README            sample-notebooks-1642761825  tutorials
benchmark_detail_20220121-105130.csv  cache_models                           Nvidia_Cloud_EULA.pdf  SageMaker         src
benchmark_fusion_20220121-105130.csv  examples                               onnx_models            sample-notebooks  tools

sh-4.2$ cd onnx_models/

sh-4.2$ ls
albert_base_v2_1_int8_cpu.onnx  albert_base_v2_1.onnx

Move outputs to use/ download:

sh-4.2$ mv bench* SageMaker/

sh-4.2$ mv onnx_models/* SageMaker/

@danielbellhv
Copy link
Author

We have a tool that could export and test ALBert.

Can you pass a parameter to have the models be Optimised?

Is there a usage documentation on these commands?

@danielbellhv
Copy link
Author

danielbellhv commented Jan 21, 2022

@tianleiwu I ran both albert_base_v2_1_int8_cpu.onnx and albert_base_v2_1.onnx (each as MODEL) through Section 2.3 Evaluate ONNX quantization performance and accuracy, both error...

Traceback:

/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/data/processors/glue.py:175: FutureWarning: This processor will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py
  warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
Evaluating:   0%|          | 0/408 [00:00<?, ?it/s]
Evaluating ONNXRuntime quantization accuracy and performance:

---------------------------------------------------------------------------
InvalidArgument                           Traceback (most recent call last)
<ipython-input-16-974877a1c2e5> in <module>
     81 
     82 print('Evaluating ONNXRuntime quantization accuracy and performance:')
---> 83 time_ort_model_evaluation(MODEL, configs, tokenizer, "onnx.opt.quant")

<ipython-input-16-974877a1c2e5> in time_ort_model_evaluation(model_path, configs, tokenizer, prefix)
     71 def time_ort_model_evaluation(model_path, configs, tokenizer, prefix=""):
     72     eval_start_time = time.time()
---> 73     result = evaluate_onnx(configs, model_path, tokenizer, prefix=prefix)
     74     eval_end_time = time.time()
     75     eval_duration_time = eval_end_time - eval_start_time

<ipython-input-16-974877a1c2e5> in evaluate_onnx(args, model_path, tokenizer, prefix)
     39                                 'segment_ids': batch[2]
     40                             }
---> 41             logits = np.reshape(session.run(None, ort_inputs), (-1,2))
     42             if preds is None:
     43                 preds = logits

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
    186             output_names = [output.name for output in self._outputs_meta]
    187         try:
--> 188             return self._sess.run(output_names, input_feed, run_options)
    189         except C.EPFail as err:
    190             if self._enable_fallback:

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:segment_ids

@danielbellhv
Copy link
Author

The problem is pre-quantisation. Any Quantisation code I try, throws the same error.

I've swapped out:

from transformers import (BertConfig, BertForSequenceClassification, BertTokenizer,)

for:

from transformers import (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer,)

However, AlbertConfig / BertConfig is not mentioned in the original Notebook.

@danielbellhv
Copy link
Author

The problem isn't Optimisation. I've ran my version of the Notebook with and without Optimisation code; the same error returns.

Running the original Notebook without Optimisation gives same outputs, and scores.

@danielbellhv
Copy link
Author

danielbellhv commented Jan 24, 2022

Might my issue be with configs variables? I'm gonna research into this.

Also, How can I check if my .onnx files are corrupted or not? @edgchen1 @yuslepukhin @tianleiwu

@danielbellhv
Copy link
Author

I am using textattack/albert-base-v2-MRPC:

git clone https://huggingface.co/textattack/albert-base-v2-MRPC
mv albert-base-v2-MRPC/ SageMaker/

And updated:

configs.model_name_or_path = "albert-base-v2-MRPC"

Only for it to throw the same error.

@danielbellhv
Copy link
Author

danielbellhv commented Jan 25, 2022

I've set up a Google Colab with the runtime

@danielbellhv
Copy link
Author

The problem was with updating the config variables for my new model.

Changes:

configs.output_dir = "albert-base-v2-MRPC"
configs.model_name_or_path = "albert-base-v2-MRPC"

I then came across this separate issue, where I hadn't git cloned my model properly. Question and answer detailed here.

Lastly, HuggingFace 🤗 does not have an equivalent to BertOptimizationOptions for ALBert. I had tried general PyTorch optimisers offered by [torch_optimizer][2] on the ONNX model, but it seems that they aren't compatible for ONNX models.

Feel free to comment for further clarification.

@tianleiwu
Copy link
Contributor

@danielbellhv, you can try the following command instead since you want to run MRPC:
python -m onnxruntime.transformers.benchmark -m albert-base-v2 -i 1 -t 100 -b 1 -s 128 -e onnxruntime --model_class AutoModelForSequenceClassification -p int8 -o -v

Try python -m onnxruntime.transformers.benchmark --help for more information about the parameter.

Related onnx export code can be found in

def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class, cache_dir,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

4 participants