Skip to content

Commit

Permalink
Support IPEX bf16 & fp32 optimization for emebedding model (#1380)
Browse files Browse the repository at this point in the history
* support ipex bf16/fp32 optimization for emebedding model

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Mar 14, 2024
1 parent df3b4f1 commit b515523
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import generate_qa_prompt, generate_prompt, generate_qa_enterprise
from intel_extension_for_transformers.langchain.embeddings import HuggingFaceEmbeddings, \
HuggingFaceInstructEmbeddings, HuggingFaceBgeEmbeddings
from intel_extension_for_transformers.transformers.utils import CpuInfo
from langchain_community.embeddings import GooglePalmEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from intel_extension_for_transformers.langchain.vectorstores import Chroma, Qdrant
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self,
process=True,
append=True,
polish=False,
precision=None,
**kwargs):

self.intent_detector = IntentDetector()
Expand All @@ -93,6 +95,10 @@ def __init__(self,
"accuracy",
"general",
)
allowed_precision: ClassVar[Collection[str]] = (
"fp32",
"bf16",
)
if polish:
self.polisher = QueryPolisher()
else:
Expand All @@ -102,6 +108,9 @@ def __init__(self,
self.retrieval_type)
assert self.mode in allowed_generation_mode, "generation mode of {} not allowed.".format( \
self.mode)
assert precision is None or precision in allowed_precision, \
"embedding precision of '{}' is not allowed. Support {}.".format(
precision, allowed_precision)

if isinstance(input_path, str):
if os.path.exists(input_path):
Expand Down Expand Up @@ -137,6 +146,17 @@ def __init__(self,
logging.error("Please select a proper embedding model.")
logging.error(e)

if precision is not None:
# IPEX BF16 or FP32 optimization for embedding model
import torch
import intel_extension_for_pytorch as ipex
if precision == "bf16" and CpuInfo().bf16:
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True)
elif precision == "fp32":
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.float32, inplace=True)

self.document_parser = DocumentParser(max_chuck_size=max_chuck_size, min_chuck_size = min_chuck_size, \
process=self.process)
data_collection = self.document_parser.load(input=self.input_path, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,50 @@ def test_similarity_search_type_k_2(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestEmbeddingPrecision(unittest.TestCase):
def setUp(self):
if os.path.exists("./embedding_precision_bf16"):
shutil.rmtree("./embedding_precision_bf16", ignore_errors=True)
if os.path.exists("./embedding_precision_fp32"):
shutil.rmtree("./embedding_precision_fp32", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./embedding_precision_bf16"):
shutil.rmtree("./embedding_precision_bf16", ignore_errors=True)
if os.path.exists("./embedding_precision_fp32"):
shutil.rmtree("./embedding_precision_fp32", ignore_errors=True)
return super().tearDown()

def test_embedding_precision_bf16(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/retrieve_multi_doc"
plugins.retrieval.args["persist_directory"] = "./embedding_precision_bf16"
plugins.retrieval.args["precision"] = 'bf16'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("Tell me about Intel Xeon Platinum 8480+ Processor.")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

def test_embedding_precision_fp32(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/retrieve_multi_doc"
plugins.retrieval.args["persist_directory"] = "./embedding_precision_fp32"
plugins.retrieval.args["precision"] = 'fp32'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("Tell me about Intel Xeon Platinum 8480+ Processor.")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

if __name__ == '__main__':
unittest.main()

0 comments on commit b515523

Please sign in to comment.