Skip to content

Commit

Permalink
[NeuralChat] Support language detection & translation for RAG chat (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
letonghan committed Mar 25, 2024
1 parent e8c77e7 commit 99df35d
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 57 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/unit-test-neuralchat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ env:
CONTAINER_NAME: "utTest"
EXTRA_CONTAINER_NAME: "modelTest"
CONTAINER_SCAN: "codeScan"
GOOGLE_API_KEY: ${{ vars.GOOGLE_API_KEY }}

jobs:
neuralchat-unit-test:
Expand Down Expand Up @@ -84,6 +85,7 @@ jobs:
-v /home/itrex-docker/models:/models \
-v /dataset/media:/media \
-v /dataset/tf_dataset2:/tf_dataset2 \
-e "GOOGLE_API_KEY=${{ vars.GOOGLE_API_KEY }}" \
${{ env.REPO_NAME }}:${{ env.REPO_TAG }}
- name: Binary build
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
import os
import re
import csv
import shutil
import datetime
import requests
from pathlib import Path
from datetime import timedelta, timezone
from typing import Optional, Dict
from fastapi import APIRouter, UploadFile, File, Request, Response, Form
from typing import Optional, Dict, List
from fastapi import APIRouter, UploadFile, File, Request, Response, Form, status, HTTPException
from ...config import GenerationConfig
from ...cli.log import logger
from ...server.restful.request import RetrievalRequest, FeedbackRequest
from ...server.restful.response import RetrievalResponse
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from ...utils.database.mysqldb import MysqlDb
from ...plugins import plugins

Expand All @@ -51,6 +53,68 @@ def get_current_beijing_time():
return beijing_time


def language_detect(text: str):
url = "https://translation.googleapis.com/language/translate/v2/detect"
try:
api_key = os.getenv("GOOGLE_API_KEY")
logger.info(f"[ language_detect ] GOOGLE_API_KEY: {api_key}")
except Exception as e:
logger.info(f"No GOOGLE_API_KEY found. {e}")
params = {
'key': api_key,
'q': text
}

response = requests.post(url, params=params)
if response.status_code == 200:
res = response.json()
return res["data"]["detections"][0][0]
else:
print("Error status:", response.status_code)
print("Error content:", response.json())
return None


def language_translate(text: str, target: str='en'):
url = "https://translation.googleapis.com/language/translate/v2"
api_key = os.getenv("GOOGLE_API_KEY")
logger.info(f"[ language_translate ] GOOGLE_API_KEY: {api_key}")
params = {
'key': api_key,
'q': text,
'target': target
}

response = requests.post(url, params=params)
if response.status_code == 200:
res = response.json()
return res["data"]["translations"][0]
else:
print("Error status:", response.status_code)
print("Error content:", response.json())
return None


def create_upload_dir(knowledge_base_id: str, user_id: str):
if knowledge_base_id == 'default':
path_prefix = RETRIEVAL_FILE_PATH + 'default'
else:
path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id
upload_path = path_prefix + '/upload_dir'
persist_path = path_prefix + '/persist_dir'
if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ):
if knowledge_base_id == 'default':
os.makedirs(Path(path_prefix), exist_ok=True)
os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True)
os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True)
logger.info(f"Default kb {path_prefix} does not exist, create.")
else:
logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}")
raise Exception(f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
Please check kb_id and save path again.")
return upload_path, persist_path


class RetrievalAPIRouter(APIRouter):

def __init__(self) -> None:
Expand Down Expand Up @@ -93,20 +157,18 @@ async def retrieval_upload_link(request: Request):
if 'knowledge_base_id' in params.keys():
print(f"[askdoc - upload_link] append")
knowledge_base_id = params['knowledge_base_id']
persist_path = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id + '/persist_dir'
if not os.path.exists(persist_path):
return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
Please check kb_id and save path again."
upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id)

try:
print("[askdoc - upload_link] starting to append local db...")
instance = plugins['retrieval']["instance"]
print(f"[askdoc - upload_link] persist_path: {persist_path}")
instance.append_localdb(append_path=link_list, persist_directory=persist_path)
print(f"[askdoc - upload_link] kb appended successfully")
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - upload_link] create knowledge base fails! {e}")
return Response(content="Error occurred while uploading links.", status_code=500)
return {"Succeed"}
return {"status": True}
# create new kb with link
else:
print(f"[askdoc - upload_link] create")
Expand Down Expand Up @@ -185,51 +247,41 @@ async def retrieval_create(request: Request,

@router.post("/v1/askdoc/append")
async def retrieval_append(request: Request,
file: UploadFile = File(...),
files: List[UploadFile] = File(...),
# file: UploadFile = File(...),
knowledge_base_id: str = Form(...)):
global plugins
filename = file.filename
if '/' in filename:
filename = filename.split('/')[-1]
logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}")
for file in files:
filename = file.filename
if '/' in filename:
filename = filename.split('/')[-1]
logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}")

user_id = request.client.host
logger.info(f'[askdoc - append] user id is: {user_id}')

# create local upload dir
upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id)
print(f"[askdoc - upload_link] persist_path: {persist_path}")
cur_time = get_current_beijing_time()
logger.info(f"[askdoc - append] upload path: {upload_path}")

# save file to local path
save_file_name = upload_path + '/' + cur_time + '-' + filename
with open(save_file_name, 'wb') as fout:
content = await file.read()
fout.write(content)
logger.info(f"[askdoc - append] file saved to local path: {save_file_name}")

user_id = request.client.host
logger.info(f'[askdoc - append] user id is: {user_id}')
if knowledge_base_id == 'default':
path_prefix = RETRIEVAL_FILE_PATH + 'default'
else:
path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id
upload_path = path_prefix + '/upload_dir'
persist_path = path_prefix + '/persist_dir'
if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ):
if knowledge_base_id == 'default':
os.makedirs(Path(path_prefix), exist_ok=True)
os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True)
os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True)
logger.info(f"Default kb {path_prefix} does not exist, create.")
else:
logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}")
return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
Please check kb_id and save path again."
cur_time = get_current_beijing_time()
logger.info(f"[askdoc - append] upload path: {upload_path}")

# save file to local path
save_file_name = upload_path + '/' + cur_time + '-' + filename
with open(save_file_name, 'wb') as fout:
content = await file.read()
fout.write(content)
logger.info(f"[askdoc - append] file saved to local path: {save_file_name}")

try:
# get retrieval instance and reload db with new knowledge base
logger.info("[askdoc - append] starting to append to local db...")
instance = plugins['retrieval']["instance"]
instance.append_localdb(append_path=save_file_name, persist_directory=persist_path)
logger.info(f"[askdoc - append] new file successfully appended to kb")
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - append] create knowledge base fails! {e}")
return "Error occurred while uploading files."
try:
# get retrieval instance and reload db with new knowledge base
logger.info("[askdoc - append] starting to append to local db...")
instance = plugins['retrieval']["instance"]
instance.append_localdb(append_path=save_file_name, persist_directory=persist_path)
logger.info(f"[askdoc - append] new file successfully appended to kb")
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - append] create knowledge base fails! {e}")
return "Error occurred while uploading files."
return "Succeed"


Expand All @@ -244,18 +296,24 @@ async def retrieval_chat(request: Request):

# parse parameters
params = await request.json()
query = params['query']
origin_query = params['translated']
origin_query = params['query']
kb_id = params['knowledge_base_id']
stream = params['stream']
max_new_tokens = params['max_new_tokens']
return_link = params['return_link']
logger.info(f"[askdoc - chat] kb_id: '{kb_id}', query: '{query}', \
logger.info(f"[askdoc - chat] kb_id: '{kb_id}', \
origin_query: '{origin_query}', stream mode: '{stream}', \
max_new_tokens: '{max_new_tokens}', \
return_link: '{return_link}'")
config = GenerationConfig(max_new_tokens=max_new_tokens)

# detect and translate query
detect_res = language_detect(origin_query)
if detect_res['language'] == 'en':
query = origin_query
else:
query = language_translate(origin_query)['translatedText']

path_prefix = RETRIEVAL_FILE_PATH
cur_path = Path(path_prefix) / "default" / "persist_dir"
os.makedirs(path_prefix, exist_ok=True)
Expand Down Expand Up @@ -350,6 +408,27 @@ def stream_generator():
return StreamingResponse(stream_generator(), media_type="text/event-stream")


@router.post("/v1/askdoc/translate")
async def retrieval_translate(request: Request):
user_id = request.client.host
logger.info(f'[askdoc - translate] user id is: {user_id}')

# parse parameters
params = await request.json()
content = params['content']
logger.info(f'[askdoc - translate] origin content: {content}')

detect_res = language_detect(content)
logger.info(f'[askdoc - translate] detected language: {detect_res["language"]}')
if detect_res['language'] == 'en':
translate_res = language_translate(content, target='zh-CN')['translatedText']
else:
translate_res = language_translate(content, target='en')['translatedText']

logger.info(f'[askdoc - translate] translated result: {translate_res}')
return {"tranlated_content": translate_res}


@router.post("/v1/askdoc/feedback")
def save_chat_feedback_to_db(request: FeedbackRequest) -> None:
logger.info(f'[askdoc - feedback] fastrag feedback received.')
Expand Down Expand Up @@ -433,3 +512,59 @@ def data_generator():
data_generator(),
media_type='text/csv',
headers={"Content-Disposition": f"attachment;filename=feedback{cur_time_str}.csv"})


@router.post("/v1/askdoc/verify_upload")
async def verify_upload(request: Request):
params = await request.json()
user_id = params['user_id']
logger.info(f'[askdoc - verify_upload] current user: {user_id}')

if user_id == "admin":
upload_path = RETRIEVAL_FILE_PATH + 'default/upload_dir'
if not os.path.exists(upload_path):
logger.info(f'[askdoc - verify_upload] currently NOT uploaded')
return {"is_uploaded": False}
else:
logger.info(f'[askdoc - verify_upload] currently ALREADY uploaded')
return {"is_uploaded": True}
else:
return JSONResponse(
content={"message": f"Current user {user_id} is not allowed to access /verify_upload api."},
status_code=status.HTTP_400_BAD_REQUEST)


@router.delete("/v1/askdoc/delete_all")
async def delete_all_files():
delete_path = RETRIEVAL_FILE_PATH + 'default'
if not os.path.exists(delete_path):
logger.info(f'[askdoc - delete_all] No file/link uploaded. Clear.')
return {"status": True}
else:
# delete all upload files
for filename in os.listdir(delete_path+'/upload_dir'):
file_path = os.path.join(delete_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f'Failed to delete {filename}. Reason: {e}'
)
try:
shutil.rmtree(delete_path+'/upload_dir')
except Exception as e:
raise HTTPException(
status_code=500,
detail=f'Failed to delete {delete_path}/upload_dir. Reason: {e}'
)
# reload default kb
origin_persist_dir = "/home/sdp/askgm_persist_new"
instance = plugins['retrieval']["instance"]
instance.reload_localdb(local_persist_dir = origin_persist_dir)
print(f"[askdoc - delete_all] Original kb loaded from: {origin_persist_dir}")

return {"status": True}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_append_existing_kb_with_links(self):
json={**sample_link_list, "knowledge_base_id": gaudi2_kb_id},
)
assert response.status_code == 200
assert "Succeed" in response.json()
assert response.json()['status'] == True

async def test_append_existing_kb(self):
# create oneapi knowledge base
Expand All @@ -133,7 +133,7 @@ async def test_append_existing_kb(self):
with open("./gaudi2.txt", "rb") as file:
response = client.post(
"/v1/askdoc/append",
files={"file": ("./gaudi2.txt", file, "multipart/form-data")},
files={"files": ("./gaudi2.txt", file, "multipart/form-data")},
data={"knowledge_base_id": oneapi_kb_id},
)
assert response.status_code == 200
Expand All @@ -151,7 +151,6 @@ async def test_non_stream_chat(self):
gaudi2_kb_id = response.json()["knowledge_base_id"]
query_params = {
"query": "How about the benchmark test of Habana Gaudi2?",
"translated": "How about the benchmark test of Habana Gaudi2?",
"knowledge_base_id": gaudi2_kb_id,
"stream": False,
"max_new_tokens": 64,
Expand All @@ -172,7 +171,6 @@ async def test_stream_chat(self):
gaudi2_kb_id = response.json()["knowledge_base_id"]
query_params = {
"query": "How about the benchmark test of Habana Gaudi2?",
"translated": "How about the benchmark test of Habana Gaudi2?",
"knowledge_base_id": gaudi2_kb_id,
"stream": True,
"max_new_tokens": 64,
Expand Down

0 comments on commit 99df35d

Please sign in to comment.