Skip to content

Commit

Permalink
Support lm-eval for ONNX models (#1103)
Browse files Browse the repository at this point in the history
* Support lm-eval for ONNX model

Signed-off-by: Mengni Wang <mengni.wang@intel.com>

* fix bug

Signed-off-by: Mengni Wang <mengni.wang@intel.com>

* add ut

Signed-off-by: Mengni Wang <mengni.wang@intel.com>

* Update requirements.txt

* Update huggingface.py

* Update test_evaluation.py

* Update huggingface.py

* Update huggingface.py

* Update test_evaluation.py

* Update test_evaluation.py

---------

Signed-off-by: Mengni Wang <mengni.wang@intel.com>
  • Loading branch information
mengniwang95 committed Jul 6, 2023
1 parent 8252318 commit a944faa
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def evaluate(model,
decontamination_ngrams_path=None,
seed=1234,
user_model=None,
model_format='torch'
):
"""Instantiate and evaluate a model on a list of tasks.
Expand Down Expand Up @@ -98,6 +99,8 @@ def evaluate(model,
Model object user provided.
:param output_dir: str
Save the results Path
:param model_format: str
Model format, support 'torch' and 'onnx'
:return
Dictionary of results
"""
Expand All @@ -111,7 +114,7 @@ def evaluate(model,
if isinstance(model, str):
if model_args is None:
model_args = ""
kwargs = {"batch_size": batch_size, "device": device}
kwargs = {"batch_size": batch_size, "device": device, "model_format": model_format}
if user_model:
kwargs["init_empty_weights"] = True
lm = get_model(model).create_from_arg_string(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
import math
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[int, str]] = "cuda",
init_empty_weights: Optional[bool] = False,
model_format: Optional[str] = "torch"
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
Expand Down Expand Up @@ -144,6 +146,8 @@ def __init__(
Use `dtype="auto"` to derive the type from the model’s weights.
init_empty_weights (bool, optional, defaults to False):):
Initialize model with empty weights if model is not used for inference.
model_format (str, optional, defaults to torch):
The format of target model, support 'torch' and 'onnx'
"""
super().__init__()

Expand Down Expand Up @@ -174,7 +178,7 @@ def __init__(
)

self._add_special_tokens = add_special_tokens
if re.search("llama", pretrained):
if re.search("llama", pretrained.lower()):
from transformers import LlamaTokenizer # pylint: disable=E0611
self.tokenizer = LlamaTokenizer.from_pretrained(
pretrained,
Expand All @@ -196,14 +200,15 @@ def __init__(
max_cpu_memory,
offload_folder,
)
self.model = self._create_auto_model(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**accelerate_kwargs,
)
self.model.eval()
if model_format == "torch":
self.model = self._create_auto_model(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**accelerate_kwargs,
)
self.model.eval()
torch.set_grad_enabled(False)

self._device = device
Expand Down Expand Up @@ -396,6 +401,43 @@ class AutoCausalLM(HuggingFaceAutoLM):

AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

def __init__(self, *args, pretrained, model_format, **kwargs):
super().__init__(*args, pretrained=pretrained, model_format=model_format, **kwargs)

self.model_format = model_format
if self.model_format == "onnx":
if not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")):
raise ValueError(
"Couldn't find decoder_model.onnx in {}.".format(pretrained)
)

import onnxruntime as ort
from transformers import PretrainedConfig
from optimum.onnxruntime import ORTModelForCausalLM

model_config = PretrainedConfig.from_pretrained(pretrained)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")):
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
os.path.join(pretrained, "decoder_model.onnx"),
os.path.join(pretrained, "decoder_with_past_model.onnx"),
session_options=sess_options)
self.model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
model_config,
pretrained,
sessions[1],
use_cache=True)
else:
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
os.path.join(pretrained, "decoder_model.onnx"),
session_options=sess_options)
self.model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
model_config,
pretrained,
use_cache=False,
use_io_binding=False)

def _create_auto_tokenizer(
self,
*,
Expand All @@ -416,7 +458,8 @@ def _create_auto_tokenizer(
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
output = self.model(inputs)
output = self.model(inputs) if self.model_format != "onnx" else \
self.model(inputs, torch.ones(inputs.shape, dtype=torch.int64))
if isinstance(output, tuple):
return output[0]
return output["logits"]
Expand Down Expand Up @@ -464,6 +507,48 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):

AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM

def __init__(self, *args, pretrained, model_format, **kwargs):
super().__init__(*args, pretrained=pretrained, model_format=model_format, **kwargs)

self.model_format = model_format
if self.model_format == "onnx":
if not os.path.exists(os.path.join(pretrained, "decoder_model.onnx")) or \
not os.path.exists(os.path.join(pretrained, "encoder_model.onnx")):
raise ValueError(
"Please ensure decoder_model.onnx and encoder_model.onnx are under {}.".format(pretrained)
)

import onnxruntime as ort
from transformers import PretrainedConfig
from optimum.onnxruntime import ORTModelForSeq2SeqLM

model_config = PretrainedConfig.from_pretrained(pretrained)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if os.path.exists(os.path.join(pretrained, "decoder_with_past_model.onnx")):
sessions = ORTModelForSeq2SeqLM.load_model(
os.path.join(pretrained, 'encoder_model.onnx'),
os.path.join(pretrained, 'decoder_model.onnx'),
os.path.join(pretrained, 'decoder_with_past_model.onnx'))

self.model = ORTModelForSeq2SeqLM(sessions[0],
sessions[1],
model_config,
pretrained,
sessions[2],
use_cache=True)
else:
sessions = ORTModelForSeq2SeqLM.load_model( # pylint: disable=E1120
os.path.join(pretrained, 'encoder_model.onnx'),
os.path.join(pretrained, 'decoder_model.onnx'))

self.model = ORTModelForSeq2SeqLM(sessions[0],
sessions[1],
model_config,
pretrained,
use_cache=False,
use_io_binding=False)

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
Expand Down Expand Up @@ -591,7 +676,16 @@ def _loglikelihood_tokens(
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(**inputs, labels=labels["input_ids"])
if self.model_format == "onnx":
decoder_start_token_id = self._config.decoder_start_token_id
pad_token_id = self._config.pad_token_id
shifted_input_ids = labels["input_ids"].new_zeros(labels["input_ids"].shape)
shifted_input_ids[..., 1:] = labels["input_ids"][..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return self.model(**inputs, decoder_input_ids=shifted_input_ids, labels=labels["input_ids"])
else:
return self.model(**inputs, labels=labels["input_ids"])

def _model_generate(
self,
Expand Down
3 changes: 2 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e587
accelerate
evaluate
wget
optimum-intel
optimum
optimum-intel
60 changes: 59 additions & 1 deletion tests/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import shutil
import unittest

import subprocess
import torch
from intel_extension_for_transformers.evaluation.lm_eval import evaluate
from intel_extension_for_transformers.evaluation.hf_eval import summarization_evaluate
Expand All @@ -24,6 +24,10 @@ def setUpClass(self):
@classmethod
def tearDownClass(self):
shutil.rmtree("./lm_cache", ignore_errors=True)
shutil.rmtree("./t5", ignore_errors=True)
shutil.rmtree("./t5-past", ignore_errors=True)
shutil.rmtree("./gptj", ignore_errors=True)
shutil.rmtree("./gptj-past", ignore_errors=True)

def test_evaluate_for_casualLM(self):
results = evaluate(
Expand All @@ -34,6 +38,33 @@ def test_evaluate_for_casualLM(self):
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.45)

def test_evaluate_for_ort_casualLM(self):
cmd = 'optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation-with-past gptj-past/'
p = subprocess.Popen(cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=True) # nosec
p.communicate()
results = evaluate(
model="hf-causal",
model_args='pretrained="./gptj-past",tokenizer="./gptj-past",dtype=float32',
tasks=["piqa"],
limit=20,
model_format="onnx"
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.45)

cmd = 'optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation gptj/'
p = subprocess.Popen(cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=True) # nosec
p.communicate()
results = evaluate(
model="hf-causal",
model_args='pretrained="./gptj",tokenizer="./gptj",dtype=float32',
tasks=["piqa"],
limit=20,
model_format="onnx"
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.45)

def test_evaluate_for_Seq2SeqLM(self):
results = evaluate(
model="hf-seq2seq",
Expand All @@ -43,6 +74,33 @@ def test_evaluate_for_Seq2SeqLM(self):
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.60)

def test_evaluate_for_ort_Seq2SeqLM(self):
cmd = 'optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 --task text2text-generation-with-past t5-past/'
p = subprocess.Popen(cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=True) # nosec
p.communicate()
results = evaluate(
model="hf-seq2seq",
model_args='pretrained="./t5-past",tokenizer="./t5-past",dtype=float32',
tasks=["piqa"],
limit=20,
model_format="onnx"
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.60)

cmd = 'optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 --task text2text-generation-with-past t5/'
p = subprocess.Popen(cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=True) # nosec
p.communicate()
results = evaluate(
model="hf-seq2seq",
model_args='pretrained="./t5",tokenizer="./t5",dtype=float32',
tasks=["piqa"],
limit=20,
model_format="onnx"
)
self.assertEqual(results["results"]["piqa"]["acc"], 0.60)

def test_evaluate_for_JitModel(self):
results = evaluate(
model="hf-causal",
Expand Down

0 comments on commit a944faa

Please sign in to comment.