Skip to content

Commit

Permalink
Feat/0.3.1.4 (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jun 17, 2024
2 parents eeaaf35 + 7865be8 commit 6b055e8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 12 deletions.
23 changes: 17 additions & 6 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def del_chat_id(*, chat_id: str, Authorize: AuthJWT = Depends()):
return resp_200(message='删除成功')


@router.delete('/chat/message/{message_id}', status_code=200)
def del_message_id(*, message_id: str, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
# 获取一条消息
ChatMessageDao.delete_by_message_id(payload.get('user_id'), message_id)

return resp_200(message='删除成功')


@router.post('/liked', status_code=200)
def like_response(*, data: ChatInput, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
Expand Down Expand Up @@ -109,9 +119,9 @@ def get_chatlist_list(*,
smt = (select(ChatMessage.flow_id, ChatMessage.chat_id,
func.max(ChatMessage.create_time).label('create_time'),
func.max(ChatMessage.update_time).label('update_time')).where(
ChatMessage.user_id == payload.get('user_id')).group_by(
ChatMessage.flow_id,
ChatMessage.chat_id).order_by(func.max(ChatMessage.create_time).desc()))
ChatMessage.user_id == payload.get('user_id')).group_by(
ChatMessage.flow_id,
ChatMessage.chat_id).order_by(func.max(ChatMessage.create_time).desc()))
with session_getter() as session:
db_message = session.exec(smt).all()
flow_ids = [message.flow_id for message in db_message]
Expand Down Expand Up @@ -145,15 +155,15 @@ def get_chatlist_list(*,
else:
# 通过接口创建的会话记录,不关联技能或者助手
logger.debug(f'unknown message.flow_id={message.flow_id}')
return resp_200(chat_list[(page-1)*limit:page*limit])
return resp_200(chat_list[(page - 1) * limit:page * limit])


# 获取所有已上线的技能和助手
@router.get('/chat/online',
response_model=UnifiedResponseModel[List[FlowGptsOnlineList]],
status_code=200)
def get_online_chat(*,
keyword: Optional[str]=None,
keyword: Optional[str] = None,
page: Optional[int] = 1,
limit: Optional[int] = 10,
Authorize: AuthJWT = Depends()):
Expand All @@ -167,7 +177,8 @@ def get_online_chat(*,
all_assistant = AssistantDao.get_all_online_assistants()
flows = FlowDao.get_all_online_flows(keyword)
else:
assistants = AssistantService.get_assistant(user, keyword, AssistantStatus.ONLINE.value, 0, 0)
assistants = AssistantService.get_assistant(user, keyword, AssistantStatus.ONLINE.value, 0,
0)
all_assistant = assistants.data.get('data')
flows = FlowDao.get_user_access_online_flows(user_id, keyword=keyword)
for one in all_assistant:
Expand Down
23 changes: 17 additions & 6 deletions src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bisheng.utils.docx_temp import test_replace_string
from bisheng.utils.logger import logger
from bisheng.utils.minio_client import MinioClient
from bisheng.utils.threadpool import thread_pool
from bisheng.utils.util import get_cache_key
from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain
from sqlmodel import select
Expand Down Expand Up @@ -269,11 +270,21 @@ async def process_autogen(self, session: ChatManager, client_id: str, chat_id: s
logger.info(f'reciever_human_interactive langchain={langchain_object}')
action = payload.get('action')
if action.lower() == 'stop':
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
if isinstance(langchain_object, AutoGenChain):
if hasattr(langchain_object, 'stop'):
logger.info('reciever_human_interactive langchain_objct')
await langchain_object.stop()
else:
logger.error(f'act=auto_gen act={action}')
else:
logger.error(f'act=auto_gen act={action}')
# 普通技能的stop
thread_pool.cancel_task([key]) # 将进行中的任务进行cancel
message = payload.get('inputs') or '手动停止'
res = ChatResponse(type='end', message=message)
close = ChatResponse(type='close')
await session.send_json(client_id, chat_id, res)
await session.send_json(client_id, chat_id, close)

elif action.lower() == 'continue':
# autgen_user 对话的时候,进程 wait() 需要换新
if hasattr(langchain_object, 'input'):
Expand Down Expand Up @@ -327,9 +338,9 @@ async def intermediate_logs(self, session: ChatManager, client_id, chat_id, user
if 'source_documents' in s:
answer = eval(s.split(':', 1)[1])
if 'result' in answer:
finally_log += 'Answer: ' + answer.get('result') + "\n\n"
finally_log += 'Answer: ' + answer.get('result') + '\n\n'
else:
finally_log += s + "\n\n"
finally_log += s + '\n\n'
msg = ChatResponse(intermediate_steps=finally_log, type='end', user_id=user_id)
steps.append(msg)
else:
Expand Down
14 changes: 14 additions & 0 deletions src/backend/bisheng/database/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def delete_by_user_chat_id(cls, user_id: int, chat_id: str):
session.commit()
return True

@classmethod
def delete_by_message_id(cls, user_id: int, message_id: str):
if user_id is None or message_id is None:
logger.info('delete_param_error user_id={} chat_id={}', user_id, message_id)
return False

statement = delete(ChatMessage).where(ChatMessage.chat_id == message_id,
ChatMessage.user_id == user_id)

with session_getter() as session:
session.exec(statement)
session.commit()
return True

@classmethod
def insert_one(cls, message: ChatMessage) -> ChatMessage:
with session_getter() as session:
Expand Down
1 change: 1 addition & 0 deletions src/backend/bisheng/database/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def add_user_and_default_role(cls, user: User) -> User:
session.refresh(user)
db_user_role = UserRole(user_id=user.user_id, role_id=DefaultRole)
session.add(db_user_role)
session.commit()
session.refresh(user)
return user

Expand Down

0 comments on commit 6b055e8

Please sign in to comment.