Skip to content

Commit

Permalink
websocket 复用 (#242)
Browse files Browse the repository at this point in the history
websocket 复用
  • Loading branch information
yaojin3616 committed Jan 10, 2024
2 parents fdb0376 + bda9edf commit 4e7d1f5
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 209 deletions.
123 changes: 79 additions & 44 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type='stream', intermediate_steps='')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())

async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str],
Expand Down Expand Up @@ -50,11 +55,10 @@ async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt],
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str,
**kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=f'Tool input: {input_str}',
)
resp = ChatResponse(type='stream',
intermediate_steps=f'Tool input: {input_str}',
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())

async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
Expand All @@ -68,11 +72,10 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
intermediate_steps = f'{observation_prefix}{result}'

# Create a ChatResponse instance.
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=intermediate_steps,
)
resp = ChatResponse(type='stream',
intermediate_steps=intermediate_steps,
flow_id=self.flow_id,
chat_id=self.chat_id)

try:
# This is to emulate the stream of tokens
Expand All @@ -92,8 +95,17 @@ async def on_text(self, text: str, **kwargs: Any) -> Any:
sender = kwargs.get('sender')
receiver = kwargs.get('receiver')
if kwargs.get('sender'):
log = ChatResponse(message=text, type='end', sender=sender, receiver=receiver)
start = ChatResponse(type='start', sender=sender, receiver=receiver)
log = ChatResponse(message=text,
type='end',
sender=sender,
receiver=receiver,
flow_id=self.flow_id,
chat_id=self.chat_id)
start = ChatResponse(type='start',
sender=sender,
receiver=receiver,
flow_id=self.flow_id,
chat_id=self.chat_id)

if receiver and receiver.get('is_self'):
await self.websocket.send_json(log.dict())
Expand All @@ -102,21 +114,31 @@ async def on_text(self, text: str, **kwargs: Any) -> Any:
await self.websocket.send_json(start.dict())
elif 'category' in kwargs:
if 'autogen' == kwargs['category']:
log = ChatResponse(message=text, type='stream')
log = ChatResponse(message=text,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(log.dict())
if kwargs.get('type'):
# 兼容下
start = ChatResponse(type='start', category=kwargs.get('type'))
start = ChatResponse(type='start',
category=kwargs.get('type'),
flow_id=self.flow_id,
chat_id=self.chat_id)
end = ChatResponse(type='end',
intermediate_steps=text,
category=kwargs.get('type'))
category=kwargs.get('type'),
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(start.dict())
await self.websocket.send_json(end.dict())
else:
log = ChatResponse(message=text,
intermediate_steps=kwargs['log'],
type=kwargs['type'],
category=kwargs['category'])
category=kwargs['category'],
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(log.dict())
logger.debug(f'on_text text={text} kwargs={kwargs}')

Expand All @@ -127,19 +149,24 @@ async def on_agent_action(self, action: AgentAction, **kwargs: Any):
if '\n' in log:
logs = log.split('\n')
for log in logs:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())
else:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
await self.websocket.send_json(resp.dict())

async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=finish.log,
)
resp = ChatResponse(flow_id=self.flow_id,
chat_id=self.chat_id,
type='stream',
intermediate_steps=finish.log)
await self.websocket.send_json(resp.dict())

async def on_retriever_start(self, serialized: Dict[str, Any], query: str,
Expand All @@ -163,11 +190,16 @@ async def on_chat_model_start(self, serialized: Dict[str, Any],
class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""

def __init__(self, websocket: WebSocket):
def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str):
self.websocket = websocket
self.flow_id = flow_id
self.chat_id = chat_id

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type='stream', intermediate_steps='')
resp = ChatResponse(message=token,
type='stream',
flow_id=self.flow_id,
chat_id=self.chat_id)

loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
Expand All @@ -180,34 +212,38 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
if '\n' in log:
logs = log.split('\n')
for log in logs:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)
else:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
resp = ChatResponse(type='stream',
intermediate_steps=log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=finish.log,
)
resp = ChatResponse(type='stream',
intermediate_steps=finish.log,
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)

def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=f'Tool input: {input_str}',
)
resp = ChatResponse(type='stream',
intermediate_steps=f'Tool input: {input_str}',
flow_id=self.flow_id,
chat_id=self.chat_id)
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)
Expand All @@ -223,11 +259,10 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
intermediate_steps = f'{observation_prefix}{result}'

# Create a ChatResponse instance.
resp = ChatResponse(
message='',
type='stream',
intermediate_steps=intermediate_steps,
)
resp = ChatResponse(type='stream',
intermediate_steps=intermediate_steps,
flow_id=self.flow_id,
chat_id=self.chat_id)

# Try to send the response, handle potential errors.

Expand Down
4 changes: 2 additions & 2 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ async def chat(
user_id,
gragh_data=graph_data)
except WebSocketException as exc:
logger.error(f'Websocket exrror: {exc}')
logger.error(f'Websocket exrror: {str(exc)}')
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as exc:
logger.error(f'Error in chat websocket: {exc}')
logger.exception(f'Error in chat websocket: {str(exc)}')
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
if 'Could not validate credentials' in str(exc):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason='Unauthorized')
Expand Down
30 changes: 20 additions & 10 deletions src/backend/bisheng/api/v1/qa.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
import asyncio
import json
from typing import List

from bisheng.api.v1.schemas import UnifiedResponseModel, resp_200
from bisheng.database.base import get_session
from bisheng.database.base import get_session, session_getter
from bisheng.database.models.knowledge_file import KnowledgeFile
from bisheng.database.models.recall_chunk import RecallChunk
from bisheng.utils.minio_client import MinioClient
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select

# build router
router = APIRouter(prefix='/qa', tags=['QA'])


@router.get('/keyword', response_model=UnifiedResponseModel[List[str]], status_code=200)
def get_answer_keyword(message_id: int, session: Session = Depends(get_session)):
async def get_answer_keyword(message_id: int):
# 获取命中的key
chunks = session.exec(select(RecallChunk).where(RecallChunk.message_id == message_id)).first()
# keywords
if chunks:
keywords = chunks.keywords
return resp_200(json.loads(keywords))
else:
return []
conter = 3
while True:
with session_getter() as session:
chunks = session.exec(
select(RecallChunk).where(RecallChunk.message_id == message_id)).first()
# keywords
if chunks:
keywords = chunks.keywords
return resp_200(json.loads(keywords))
else:
# 延迟循环
if conter <= 0:
break
await asyncio.sleep(1)
conter -= 1
raise HTTPException(status_code=500, detail='后台处理中,稍后再试')


@router.get('/chunk', status_code=200)
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class ChatMessage(BaseModel):
receiver: dict = None
liked: int = 0
extra: str = '{}'
flow_id: str = None
chat_id: str = None


class ChatResponse(ChatMessage):
Expand Down
3 changes: 3 additions & 0 deletions src/backend/bisheng/cache/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import redis
from bisheng.settings import settings
from loguru import logger
from redis import ConnectionPool


Expand All @@ -17,6 +18,8 @@ def set(self, key, value, expiration=3600):
result = self.connection.setex(key, expiration, pickled)
if not result:
raise ValueError('RedisCache could not set the value.')
else:
logger.error('pickle error, value={}', value)
except TypeError as exc:
raise TypeError('RedisCache only accepts values that can be pickled. ') from exc
finally:
Expand Down
5 changes: 4 additions & 1 deletion src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self) -> None:

async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str, action: str,
payload: dict, user_id):
logger.info(f'dispatch_task payload={payload}')
logger.info(f'dispatch_task payload={payload.get("inputs")}')
start_time = time.time()
with session.cache_manager.set_client_id(client_id, chat_id):
if not action:
Expand All @@ -36,6 +36,7 @@ async def dispatch_task(self, session: ChatManager, client_id: str, chat_id: str

await self.handler_dict[action](session, client_id, chat_id, payload, user_id)
logger.info(f'dispatch_task done timecost={time.time() - start_time}')
return client_id, chat_id

async def process_report(self,
session: ChatManager,
Expand Down Expand Up @@ -136,6 +137,8 @@ async def process_message(self,
langchain_object=langchain_object,
chat_inputs=chat_inputs,
websocket=session.active_connections[get_cache_key(client_id, chat_id)],
flow_id=client_id,
chat_id=chat_id,
)
except Exception as e:
# Log stack trace
Expand Down
Loading

0 comments on commit 4e7d1f5

Please sign in to comment.