diff --git a/.github/workflows/unit-test-neuralchat.yml b/.github/workflows/unit-test-neuralchat.yml index ab2a885767c..38ea76388b5 100644 --- a/.github/workflows/unit-test-neuralchat.yml +++ b/.github/workflows/unit-test-neuralchat.yml @@ -34,6 +34,7 @@ env: CONTAINER_NAME: "utTest" EXTRA_CONTAINER_NAME: "modelTest" CONTAINER_SCAN: "codeScan" + GOOGLE_API_KEY: ${{ vars.GOOGLE_API_KEY }} jobs: neuralchat-unit-test: @@ -84,6 +85,7 @@ jobs: -v /home/itrex-docker/models:/models \ -v /dataset/media:/media \ -v /dataset/tf_dataset2:/tf_dataset2 \ + -e "GOOGLE_API_KEY=${{ vars.GOOGLE_API_KEY }}" \ ${{ env.REPO_NAME }}:${{ env.REPO_TAG }} - name: Binary build diff --git a/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py b/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py index bf086151ce3..e92e1afcbc5 100644 --- a/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py +++ b/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py @@ -20,16 +20,18 @@ import os import re import csv +import shutil import datetime +import requests from pathlib import Path from datetime import timedelta, timezone -from typing import Optional, Dict -from fastapi import APIRouter, UploadFile, File, Request, Response, Form +from typing import Optional, Dict, List +from fastapi import APIRouter, UploadFile, File, Request, Response, Form, status, HTTPException from ...config import GenerationConfig from ...cli.log import logger from ...server.restful.request import RetrievalRequest, FeedbackRequest from ...server.restful.response import RetrievalResponse -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from ...utils.database.mysqldb import MysqlDb from ...plugins import plugins @@ -51,6 +53,68 @@ def get_current_beijing_time(): return beijing_time +def language_detect(text: str): + url = "https://translation.googleapis.com/language/translate/v2/detect" + try: + api_key = os.getenv("GOOGLE_API_KEY") + logger.info(f"[ language_detect ] GOOGLE_API_KEY: {api_key}") + except Exception as e: + logger.info(f"No GOOGLE_API_KEY found. {e}") + params = { + 'key': api_key, + 'q': text + } + + response = requests.post(url, params=params) + if response.status_code == 200: + res = response.json() + return res["data"]["detections"][0][0] + else: + print("Error status:", response.status_code) + print("Error content:", response.json()) + return None + + +def language_translate(text: str, target: str='en'): + url = "https://translation.googleapis.com/language/translate/v2" + api_key = os.getenv("GOOGLE_API_KEY") + logger.info(f"[ language_translate ] GOOGLE_API_KEY: {api_key}") + params = { + 'key': api_key, + 'q': text, + 'target': target + } + + response = requests.post(url, params=params) + if response.status_code == 200: + res = response.json() + return res["data"]["translations"][0] + else: + print("Error status:", response.status_code) + print("Error content:", response.json()) + return None + + +def create_upload_dir(knowledge_base_id: str, user_id: str): + if knowledge_base_id == 'default': + path_prefix = RETRIEVAL_FILE_PATH + 'default' + else: + path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id + upload_path = path_prefix + '/upload_dir' + persist_path = path_prefix + '/persist_dir' + if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ): + if knowledge_base_id == 'default': + os.makedirs(Path(path_prefix), exist_ok=True) + os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True) + os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True) + logger.info(f"Default kb {path_prefix} does not exist, create.") + else: + logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}") + raise Exception(f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \ + Please check kb_id and save path again.") + return upload_path, persist_path + + class RetrievalAPIRouter(APIRouter): def __init__(self) -> None: @@ -93,20 +157,18 @@ async def retrieval_upload_link(request: Request): if 'knowledge_base_id' in params.keys(): print(f"[askdoc - upload_link] append") knowledge_base_id = params['knowledge_base_id'] - persist_path = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id + '/persist_dir' - if not os.path.exists(persist_path): - return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \ - Please check kb_id and save path again." + upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id) try: print("[askdoc - upload_link] starting to append local db...") instance = plugins['retrieval']["instance"] + print(f"[askdoc - upload_link] persist_path: {persist_path}") instance.append_localdb(append_path=link_list, persist_directory=persist_path) print(f"[askdoc - upload_link] kb appended successfully") except Exception as e: # pragma: no cover logger.info(f"[askdoc - upload_link] create knowledge base fails! {e}") return Response(content="Error occurred while uploading links.", status_code=500) - return {"Succeed"} + return {"status": True} # create new kb with link else: print(f"[askdoc - upload_link] create") @@ -185,51 +247,41 @@ async def retrieval_create(request: Request, @router.post("/v1/askdoc/append") async def retrieval_append(request: Request, - file: UploadFile = File(...), + files: List[UploadFile] = File(...), + # file: UploadFile = File(...), knowledge_base_id: str = Form(...)): global plugins - filename = file.filename - if '/' in filename: - filename = filename.split('/')[-1] - logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}") + for file in files: + filename = file.filename + if '/' in filename: + filename = filename.split('/')[-1] + logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}") + + user_id = request.client.host + logger.info(f'[askdoc - append] user id is: {user_id}') + + # create local upload dir + upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id) + print(f"[askdoc - upload_link] persist_path: {persist_path}") + cur_time = get_current_beijing_time() + logger.info(f"[askdoc - append] upload path: {upload_path}") + + # save file to local path + save_file_name = upload_path + '/' + cur_time + '-' + filename + with open(save_file_name, 'wb') as fout: + content = await file.read() + fout.write(content) + logger.info(f"[askdoc - append] file saved to local path: {save_file_name}") - user_id = request.client.host - logger.info(f'[askdoc - append] user id is: {user_id}') - if knowledge_base_id == 'default': - path_prefix = RETRIEVAL_FILE_PATH + 'default' - else: - path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id - upload_path = path_prefix + '/upload_dir' - persist_path = path_prefix + '/persist_dir' - if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ): - if knowledge_base_id == 'default': - os.makedirs(Path(path_prefix), exist_ok=True) - os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True) - os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True) - logger.info(f"Default kb {path_prefix} does not exist, create.") - else: - logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}") - return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \ - Please check kb_id and save path again." - cur_time = get_current_beijing_time() - logger.info(f"[askdoc - append] upload path: {upload_path}") - - # save file to local path - save_file_name = upload_path + '/' + cur_time + '-' + filename - with open(save_file_name, 'wb') as fout: - content = await file.read() - fout.write(content) - logger.info(f"[askdoc - append] file saved to local path: {save_file_name}") - - try: - # get retrieval instance and reload db with new knowledge base - logger.info("[askdoc - append] starting to append to local db...") - instance = plugins['retrieval']["instance"] - instance.append_localdb(append_path=save_file_name, persist_directory=persist_path) - logger.info(f"[askdoc - append] new file successfully appended to kb") - except Exception as e: # pragma: no cover - logger.info(f"[askdoc - append] create knowledge base fails! {e}") - return "Error occurred while uploading files." + try: + # get retrieval instance and reload db with new knowledge base + logger.info("[askdoc - append] starting to append to local db...") + instance = plugins['retrieval']["instance"] + instance.append_localdb(append_path=save_file_name, persist_directory=persist_path) + logger.info(f"[askdoc - append] new file successfully appended to kb") + except Exception as e: # pragma: no cover + logger.info(f"[askdoc - append] create knowledge base fails! {e}") + return "Error occurred while uploading files." return "Succeed" @@ -244,18 +296,24 @@ async def retrieval_chat(request: Request): # parse parameters params = await request.json() - query = params['query'] - origin_query = params['translated'] + origin_query = params['query'] kb_id = params['knowledge_base_id'] stream = params['stream'] max_new_tokens = params['max_new_tokens'] return_link = params['return_link'] - logger.info(f"[askdoc - chat] kb_id: '{kb_id}', query: '{query}', \ + logger.info(f"[askdoc - chat] kb_id: '{kb_id}', \ origin_query: '{origin_query}', stream mode: '{stream}', \ max_new_tokens: '{max_new_tokens}', \ return_link: '{return_link}'") config = GenerationConfig(max_new_tokens=max_new_tokens) + # detect and translate query + detect_res = language_detect(origin_query) + if detect_res['language'] == 'en': + query = origin_query + else: + query = language_translate(origin_query)['translatedText'] + path_prefix = RETRIEVAL_FILE_PATH cur_path = Path(path_prefix) / "default" / "persist_dir" os.makedirs(path_prefix, exist_ok=True) @@ -350,6 +408,27 @@ def stream_generator(): return StreamingResponse(stream_generator(), media_type="text/event-stream") +@router.post("/v1/askdoc/translate") +async def retrieval_translate(request: Request): + user_id = request.client.host + logger.info(f'[askdoc - translate] user id is: {user_id}') + + # parse parameters + params = await request.json() + content = params['content'] + logger.info(f'[askdoc - translate] origin content: {content}') + + detect_res = language_detect(content) + logger.info(f'[askdoc - translate] detected language: {detect_res["language"]}') + if detect_res['language'] == 'en': + translate_res = language_translate(content, target='zh-CN')['translatedText'] + else: + translate_res = language_translate(content, target='en')['translatedText'] + + logger.info(f'[askdoc - translate] translated result: {translate_res}') + return {"tranlated_content": translate_res} + + @router.post("/v1/askdoc/feedback") def save_chat_feedback_to_db(request: FeedbackRequest) -> None: logger.info(f'[askdoc - feedback] fastrag feedback received.') @@ -433,3 +512,59 @@ def data_generator(): data_generator(), media_type='text/csv', headers={"Content-Disposition": f"attachment;filename=feedback{cur_time_str}.csv"}) + + +@router.post("/v1/askdoc/verify_upload") +async def verify_upload(request: Request): + params = await request.json() + user_id = params['user_id'] + logger.info(f'[askdoc - verify_upload] current user: {user_id}') + + if user_id == "admin": + upload_path = RETRIEVAL_FILE_PATH + 'default/upload_dir' + if not os.path.exists(upload_path): + logger.info(f'[askdoc - verify_upload] currently NOT uploaded') + return {"is_uploaded": False} + else: + logger.info(f'[askdoc - verify_upload] currently ALREADY uploaded') + return {"is_uploaded": True} + else: + return JSONResponse( + content={"message": f"Current user {user_id} is not allowed to access /verify_upload api."}, + status_code=status.HTTP_400_BAD_REQUEST) + + +@router.delete("/v1/askdoc/delete_all") +async def delete_all_files(): + delete_path = RETRIEVAL_FILE_PATH + 'default' + if not os.path.exists(delete_path): + logger.info(f'[askdoc - delete_all] No file/link uploaded. Clear.') + return {"status": True} + else: + # delete all upload files + for filename in os.listdir(delete_path+'/upload_dir'): + file_path = os.path.join(delete_path, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f'Failed to delete {filename}. Reason: {e}' + ) + try: + shutil.rmtree(delete_path+'/upload_dir') + except Exception as e: + raise HTTPException( + status_code=500, + detail=f'Failed to delete {delete_path}/upload_dir. Reason: {e}' + ) + # reload default kb + origin_persist_dir = "/home/sdp/askgm_persist_new" + instance = plugins['retrieval']["instance"] + instance.reload_localdb(local_persist_dir = origin_persist_dir) + print(f"[askdoc - delete_all] Original kb loaded from: {origin_persist_dir}") + + return {"status": True} diff --git a/intel_extension_for_transformers/neural_chat/tests/ci/server/test_askdoc_server.py b/intel_extension_for_transformers/neural_chat/tests/ci/server/test_askdoc_server.py index c8439780ad9..9ffbf86090f 100644 --- a/intel_extension_for_transformers/neural_chat/tests/ci/server/test_askdoc_server.py +++ b/intel_extension_for_transformers/neural_chat/tests/ci/server/test_askdoc_server.py @@ -118,7 +118,7 @@ async def test_append_existing_kb_with_links(self): json={**sample_link_list, "knowledge_base_id": gaudi2_kb_id}, ) assert response.status_code == 200 - assert "Succeed" in response.json() + assert response.json()['status'] == True async def test_append_existing_kb(self): # create oneapi knowledge base @@ -133,7 +133,7 @@ async def test_append_existing_kb(self): with open("./gaudi2.txt", "rb") as file: response = client.post( "/v1/askdoc/append", - files={"file": ("./gaudi2.txt", file, "multipart/form-data")}, + files={"files": ("./gaudi2.txt", file, "multipart/form-data")}, data={"knowledge_base_id": oneapi_kb_id}, ) assert response.status_code == 200 @@ -151,7 +151,6 @@ async def test_non_stream_chat(self): gaudi2_kb_id = response.json()["knowledge_base_id"] query_params = { "query": "How about the benchmark test of Habana Gaudi2?", - "translated": "How about the benchmark test of Habana Gaudi2?", "knowledge_base_id": gaudi2_kb_id, "stream": False, "max_new_tokens": 64, @@ -172,7 +171,6 @@ async def test_stream_chat(self): gaudi2_kb_id = response.json()["knowledge_base_id"] query_params = { "query": "How about the benchmark test of Habana Gaudi2?", - "translated": "How about the benchmark test of Habana Gaudi2?", "knowledge_base_id": gaudi2_kb_id, "stream": True, "max_new_tokens": 64,