Skip to content

Commit

Permalink
add new version 0.1.10
Browse files Browse the repository at this point in the history
autogen
  • Loading branch information
yaojin3616 committed Nov 22, 2023
2 parents 81536a4 + 3d5fa72 commit 7b32122
Show file tree
Hide file tree
Showing 89 changed files with 2,906 additions and 593 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ concurrency:

jobs:
build_bisheng_langchain:
strategy:
fail-fast: false
runs-on: ubuntu-latest
#if: startsWith(github.event.ref, 'refs/tags')
steps:
Expand Down Expand Up @@ -49,8 +47,6 @@ jobs:
twine upload dist/* -u ${{ secrets.PYPI_USER }} -p ${{ secrets.PYPI_PASSWORD }} --repository pypi
set -e
build_bisheng:
needs: build_bisheng_langchain
runs-on: ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ redis_url: "redis://redis:6379/0"

environment:
env: dev
uns_support: ['png','jpg','jpeg','bmp','doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md', 'html', 'pdf']
uns_support: ['png','jpg','jpeg','bmp','doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx', 'txt', 'md', 'html', 'pdf', 'csv', 'tiff']

# admin 用户配置
admin:
user_name: "admin"
password: "1234"

agents:
agents:
ZeroShotAgent:
documentation: "https://python.langchain.com/docs/modules/agents/how_to/custom_mrkl_agent"
JsonAgent:
Expand Down
47 changes: 47 additions & 0 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.schema import AgentFinish, LLMResult
from langchain.schema.agent import AgentAction
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage


# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
Expand Down Expand Up @@ -86,6 +87,26 @@ async def on_text(self, text: str, **kwargs: Any) -> Any:
# This runs when first sending the prompt
# to the LLM, adding it will send the final prompt
# to the frontend
sender = kwargs.get('sender')
receiver = kwargs.get('receiver')
if kwargs.get('sender'):
log = ChatResponse(message=text, type='end',
sender=sender, receiver=receiver)
start = ChatResponse(type='start', sender=sender, receiver=receiver)

if receiver and receiver.get('is_self'):
await self.websocket.send_json(log.dict())
else:
await self.websocket.send_json(log.dict())
await self.websocket.send_json(start.dict())
elif kwargs.get('type'):
start = ChatResponse(type='start', category=kwargs.get('type'))
end = ChatResponse(type='end', intermediate_steps=text, category=kwargs.get('type'))
await self.websocket.send_json(start.dict())
await self.websocket.send_json(end.dict())
elif 'category' in kwargs:
log = ChatResponse(message=text, type='stream')
await self.websocket.send_json(log.dict())

async def on_agent_action(self, action: AgentAction, **kwargs: Any):
log = f'Thought: {action.log}'
Expand Down Expand Up @@ -118,6 +139,14 @@ async def on_retriever_end(self, result: List[Document], **kwargs: Any) -> Any:
# todo 判断技能权限
logger.debug(f'retriver_result result={result}')

async def on_chat_model_start(self, serialized: Dict[str, Any],
messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
# """Run when retriever end running."""
# content = messages[0][0] if isinstance(messages[0][0], str) else messages[0][0].get('content')
# stream = ChatResponse(message=f'{content}', type='stream')
# await self.websocket.send_json(stream.dict())
logger.debug(f'chat_message result={messages}')


class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""
Expand Down Expand Up @@ -213,3 +242,21 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any],
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
logger.debug(f'on_chain_end outputs={outputs}')

def on_chat_model_start(self, serialized: Dict[str, Any],
messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
"""Run when retriever end running."""
sender = kwargs['sender']
receiver = kwargs['receiver']
content = messages[0][0] if isinstance(messages[0][0], str) else messages[0][0].get('content')
end = ChatResponse(message=f'{content}', type='end', sender=sender, recevier=receiver)
start = ChatResponse(type='start', sender=sender, recevier=receiver)
loop = asyncio.get_event_loop()
coroutine2 = self.websocket.send_json(end.dict())
coroutine3 = self.websocket.send_json(start.dict())
asyncio.run_coroutine_threadsafe(coroutine2, loop)
asyncio.run_coroutine_threadsafe(coroutine3, loop)
logger.debug(f'on_chat result={messages}')

def on_text(self, text: str, **kwargs) -> Any:
logger.info(text)
18 changes: 17 additions & 1 deletion src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fastapi.params import Depends
from fastapi.responses import StreamingResponse
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import func
from sqlalchemy import delete, func
from sqlmodel import Session, select

router = APIRouter(tags=['Chat'])
Expand Down Expand Up @@ -46,6 +46,22 @@ def get_chatmessage(*,
return [jsonable_encoder(message) for message in db_message]


@router.delete('/chat/{chat_id}', status_code=200)
def del_chat_id(*,
session: Session = Depends(get_session),
chat_id: str,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

statement = delete(ChatMessage).where(ChatMessage.chat_id == chat_id,
ChatMessage.user_id == payload.get('user_id'))

session.exec(statement)
session.commit()
return {'status_code': 200, 'status_message': 'success'}


@router.get('/chat/list', response_model=List[ChatList], status_code=200)
def get_chatlist_list(*, session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
Expand Down
31 changes: 18 additions & 13 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,32 @@ async def process_knowledge(*,
for path in file_path:
filepath, file_name = path.split('_', 1)
md5_ = filepath.rsplit('/', 1)[1]
db_file = KnowledgeFile(knowledge_id=knowledge_id,
file_name=file_name,
status=1,
md5=md5_,
# 是否包含重复文件
repeat = session.exec(select(KnowledgeFile
).where(KnowledgeFile.md5 == md5_, KnowledgeFile.status == 2,
KnowledgeFile.knowledge_id == knowledge_id)).all()
status = 3 if repeat else 1
remark = 'file repeat' if repeat else ''
db_file = KnowledgeFile(knowledge_id=knowledge_id, file_name=file_name,
status=status, md5=md5_, remark=remark,
user_id=payload.get('user_id'))

session.add(db_file)
session.commit()
session.refresh(db_file)
files.append(db_file)
file_paths.append(filepath)
logger.info(f'fileName={file_name} col={collection_name}')

asyncio.create_task(
addEmbedding(collection_name=collection_name,
model=knowledge.model,
chunk_size=chunk_size,
separator=separator,
chunk_overlap=chunk_overlap,
file_paths=file_paths,
knowledge_files=files))
if not repeat:
asyncio.create_task(
addEmbedding(collection_name=collection_name,
model=knowledge.model,
chunk_size=chunk_size,
separator=separator,
chunk_overlap=chunk_overlap,
file_paths=file_paths,
knowledge_files=files))

knowledge.update_time = db_file.create_time
session.add(knowledge)
Expand Down Expand Up @@ -383,7 +389,6 @@ async def addEmbedding(collection_name, model: str, chunk_size: int, separator:
setattr(db_file, 'remark', str(e)[:500])
session.add(db_file)
session.commit()
raise e


def _read_chunk_text(input_file, file_name, size, chunk_overlap, separator):
Expand Down
6 changes: 4 additions & 2 deletions src/backend/bisheng/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,22 @@ class ChatMessage(BaseModel):
"""Chat message schema."""

is_bot: bool = False
message: Union[str, None, dict] = None
message: Union[str, None, dict] = ''
type: str = 'human'
category: str = 'processing'
intermediate_steps: str = None
files: list = []
user_id: int = None
message_id: int = None
source: bool = False
sender: str = None
receiver: dict = None


class ChatResponse(ChatMessage):
"""Chat response schema."""

intermediate_steps: str
intermediate_steps: str = ''
type: str
is_bot: bool = True
files: list = []
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/v1/skillcenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def create_template(*, session: Session = Depends(get_session), template: Templa
# 校验name
name_repeat = session.exec(select(Template).where(Template.name == db_template.name)).first()
if name_repeat:
raise HTTPException(status_code=500, detail='模板名称重复,请重新检查')
raise HTTPException(status_code=500, detail='Repeat name, please choose another name')
# 增加 order_num x,x+65535
max_order = session.exec(select(Template).order_by(Template.order_num.desc()).limit(1)).first()
# 如果没有数据,就从 65535 开始
Expand Down
20 changes: 14 additions & 6 deletions src/backend/bisheng/api/v1/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,27 @@

@router.post('/user/regist', response_model=UserRead, status_code=201)
async def regist(*, session: Session = Depends(get_session), user: UserCreate):
db_user = User.from_orm(user)
# check if admin user
admin = session.exec(select(User).where(User.user_id == 1)).all()
if not admin:
db_user_role = UserRole(user_id=db_user.user_id, role_id=1)
db_user.user_id = 1

# check if user already exist
db_user = session.exec(select(User).where(User.user_name == user.user_name)).all()
if db_user:
name_user = session.exec(select(User).where(User.user_name == user.user_name)).all()
if name_user:
raise HTTPException(status_code=500, detail='账号已存在')
else:
try:
user.password = md5_hash(user.password)
db_user = User.from_orm(user)
db_user.password = md5_hash(user.password)
session.add(db_user)
session.flush()
session.refresh(db_user)
# 默认加入普通用户
db_user_role = UserRole(user_id=db_user.user_id, role_id=2)
session.add(db_user_role)
if db_user != 1:
db_user_role = UserRole(user_id=db_user.user_id, role_id=2)
session.add(db_user_role)
session.commit()
except Exception as e:
session.rollback()
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/v2/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def union_websocket(flow_id: str,
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(flow_id, chat_id), node._built_object)
await chat_manager.handle_websocket(flow_id, chat_id, websocket,
settings.get_from_db('default_operator'))
settings.get_from_db('default_operator').get('user'))
except Exception as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
10 changes: 6 additions & 4 deletions src/backend/bisheng/api/v2/filelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ def create_knowledge(
db_knowldge = Knowledge.from_orm(knowledge)
know = session.exec(
select(Knowledge).where(Knowledge.name == knowledge.name,
knowledge.user_id == settings.get_from_db('default_operator'))
knowledge.user_id == settings.get_from_db(
'default_operator').get('user'))
).all()
if know:
raise HTTPException(status_code=500, detail='知识库名称重复')
if not db_knowldge.collection_name:
# 默认collectionName
db_knowldge.collection_name = f'col_{int(time.time())}_{str(uuid4())[:8]}'
db_knowldge.user_id = settings.get_from_db('default_operator')
db_knowldge.user_id = settings.get_from_db('default_operator').get('user')
session.add(db_knowldge)
session.commit()
session.refresh(db_knowldge)
Expand All @@ -61,7 +62,8 @@ def update_knowledge(

know = session.exec(
select(Knowledge).where(Knowledge.name == knowledge.name,
knowledge.user_id == settings.get_from_db('default_operator')
knowledge.user_id == settings.get_from_db(
'default_operator').get('user')
)).all()
if know:
raise HTTPException(status_code=500, detail='知识库名称重复')
Expand All @@ -81,7 +83,7 @@ def get_knowledge(
page_num: Optional[str],
):
""" 读取所有知识库信息. """
default_user_id = settings.get_from_db('default_operator')
default_user_id = settings.get_from_db('default_operator').get('user')
try:
sql = select(Knowledge)
count_sql = select(func.count(Knowledge.id))
Expand Down
3 changes: 0 additions & 3 deletions src/backend/bisheng/cache/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import abc

from bisheng import settings
from fastapi import logger


class BaseCache(abc.ABC):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/cache/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def set_client_id(self, client_id: str, chat_id: str):
previous_client_id = self.current_client_id
previous_chat_id = self.current_chat_id
self.current_client_id = client_id
self.current_chat_id= chat_id
self.current_chat_id = chat_id
self.current_cache = self._cache.setdefault(get_cache_key(client_id, chat_id), {})
try:
yield
Expand Down
Loading

0 comments on commit 7b32122

Please sign in to comment.