Skip to content

Commit

Permalink
[Neuralchat] Improve error code UT coverage (#1132)
Browse files Browse the repository at this point in the history
Co-authored-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
Liangyx2 and lvliang-intel committed Jan 18, 2024
1 parent fb1c0ec commit dd6dcb4
Show file tree
Hide file tree
Showing 10 changed files with 681 additions and 183 deletions.
89 changes: 28 additions & 61 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
from intel_extension_for_transformers.llm.quantization.optimization import Optimization
from .config import PipelineConfig
from .config import BaseFinetuningConfig
from .config import DeviceOptions
from .plugins import plugins

from .errorcode import ErrorCodes, STORAGE_THRESHOLD_GB
from .utils.error_utils import set_latest_error
import psutil
import torch
from .errorcode import ErrorCodes
from .utils.error_utils import set_latest_error, get_latest_error, clear_latest_error
from .config_logging import configure_logging
logger = configure_logging()

Expand All @@ -44,29 +41,10 @@ def build_chatbot(config: PipelineConfig=None):
pipeline = build_chatbot()
response = pipeline.predict(query="Tell me about Intel Xeon Scalable Processors.")
"""
# Check for out of storage
available_storage = psutil.disk_usage('/').free
available_storage_gb = available_storage / (1024 ** 3)
if available_storage_gb < STORAGE_THRESHOLD_GB:
set_latest_error(ErrorCodes.ERROR_OUT_OF_STORAGE)
return

global plugins
clear_latest_error()
if not config:
config = PipelineConfig()
# Validate input parameters
if config.device not in [option.name.lower() for option in DeviceOptions]:
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_SUPPORTED)
return

if config.device == "cuda":
if not torch.cuda.is_available():
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_FOUND)
return
elif config.device == "xpu":
if not torch.xpu.is_available():
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_FOUND)
return

# create model adapter
if "llama" in config.model_name_or_path.lower():
Expand Down Expand Up @@ -103,6 +81,7 @@ def build_chatbot(config: PipelineConfig=None):
adapter = BaseModel()
else:
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_SUPPORTED)
logger.error("build_chatbot: unknown model")
return

# register plugin instance in model adaptor
Expand Down Expand Up @@ -137,12 +116,25 @@ def build_chatbot(config: PipelineConfig=None):
elif plugin_name == "image2image": # pragma: no cover
from .pipeline.plugins.image2image.image2image import Image2Image
plugins[plugin_name]['class'] = Image2Image
else: # pragma: no cover
else:
set_latest_error(ErrorCodes.ERROR_PLUGIN_NOT_SUPPORTED)
logger.error("build_chatbot: unknown plugin")
return
print(f"create {plugin_name} plugin instance...")
print(f"plugin parameters: ", plugin_value['args'])
plugins[plugin_name]["instance"] = plugins[plugin_name]['class'](**plugin_value['args'])
try:
plugins[plugin_name]["instance"] = plugins[plugin_name]['class'](**plugin_value['args'])
except Exception as e:
if "[Rereieval ERROR] Document format not supported" in str(e):
set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED)
logger.error("build_chatbot: retrieval plugin init failed")
elif "[SafetyChecker ERROR] Sensitive check file not found" in str(e):
set_latest_error(ErrorCodes.ERROR_SENSITIVE_CHECK_FILE_NOT_FOUND)
logger.error("build_chatbot: safety checker plugin init failed")
else:
set_latest_error(ErrorCodes.ERROR_GENERIC)
logger.error("build_chatbot: plugin init failed")
return
adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"])

parameters = {}
Expand All @@ -163,47 +155,19 @@ def build_chatbot(config: PipelineConfig=None):
parameters["hf_access_token"] = config.hf_access_token
parameters["assistant_model"] = config.assistant_model

try:
adapter.load_model(parameters)
except RuntimeError as e:
logger.error(f"Exception: {e}")
if "out of memory" in str(e):
set_latest_error(ErrorCodes.ERROR_OUT_OF_MEMORY)
elif "devices are busy or unavailable" in str(e):
set_latest_error(ErrorCodes.ERROR_DEVICE_BUSY)
elif "tensor does not have a device" in str(e):
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_FOUND)
else:
set_latest_error(ErrorCodes.ERROR_GENERIC)
adapter.load_model(parameters)
if get_latest_error():
return
except ValueError as e:
logger.error(f"Exception: {e}")
if "load_model: unsupported device" in str(e):
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_SUPPORTED)
elif "load_model: unsupported model" in str(e):
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_SUPPORTED)
elif "load_model: tokenizer is not found" in str(e):
set_latest_error(ErrorCodes.ERROR_TOKENIZER_NOT_FOUND)
elif "load_model: model name or path is not found" in str(e):
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_FOUND)
elif "load_model: model config is not found" in str(e):
set_latest_error(ErrorCodes.ERROR_MODEL_CONFIG_NOT_FOUND)
else:
set_latest_error(ErrorCodes.ERROR_GENERIC)
return
except Exception as e:
logger.error(f"Exception: {e}")
set_latest_error(ErrorCodes.ERROR_GENERIC)
return
return adapter
else:
return adapter

def finetune_model(config: BaseFinetuningConfig):
"""Finetune the model based on the provided configuration.
Args:
config (BaseFinetuningConfig): Configuration for finetuning the model.
"""

clear_latest_error()
assert config is not None, "BaseFinetuningConfig is needed for finetuning."
from intel_extension_for_transformers.llm.finetuning.finetuning import Finetuning
finetuning = Finetuning(config)
Expand All @@ -221,7 +185,9 @@ def finetune_model(config: BaseFinetuningConfig):
set_latest_error(ErrorCodes.ERROR_TRAIN_FILE_NOT_FOUND)
except Exception as e:
logger.error(f"Exception: {e}")
if config.finetune_args.peft == "lora":
if "Permission denied" in str(e):
set_latest_error(ErrorCodes.ERROR_DATASET_CACHE_DIR_NO_WRITE_PERMISSION)
elif config.finetune_args.peft == "lora":
set_latest_error(ErrorCodes.ERROR_LORA_FINETUNE_FAIL)
elif config.finetune_args.peft == "llama_adapter":
set_latest_error(ErrorCodes.ERROR_LLAMA_ADAPTOR_FINETUNE_FAIL)
Expand All @@ -242,6 +208,7 @@ def optimize_model(model, config, use_llm_runtime=False):
config (OptimizationConfig): The configuration required for optimizing the model.
use_llm_runtime (bool): A boolean indicating whether to use the LLM runtime graph optimization.
"""
clear_latest_error()
optimization = Optimization(optimization_config=config)
try:
model = optimization.optimize(model, use_llm_runtime)
Expand Down
37 changes: 28 additions & 9 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,25 @@ def predict_stream(self, query, origin_query="", config=None):
if response:
logging.info("Get response: %s from cache", response)
return response['choices'][0]['text'], link
if plugin_name == "asr" and not is_audio_file(query):
if plugin_name == "asr" and not os.path.exists(query):
continue
if plugin_name == "retrieval":
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template, link
try:
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template, link
except Exception as e:
if "[Rereieval ERROR] intent detection failed" in str(e):
set_latest_error(ErrorCodes.ERROR_INTENT_DETECT_FAIL)
return
else:
try:
response = plugin_instance.pre_llm_inference_actions(query)
except Exception as e:
if plugin_name == "asr":
if "[ASR ERROR] Audio format not supported" in str(e):
set_latest_error(ErrorCodes.ERROR_AUDIO_FORMAT_NOT_SUPPORTED)
return
if plugin_name == "safety_checker":
sign1=plugin_instance.pre_llm_inference_actions(my_query)
if sign1:
Expand Down Expand Up @@ -226,6 +232,7 @@ def predict_stream(self, query, origin_query="", config=None):
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
except Exception as e:
set_latest_error(ErrorCodes.ERROR_MODEL_INFERENCE_FAIL)
return

def is_generator(obj):
return isinstance(obj, types.GeneratorType)
Expand Down Expand Up @@ -284,14 +291,25 @@ def predict(self, query, origin_query="", config=None):
if response:
logging.info("Get response: %s from cache", response)
return response['choices'][0]['text']
if plugin_name == "asr" and not is_audio_file(query):
if plugin_name == "asr" and not os.path.exists(query):
continue
if plugin_name == "retrieval":
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template
try:
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template
except Exception as e:
if "[Rereieval ERROR] intent detection failed" in str(e):
set_latest_error(ErrorCodes.ERROR_INTENT_DETECT_FAIL)
return
else:
response = plugin_instance.pre_llm_inference_actions(query)
try:
response = plugin_instance.pre_llm_inference_actions(query)
except Exception as e:
if plugin_name == "asr":
if "[ASR ERROR] Audio format not supported" in str(e):
set_latest_error(ErrorCodes.ERROR_AUDIO_FORMAT_NOT_SUPPORTED)
return
if plugin_name == "safety_checker" and response:
if response:
return "Your query contains sensitive words, please try another query."
Expand Down Expand Up @@ -321,6 +339,7 @@ def predict(self, query, origin_query="", config=None):
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
except Exception as e:
set_latest_error(ErrorCodes.ERROR_MODEL_INFERENCE_FAIL)
return

# plugin post actions
for plugin_name in get_registered_plugins():
Expand Down
Loading

0 comments on commit dd6dcb4

Please sign in to comment.