Skip to content

Commit

Permalink
release test (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 committed Nov 27, 2023
2 parents 14b4b3e + e66b65d commit 129634d
Show file tree
Hide file tree
Showing 16 changed files with 301 additions and 162 deletions.
4 changes: 2 additions & 2 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def build_flow(graph_data: dict,
}
yield str(StreamData(event='log', data=log_dict))
# # 如果存在文件,当前不操作文件,避免重复操作
if not process_file and chat_id is not None:
if not process_file and vertex.base_type == 'documentloaders':
template_dict = {
key: value
for key, value in vertex.data['node']['template'].items()
if isinstance(value, dict)
}
for key, value in template_dict.items():
if value.get('type') == 'file':
if value.get('type') == 'fileNode':
# 过滤掉文件
vertex.params[key] = ''

Expand Down
19 changes: 4 additions & 15 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from bisheng.database.base import get_session
from bisheng.database.models.flow import Flow
from bisheng.database.models.message import ChatMessage, ChatMessageRead
from bisheng.database.models.user import User
from bisheng.utils.logger import logger
from bisheng.utils.util import get_cache_key
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status
Expand Down Expand Up @@ -103,12 +102,10 @@ async def chat(flow_id: str,
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
db_user = User(user_id=user_id, user_name='')
"""Websocket endpoint for chat."""
if chat_id:
with next(get_session()) as session:
db_flow = session.get(Flow, flow_id)
db_user = session.get(User, user_id) # 用来支持节点判断用户权限
if not db_flow:
await websocket.accept()
message = '该技能已被删除'
Expand All @@ -130,18 +127,9 @@ async def chat(flow_id: str,
graph_data = json.loads(flow_data_store.hget(flow_data_key, 'graph_data'))

try:
process_file = False if chat_id else True
graph = build_flow_no_yield(graph_data=graph_data,
artifacts={},
process_file=process_file,
flow_id=UUID(flow_id).hex,
chat_id=chat_id, user_name=db_user.user_name)
langchain_object = graph.build()
for node in langchain_object:
key_node = get_cache_key(flow_id, chat_id, node.id)
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(flow_id, chat_id), node._built_object)
await chat_manager.handle_websocket(flow_id, chat_id, websocket, user_id)
chat_manager.set_cache(get_cache_key(flow_id, chat_id), None)
await chat_manager.handle_websocket(flow_id, chat_id, websocket, user_id,
gragh_data=graph_data)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
Expand Down Expand Up @@ -314,6 +302,7 @@ async def event_stream(flow_id, chat_id: str):
yield str(StreamData(event='message', data=input_keys_response))
# We need to reset the chat history
chat_manager.chat_history.empty_history(flow_id, chat_id)
chat_manager.set_cache(get_cache_key(flow_id=flow_id, chat_id=chat_id), None)
flow_data_store.hsetkey(flow_data_key, 'status', BuildStatus.SUCCESS.value, expire)
except Exception as exc:
logger.exception(exc)
Expand Down
5 changes: 3 additions & 2 deletions src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ async def process_flow(
async def create_upload_file(file: UploadFile, flow_id: str):
# Cache file
try:
file_path = save_uploaded_file(file.file, folder_name=flow_id)

file_path = save_uploaded_file(file.file, folder_name=flow_id, file_name=file.filename)
if not isinstance(file_path, str):
file_path = str(file_path)
return UploadFileResponse(
flowId=flow_id,
file_path=file_path,
Expand Down
8 changes: 4 additions & 4 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ async def upload_file(*, file: UploadFile = File(...)):
try:
file_name = file.filename
# 缓存本地
file_path = save_uploaded_file(file.file, 'bisheng').as_posix()
# 上传minio

return UploadFileResponse(file_path=file_path + '_' + file_name,)
file_path = save_uploaded_file(file.file, 'bisheng', file_name)
if not isinstance(file_path, str):
file_path = str(file_path) + '_' + file_name
return UploadFileResponse(file_path=file_path)
except Exception as exc:
logger.error(f'Error saving file: {exc}')
raise HTTPException(status_code=500, detail=str(exc)) from exc
Expand Down
3 changes: 1 addition & 2 deletions src/backend/bisheng/api/v1/schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

Expand Down Expand Up @@ -130,7 +129,7 @@ class UploadFileResponse(BaseModel):
"""Upload file response schema."""

flowId: Optional[str]
file_path: Path
file_path: str


class StreamData(BaseModel):
Expand Down
20 changes: 14 additions & 6 deletions src/backend/bisheng/cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Any, Dict

from appdirs import user_cache_dir
from bisheng.settings import settings
from bisheng.utils.minio_client import mino_client, tmp_bucket

CACHE: Dict[str, Any] = {}

Expand Down Expand Up @@ -138,7 +140,7 @@ def save_binary_file(content: str, file_name: str, accepted_types: list[str]) ->


@create_cache_folder
def save_uploaded_file(file, folder_name):
def save_uploaded_file(file, folder_name, file_name):
"""
Save an uploaded file to the specified folder with a hash of its content as the file name.
Expand Down Expand Up @@ -166,15 +168,21 @@ def save_uploaded_file(file, folder_name):

# Use the hex digest of the hash as the file name
hex_dig = sha256_hash.hexdigest()
file_name = hex_dig
md5_name = hex_dig

# Reset the file cursor to the beginning of the file
file.seek(0)

# Save the file with the hash as its name
file_path = folder_path / file_name
with open(file_path, 'wb') as new_file:
while chunk := file.read(8192):
new_file.write(chunk)
if settings.get_knowledge().get('minio'):
# 存储oss
file_byte = file.read()
mino_client.upload_tmp(file_name, file_byte)
file_path = mino_client.get_share_link(file_name, tmp_bucket)
else:
file_path = folder_path / md5_name
with open(file_path, 'wb') as new_file:
while chunk := file.read(8192):
new_file.write(chunk)

return file_path
87 changes: 59 additions & 28 deletions src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from bisheng.database.base import get_session
from bisheng.database.models.model_deploy import ModelDeploy
from bisheng.database.models.recall_chunk import RecallChunk
from bisheng.database.models.report import Report
from bisheng.utils.docx_temp import test_replace_string
from bisheng.utils.logger import logger
from bisheng.utils.minio_client import mino_client
from bisheng.utils.util import get_cache_key
from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain
from langchain.docstore.document import Document
Expand All @@ -26,14 +29,17 @@ async def dispatch_task(self, session: ChatManager,
client_id: str, chat_id: str,
action: str, payload: dict, user_id):
with session.cache_manager.set_client_id(client_id, chat_id):
if not action:
action = 'default'
if action not in self.handler_dict:
raise Exception(f'unknown action {action}')

await self.handler_dict[action](session, client_id, chat_id, payload, user_id)

async def process_report(self, session: ChatManager,
client_id: str, chat_id: str,
payload: Dict, user_id=None):
chat_inputs = payload.pop('inputs', '')
chat_inputs = payload.pop('inputs', {})
chat_inputs.pop('data') if 'data' in chat_inputs else {}
chat_inputs.pop('id') if 'id' in chat_inputs else ''
key = get_cache_key(client_id, chat_id)
Expand All @@ -42,20 +48,40 @@ async def process_report(self, session: ChatManager,
for k, value in artifacts.items():
if k in chat_inputs:
chat_inputs[k] = value
chat_inputs = ChatMessage(message=chat_inputs, category='question',
type='bot', user_id=user_id,)
session.chat_history.add_message(client_id, chat_id, chat_inputs)
chat_message = ChatMessage(message=chat_inputs, category='question', type='bot',
user_id=user_id)
session.chat_history.add_message(client_id, chat_id, chat_message)

# process message
chat_inputs = {'inputs': chat_inputs, 'is_begin': False}
result = await self.process_message(session, client_id, chat_id, chat_inputs, user_id)
# build report
db_session = next(get_session())
template = db_session.exec(select(Report).where(
Report.flow_id == client_id).order_by(Report.id.desc())).first()
if not template:
logger.error('template not support')
return
start_resp = ChatResponse(type='start', user_id=user_id)
await session.send_json(client_id, chat_id, start_resp)
template_muban = mino_client.get_share_link(template.object_name)
test_replace_string(template_muban, result, 'report.docx')
file = mino_client.get_share_link('report.docx')
response = ChatResponse(type='end', intermediate_steps=json.dumps(result),
files=[{'file_url': file, 'file_name': 'report.docx'}],
user_id=user_id)
await session.send_json(client_id, chat_id, response)
close_resp = ChatResponse(type='close', category='system', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)

async def process_message(self, session: ChatManager,
client_id: str, chat_id: str,
payload: Dict, user_id=None):
# Process the graph data and chat message
chat_inputs = payload.pop('inputs', '')
node_id = chat_inputs.pop('id') if 'id' in chat_inputs else ''
chat_inputs = payload.pop('inputs', {})
chat_inputs.pop('id') if 'id' in chat_inputs else ''
is_begin = payload.get('is_begin', True)
key = get_cache_key(client_id, chat_id, node_id)
key = get_cache_key(client_id, chat_id)

artifacts = session.in_memory_cache.get(key + '_artifacts')
if artifacts:
Expand All @@ -67,8 +93,6 @@ async def process_message(self, session: ChatManager,
if is_begin:
# 从file auto trigger process_message, the question already saved
session.chat_history.add_message(client_id, chat_id, chat_inputs)
start_resp = ChatResponse(type='begin', user_id=user_id)
await session.send_json(client_id, chat_id, start_resp)
start_resp = ChatResponse(type='start', user_id=user_id)
await session.send_json(client_id, chat_id, start_resp)

Expand Down Expand Up @@ -102,7 +126,7 @@ async def process_message(self, session: ChatManager,
# Send a response back to the frontend, if needed
intermediate_steps = intermediate_steps or ''
# history = self.chat_history.get_history(client_id, chat_id, filter_messages=False)
await self.intermediate_logs(client_id, chat_id, user_id, intermediate_steps)
await self.intermediate_logs(session, client_id, chat_id, user_id, intermediate_steps)
source = True if source_doucment and chat_id else False
if source:
for doc in source_doucment:
Expand All @@ -118,35 +142,38 @@ async def process_message(self, session: ChatManager,
category='divider', user_id=user_id)
await session.send_json(client_id, chat_id, response)
else:
start_resp.category = 'answer'
await session.send_json(client_id, chat_id, start_resp)
response = ChatResponse(message=result if is_begin else '', type='end',
intermediate_steps=result if not is_begin else '',
category='answer', user_id=user_id,
source=source)
await session.send_json(client_id, chat_id, response)
# 正常
if is_begin:
start_resp.category = 'answer'
await session.send_json(client_id, chat_id, start_resp)
response = ChatResponse(message=result, type='end',
category='answer', user_id=user_id,
source=source)
await session.send_json(client_id, chat_id, response)

# 循环结束
close_resp = ChatResponse(type='close', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)
if is_begin:
close_resp = ChatResponse(type='close', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)

if source:
# 处理召回的chunk
await self.process_source_document(source_doucment, chat_id, response.message_id,
result,)
await self.process_source_document(source_doucment, chat_id,
response.message_id, result,)
return result

async def process_file(self, session: ChatManager,
client_id: str, chat_id: str,
payload: dict, user_id: int):
_, file_name = payload['inputs']['data'][0]['vaule'].split('_', 1)
file_name = payload['inputs']['data'][0]['value']
batch_question = payload['inputs']['questions']
# 如果L3
file = ChatMessage(is_bot=False,
files=[{'file_name': file_name}],
type='end', user_id=user_id)
session.chat_history.add_message(client_id, chat_id, file)
start_resp = ChatResponse(type='begin', category='system', user_id=user_id)
start_resp = ChatResponse(type='start', category='system', user_id=user_id)
await session.send_json(client_id, chat_id, start_resp)

if not batch_question:
# no question
Expand All @@ -158,9 +185,8 @@ async def process_file(self, session: ChatManager,
await session.send_json(client_id, chat_id, start_resp)
return

step_resp = ChatResponse(type='end',
intermediate_steps='File parsing complete, analysis starting',
category='system', user_id=user_id)
step_resp = ChatResponse(intermediate_steps='File parsing complete, analysis starting',
type='end', category='system', user_id=user_id)
await session.send_json(client_id, chat_id, step_resp)

key = get_cache_key(client_id, chat_id)
Expand All @@ -171,7 +197,7 @@ async def process_file(self, session: ChatManager,
for question in batch_question:
if not question:
continue
payload = {'inputs': {input_key: question}}
payload = {'inputs': {input_key: question}, 'is_begin': False}
start_resp.category == 'question'
await session.send_json(client_id, chat_id, start_resp)
step_resp = ChatResponse(type='end',
Expand All @@ -180,6 +206,11 @@ async def process_file(self, session: ChatManager,
user_id=user_id)
await session.send_json(client_id, chat_id, step_resp)
result = await self.process_message(session, client_id, chat_id, payload, user_id)
response_step = ChatResponse(intermediate_steps=result, type='start', category='answer',
user_id=user_id)
await session.send_json(client_id, chat_id, response_step)
response_step.type = 'end'
await session.send_json(client_id, chat_id, response_step)
report = f"""{report}### {question} \n {result} \n """

start_resp.category = 'report'
Expand All @@ -195,7 +226,7 @@ async def process_autogen(self, session: ChatManager,
key = get_cache_key(client_id, chat_id)
langchain_object = session.in_memory_cache.get(key)
logger.info(f'reciever_human_interactive langchain={langchain_object}')
action = payload['inputs'].get('action')
action = payload.get('action')
if action.lower() == 'stop':
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
Expand Down
Loading

0 comments on commit 129634d

Please sign in to comment.