Skip to content

Commit

Permalink
add neuralchat support ipex.optimize_transformers sq int8 model loadi…
Browse files Browse the repository at this point in the history
…ng (#764)
  • Loading branch information
changwangss committed Nov 24, 2023
1 parent 3a4586e commit ee85583
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import copy, time
from datetime import datetime
import torch
import warnings
from queue import Queue
import re, os
from threading import Thread
Expand Down Expand Up @@ -394,13 +395,50 @@ def load_model(
or re.search("codellama", model_name, re.IGNORECASE)
) and ipex_int8
):
with smart_context_manager(use_deepspeed=use_deepspeed):
import intel_extension_for_pytorch
from optimum.intel.generation.modeling import TSModelForCausalLM
model = TSModelForCausalLM.from_pretrained(
model_name,
file_name="best_model.pt",
)
with smart_context_manager(use_deepspeed=use_deepspeed):
try:
import intel_extension_for_pytorch as ipex
except ImportError:
warnings.warn(
"Please install Intel Extension for PyTorch to accelerate the model inference."
)
assert ipex.__version__ >= "2.1.0+cpu", "Please use Intel Extension for PyTorch >=2.1.0+cpu."
from optimum.intel.generation.modeling import TSModelForCausalLM
model = TSModelForCausalLM.from_pretrained(
model_name,
file_name="best_model.pt",
)
elif(
(re.search("llama", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
or re.search("gpt_neox", model_name, re.IGNORECASE)
or re.search("gptj", model_name, re.IGNORECASE)
) and ipex_int8
):
with smart_context_manager(use_deepspeed=use_deepspeed):
try:
import intel_extension_for_pytorch as ipex
except ImportError:
warnings.warn(
"Please install Intel Extension for PyTorch to accelerate the model inference."
)
assert ipex.__version__ >= "2.1.0+cpu", "Please use Intel Extension for PyTorch >=2.1.0+cpu."
torch._C._jit_set_texpr_fuser_enabled(False)
qconfig = ipex.quantization.default_static_qconfig_mapping
with ipex.OnDevice(dtype=torch.float, device="meta"):
model = AutoModelForCausalLM.from_pretrained(model_name)
model = ipex.optimize_transformers(
model.eval(),
dtype=torch.float,
inplace=True,
quantization_config=qconfig,
deployment_mode=False,
)
if not hasattr(model, "trace_graph"):
print("load_quantized_model")
self_jit = torch.jit.load(os.path.join(model_name, "best_model.pt"))
self_jit = torch.jit.freeze(self_jit.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=self_jit)
else:
raise ValueError(
f"Unsupported model {model_name}, only supports "
Expand Down

0 comments on commit ee85583

Please sign in to comment.