diff --git a/src/backend/bisheng/api/v1/chat.py b/src/backend/bisheng/api/v1/chat.py index 4bee8fec7..f8df13d86 100644 --- a/src/backend/bisheng/api/v1/chat.py +++ b/src/backend/bisheng/api/v1/chat.py @@ -73,6 +73,16 @@ def del_chat_id(*, chat_id: str, Authorize: AuthJWT = Depends()): return resp_200(message='删除成功') +@router.delete('/chat/message/{message_id}', status_code=200) +def del_message_id(*, message_id: str, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + payload = json.loads(Authorize.get_jwt_subject()) + # 获取一条消息 + ChatMessageDao.delete_by_message_id(payload.get('user_id'), message_id) + + return resp_200(message='删除成功') + + @router.post('/liked', status_code=200) def like_response(*, data: ChatInput, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -109,9 +119,9 @@ def get_chatlist_list(*, smt = (select(ChatMessage.flow_id, ChatMessage.chat_id, func.max(ChatMessage.create_time).label('create_time'), func.max(ChatMessage.update_time).label('update_time')).where( - ChatMessage.user_id == payload.get('user_id')).group_by( - ChatMessage.flow_id, - ChatMessage.chat_id).order_by(func.max(ChatMessage.create_time).desc())) + ChatMessage.user_id == payload.get('user_id')).group_by( + ChatMessage.flow_id, + ChatMessage.chat_id).order_by(func.max(ChatMessage.create_time).desc())) with session_getter() as session: db_message = session.exec(smt).all() flow_ids = [message.flow_id for message in db_message] @@ -145,7 +155,7 @@ def get_chatlist_list(*, else: # 通过接口创建的会话记录,不关联技能或者助手 logger.debug(f'unknown message.flow_id={message.flow_id}') - return resp_200(chat_list[(page-1)*limit:page*limit]) + return resp_200(chat_list[(page - 1) * limit:page * limit]) # 获取所有已上线的技能和助手 @@ -153,7 +163,7 @@ def get_chatlist_list(*, response_model=UnifiedResponseModel[List[FlowGptsOnlineList]], status_code=200) def get_online_chat(*, - keyword: Optional[str]=None, + keyword: Optional[str] = None, page: Optional[int] = 1, limit: Optional[int] = 10, Authorize: AuthJWT = Depends()): @@ -167,7 +177,8 @@ def get_online_chat(*, all_assistant = AssistantDao.get_all_online_assistants() flows = FlowDao.get_all_online_flows(keyword) else: - assistants = AssistantService.get_assistant(user, keyword, AssistantStatus.ONLINE.value, 0, 0) + assistants = AssistantService.get_assistant(user, keyword, AssistantStatus.ONLINE.value, 0, + 0) all_assistant = assistants.data.get('data') flows = FlowDao.get_user_access_online_flows(user_id, keyword=keyword) for one in all_assistant: diff --git a/src/backend/bisheng/chat/handlers.py b/src/backend/bisheng/chat/handlers.py index d4099db62..1a9cc7cf6 100644 --- a/src/backend/bisheng/chat/handlers.py +++ b/src/backend/bisheng/chat/handlers.py @@ -10,6 +10,7 @@ from bisheng.utils.docx_temp import test_replace_string from bisheng.utils.logger import logger from bisheng.utils.minio_client import MinioClient +from bisheng.utils.threadpool import thread_pool from bisheng.utils.util import get_cache_key from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain from sqlmodel import select @@ -269,11 +270,21 @@ async def process_autogen(self, session: ChatManager, client_id: str, chat_id: s logger.info(f'reciever_human_interactive langchain={langchain_object}') action = payload.get('action') if action.lower() == 'stop': - if hasattr(langchain_object, 'stop'): - logger.info('reciever_human_interactive langchain_objct') - await langchain_object.stop() + if isinstance(langchain_object, AutoGenChain): + if hasattr(langchain_object, 'stop'): + logger.info('reciever_human_interactive langchain_objct') + await langchain_object.stop() + else: + logger.error(f'act=auto_gen act={action}') else: - logger.error(f'act=auto_gen act={action}') + # 普通技能的stop + thread_pool.cancel_task([key]) # 将进行中的任务进行cancel + message = payload.get('inputs') or '手动停止' + res = ChatResponse(type='end', message=message) + close = ChatResponse(type='close') + await session.send_json(client_id, chat_id, res) + await session.send_json(client_id, chat_id, close) + elif action.lower() == 'continue': # autgen_user 对话的时候,进程 wait() 需要换新 if hasattr(langchain_object, 'input'): @@ -327,9 +338,9 @@ async def intermediate_logs(self, session: ChatManager, client_id, chat_id, user if 'source_documents' in s: answer = eval(s.split(':', 1)[1]) if 'result' in answer: - finally_log += 'Answer: ' + answer.get('result') + "\n\n" + finally_log += 'Answer: ' + answer.get('result') + '\n\n' else: - finally_log += s + "\n\n" + finally_log += s + '\n\n' msg = ChatResponse(intermediate_steps=finally_log, type='end', user_id=user_id) steps.append(msg) else: diff --git a/src/backend/bisheng/database/models/message.py b/src/backend/bisheng/database/models/message.py index ac3639ebe..cca5525cf 100644 --- a/src/backend/bisheng/database/models/message.py +++ b/src/backend/bisheng/database/models/message.py @@ -109,6 +109,20 @@ def delete_by_user_chat_id(cls, user_id: int, chat_id: str): session.commit() return True + @classmethod + def delete_by_message_id(cls, user_id: int, message_id: str): + if user_id is None or message_id is None: + logger.info('delete_param_error user_id={} chat_id={}', user_id, message_id) + return False + + statement = delete(ChatMessage).where(ChatMessage.chat_id == message_id, + ChatMessage.user_id == user_id) + + with session_getter() as session: + session.exec(statement) + session.commit() + return True + @classmethod def insert_one(cls, message: ChatMessage) -> ChatMessage: with session_getter() as session: diff --git a/src/backend/bisheng/database/models/user.py b/src/backend/bisheng/database/models/user.py index 9ee622353..f384051cf 100644 --- a/src/backend/bisheng/database/models/user.py +++ b/src/backend/bisheng/database/models/user.py @@ -138,6 +138,7 @@ def add_user_and_default_role(cls, user: User) -> User: session.refresh(user) db_user_role = UserRole(user_id=user.user_id, role_id=DefaultRole) session.add(db_user_role) + session.commit() session.refresh(user) return user