Skip to content

Commit

Permalink
fix domain detection for large model (#565)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Feb 28, 2023
1 parent ef928e9 commit 70a5662
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion neural_compressor/adaptor/onnxrt.py
Expand Up @@ -636,7 +636,12 @@ def _detect_domain(self, model):
# 2. according to input
# typically, NLP models have multiple inputs,
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
sess = ort.InferenceSession(model.model.SerializeToString())
if not model.is_large_model:
sess = ort.InferenceSession(model.model.SerializeToString())
elif model.model_path is not None: # pragma: no cover
sess = ort.InferenceSession(model.model_path)
else: # pragma: no cover
assert False, "Please use model path instead of onnx model object to quantize."
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens):
is_nlp = True
Expand Down

0 comments on commit 70a5662

Please sign in to comment.