Skip to content

Commit

Permalink
Support security check for both origin query and translated query (#709)
Browse files Browse the repository at this point in the history
* Support security check for both origin query and translated query

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Nov 23, 2023
1 parent 452cbd1 commit 6e03862
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
34 changes: 25 additions & 9 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@ def load_model(self, kwargs: dict):
hf_access_token=kwargs["hf_access_token"],
use_llm_runtime=kwargs["use_llm_runtime"])

def predict_stream(self, query, config=None):
def predict_stream(self, query, origin_query="", config=None):
"""
Predict using a streaming approach.
Args:
query: The input query for prediction.
origin_query: The origin Chinese query for safety checker.
config: Configuration for prediction.
"""
if not config:
Expand All @@ -139,6 +140,9 @@ def predict_stream(self, query, config=None):
config.use_cache = self.use_cache
config.ipex_int8 = self.ipex_int8

my_query = query
my_origin_query = origin_query

if is_audio_file(query):
if not os.path.exists(query):
raise ValueError(f"The audio file path {query} is invalid.")
Expand Down Expand Up @@ -169,8 +173,14 @@ def predict_stream(self, query, config=None):
return plugin_instance.response_template, link
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
return "Your query contains sensitive words, please try another query.", link
if plugin_name == "safety_checker":
sign1=plugin_instance.pre_llm_inference_actions(my_query)
if sign1:
return "Your query contains sensitive words, please try another query.", link
if not my_origin_query=="":
sign2=plugin_instance.pre_llm_inference_actions(my_origin_query)
if sign2:
return "Your query contains sensitive words, please try another query.", link
else:
if response != None and response != False:
query = response
Expand All @@ -196,12 +206,13 @@ def is_generator(obj):

return response, link

def predict(self, query, config=None):
def predict(self, query, origin_query="", config=None):
"""
Predict using a non-streaming approach.
Args:
query: The input query for prediction.
origin_query: The origin Chinese query for safety checker.
config: Configuration for prediction.
"""
if not config:
Expand Down Expand Up @@ -244,7 +255,10 @@ def predict(self, query, config=None):
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
return "Your query contains sensitive words, please try another query."
if response:
return "Your query contains sensitive words, please try another query."
elif origin_query and plugin_instance.pre_llm_inference_actions(origin_query):
return "Your query contains sensitive words, please try another query."
else:
if response != None and response != False:
query = response
Expand All @@ -268,25 +282,27 @@ def predict(self, query, config=None):

return response

def chat_stream(self, query, config=None):
def chat_stream(self, query, origin_query="", config=None):
"""
Chat using a streaming approach.
Args:
query: The input query for prediction.
origin_query: The origin Chinese query for safety checker.
config: Configuration for prediction.
"""
return self.predict_stream(query=query, config=config)
return self.predict_stream(query=query, origin_query=origin_query, config=config)

def chat(self, query, config=None):
def chat(self, query, origin_query="", config=None):
"""
Chat using a non-streaming approach.
Args:
query: The input query for conversation.
origin_query: The origin Chinese query for safety checker.
config: Configuration for conversation.
"""
return self.predict(query=query, config=config)
return self.predict(query=query, origin_query=origin_query, config=config)

def face_animate(self, image_path, audio_path=None, text=None, voice=None) -> str: # pragma: no cover
# 1) if there is a driven audio, then image + audio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class FinetuneRequest(RequestBaseModel):

class AskDocRequest(RequestBaseModel):
query: str
translated: str
domain: str
blob: Optional[str]
filename: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def retrieval_chat(request: Request):
# parse parameters
params = await request.json()
query = params['query']
origin_query = params['translated']
kb_id = params['knowledge_base_id']
stream = params['stream']
max_new_tokens = params['max_new_tokens']
Expand Down Expand Up @@ -227,11 +228,11 @@ async def retrieval_chat(request: Request):

# non-stream mode
if not stream:
response = chatbot.predict(query=query, config=config)
response = chatbot.predict(query=query, origin_query=origin_query, config=config)
formatted_response = response.replace('\n', '<br/>')
return formatted_response
# stream mode
generator, link = chatbot.predict_stream(query=query, config=config)
generator, link = chatbot.predict_stream(query=query, origin_query=origin_query, config=config)
logger.info(f"[askdoc - chat] chatbot predicted: {generator}")
if isinstance(generator, str):
def stream_generator():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ retrieval:
response_template: "We cannot find suitable content to answer your query, please contact AskGM to find help. Mail: ask.gm.zizhu@intel.com."
append: True

safety_checker:
enable: true

tasks_list: ['textchat', 'retrieval']

Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ def setUp(self) -> None:
command = f'neuralchat_server start \
--config_file {yaml_file_path} \
--log_file "./neuralchat.log"'
elif os.path.exists("./askdoc.yaml"):
command = f'neuralchat_server start \
--config_file ./askdoc.yaml \
--log_file "./neuralchat.log"'
else:
command = 'sed -i "s|askdoc|ci/server/askdoc|g" ./ci/server/askdoc.yaml && neuralchat_server start \
--config_file "./ci/server/askdoc.yaml" \
--log_file "./neuralchat.log"'
try:
self.server_process = subprocess.Popen(command,
universal_newlines=True, shell=True) # nosec
time.sleep(120)
time.sleep(60)
except subprocess.CalledProcessError as e:
print("Error while executing command:", e)

Expand All @@ -55,14 +59,25 @@ def tearDown(self) -> None:
def test_askdoc_chat(self):
url = 'http://127.0.0.1:6000/v1/aiphotos/askdoc/chat'
request = {
"query": "What is Intel oneAPI Compiler?",
"query": "oneAPI编译器是什么?",
"translated": "What is Intel oneAPI Compiler?",
"knowledge_base_id": "default",
"stream": False,
"max_new_tokens": 256
}
res = requests.post(url, json.dumps(request))
self.assertEqual(res.status_code, 200)

request = {
"query": "蔡英文是谁?",
"translated": "Who is Tsai Ing-wen?",
"knowledge_base_id": "default",
"stream": False,
"max_new_tokens": 256
}
res = requests.post(url, json.dumps(request))
self.assertEqual(res.status_code, 200)
self.assertIn('Your query contains sensitive words, please try another query', str(res.text))

if __name__ == "__main__":
unittest.main()

0 comments on commit 6e03862

Please sign in to comment.