Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/0.3.1.5 #656

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/backend/bisheng/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bisheng.api.v1 import (assistant_router, chat_router, component_router, endpoints_router,
finetune_router, flows_router, group_router, knowledge_router,
qa_router, report_router, server_router, skillcenter_router,
user_router, validate_router, variable_router)
user_router, validate_router, variable_router, audit_router)
from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc
from fastapi import APIRouter

Expand All @@ -22,6 +22,7 @@
router.include_router(component_router)
router.include_router(assistant_router)
router.include_router(group_router)
router.include_router(audit_router)

router_rpc = APIRouter(prefix='/api/v2', )
router_rpc.include_router(knowledge_router_rpc)
Expand Down
47 changes: 45 additions & 2 deletions src/backend/bisheng/api/services/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,27 @@ def update_flow_list(cls, assistant_id: UUID, flow_list: List[str],
return resp_200()

@classmethod
def get_gpts_tools(cls, user_id: Any, is_preset: Optional[bool] = None) -> List[GptsToolsTypeRead]:
def get_gpts_tools(cls, user: UserPayload, is_preset: Optional[bool] = None) -> List[GptsToolsTypeRead]:
""" 获取用户可见的工具列表 """
# 获取用户可见的工具类别
all_tool_type = GptsToolsDao.get_tool_type(user_id, is_preset)
tool_type_ids_extra = []
if not is_preset:
# 获取自定义工具列表时,需要包含用户可用的工具列表
user_role = UserRoleDao.get_user_roles(user.user_id)
if user_role:
role_ids = [role.role_id for role in user_role]
role_access = RoleAccessDao.get_role_access(role_ids, AccessType.GPTS_TOOL_READ)
if role_access:
tool_type_ids_extra = [int(access.third_id) for access in role_access]
# 获取用户可见的所有工具列表
if is_preset is None:
all_tool_type = GptsToolsDao.get_user_tool_type(user.user_id, tool_type_ids_extra)
elif is_preset:
# 获取预置工具列表
all_tool_type = GptsToolsDao.get_preset_tool_type()
else:
# 获取用户可见的自定义工具列表
all_tool_type = GptsToolsDao.get_user_tool_type(user.user_id, tool_type_ids_extra, False)
tool_type_id = [one.id for one in all_tool_type]
res = []
tool_type_children = {}
Expand Down Expand Up @@ -348,8 +365,26 @@ def add_gpts_tools(cls, user: UserPayload, req: GptsToolsTypeRead) -> UnifiedRes

# 添加工具类别和对应的 工具列表
res = GptsToolsDao.insert_tool_type(req)

cls.add_gpts_tools_hook(user, res)
return resp_200(data=res)

@classmethod
def add_gpts_tools_hook(cls, user: UserPayload, gpts_tool_type: GptsToolsTypeRead) -> bool:
""" 添加自定义工具后的hook函数 """
# 查询下用户所在的用户组
user_group = UserGroupDao.get_user_group(user.user_id)
if user_group:
# 批量将自定义工具插入到关联表里
batch_resource = []
for one in user_group:
batch_resource.append(GroupResource(
group_id=one.group_id,
third_id=gpts_tool_type.id,
type=ResourceTypeEnum.GPTS_TOOL.value))
GroupResourceDao.insert_group_batch(batch_resource)
return True

@classmethod
def update_gpts_tools(cls, user: UserPayload, req: GptsToolsTypeRead) -> UnifiedResponseModel:
"""
Expand Down Expand Up @@ -430,8 +465,16 @@ def delete_gpts_tools(cls, user: UserPayload, tool_type_id: int) -> UnifiedRespo
if exist_tool_type.is_preset:
return ToolTypeIsPresetError.return_resp()
GptsToolsDao.delete_tool_type(tool_type_id)
cls.delete_gpts_tool_hook(user, exist_tool_type)
return resp_200()

@classmethod
def delete_gpts_tool_hook(cls, user: UserPayload, gpts_tool_type) -> bool:
""" 删除自定义工具后的hook函数 """
logger.info(f"delete_gpts_tool_hook id: {gpts_tool_type.id}, user: {user.user_id}")
GroupResourceDao.delete_group_resource_by_third_id(gpts_tool_type.id, ResourceTypeEnum.GPTS_TOOL)
return True

@classmethod
def get_models(cls) -> UnifiedResponseModel:
llm_list = cls.get_gpts_conf('llms')
Expand Down
52 changes: 33 additions & 19 deletions src/backend/bisheng/api/services/role_group_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bisheng.api.services.user_service import UserPayload
from bisheng.database.models.assistant import AssistantDao
from bisheng.database.models.flow import FlowDao
from bisheng.database.models.gpts_tools import GptsToolsDao
from bisheng.database.models.group import Group, GroupCreate, GroupDao, GroupRead, DefaultGroup
from bisheng.database.models.group_resource import GroupResourceDao, ResourceTypeEnum
from bisheng.database.models.knowledge import KnowledgeDao
Expand Down Expand Up @@ -165,6 +166,9 @@ def get_group_resources(self, group_id: int, resource_type: ResourceTypeEnum, na
return self.get_group_knowledge(group_id, name, page_size, page_num)
elif resource_type.value == ResourceTypeEnum.ASSISTANT.value:
return self.get_group_assistant(group_id, name, page_size, page_num)
elif resource_type.value == ResourceTypeEnum.GPTS_TOOL.value:
return self.get_group_tool(group_id, name, page_size, page_num)
logger.warning('not support resource type: %s', resource_type)
return [], 0

def get_user_map(self, user_ids: set[int]):
Expand All @@ -173,13 +177,11 @@ def get_user_map(self, user_ids: set[int]):
return user_map

def get_group_flow(self, group_id: int, keyword: str, page_size: int, page_num: int) -> (List[Any], int):
# 默认分组的话,直接搜索和查询对应的资源数据表即可
if group_id == DefaultGroup:
resource_list = []
else:
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.FLOW)
if not resource_list:
return [], 0
""" 获取用户组下的知识库列表 """
# 查询用户组下的技能ID列表
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.FLOW)
if not resource_list:
return [], 0
res = []
flow_ids = [UUID(resource.third_id) for resource in resource_list]
data, total = FlowDao.filter_flows_by_ids(flow_ids, keyword, page_num, page_size)
Expand All @@ -195,12 +197,9 @@ def get_group_flow(self, group_id: int, keyword: str, page_size: int, page_num:
def get_group_knowledge(self, group_id: int, keyword: str, page_size: int, page_num: int) -> (List[Any], int):
""" 获取用户组下的知识库列表 """
# 查询用户组下的知识库ID列表
if group_id == DefaultGroup:
resource_list = []
else:
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.KNOWLEDGE)
if not resource_list:
return [], 0
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.KNOWLEDGE)
if not resource_list:
return [], 0
res = []
knowledge_ids = [int(resource.third_id) for resource in resource_list]
# 查询知识库
Expand All @@ -216,16 +215,31 @@ def get_group_knowledge(self, group_id: int, keyword: str, page_size: int, page_
def get_group_assistant(self, group_id: int, keyword: str, page_size: int, page_num: int) -> (List[Any], int):
""" 获取用户组下的助手列表 """
# 查询用户组下的助手ID列表
if group_id == DefaultGroup:
resource_list = []
else:
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.ASSISTANT)
if not resource_list:
return [], 0
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.ASSISTANT)
if not resource_list:
return [], 0
res = []
assistant_ids = [UUID(resource.third_id) for resource in resource_list] # 查询助手
data, total = AssistantDao.filter_assistant_by_id(assistant_ids, keyword, page_num, page_size)
for one in data:
simple_one = AssistantService.return_simple_assistant_info(one)
res.append(simple_one)
return res, total

def get_group_tool(self, group_id: int, keyword: str, page_size: int, page_num: int) -> (List[Any], int):
""" 获取用户组下的工具列表 """
# 查询用户组下的工具ID列表
resource_list = GroupResourceDao.get_group_resource(group_id, ResourceTypeEnum.GPTS_TOOL)
if not resource_list:
return [], 0
res = []
tool_ids = [int(resource.third_id) for resource in resource_list]
# 查询工具
data, total = GptsToolsDao.filter_tool_types_by_ids(tool_ids, keyword, page_num, page_size)
db_user_ids = {one.user_id for one in data}
user_map = self.get_user_map(db_user_ids)
for one in data:
one_dict = jsonable_encoder(one)
one_dict["user_name"] = user_map.get(one.user_id, one.user_id)
res.append(one_dict)
return res, total
4 changes: 3 additions & 1 deletion src/backend/bisheng/api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from bisheng.api.v1.usergroup import router as group_router
from bisheng.api.v1.validate import router as validate_router
from bisheng.api.v1.variable import router as variable_router
from bisheng.api.v1.audit import router as audit_router

__all__ = [
'chat_router', 'endpoints_router', 'validate_router', 'flows_router', 'skillcenter_router',
'knowledge_router', 'server_router', 'user_router', 'qa_router', 'variable_router',
'report_router', 'finetune_router', 'component_router', 'assistant_router', 'group_router'
'report_router', 'finetune_router', 'component_router', 'assistant_router', 'group_router',
'audit_router'
]
6 changes: 2 additions & 4 deletions src/backend/bisheng/api/v1/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,9 @@ async def chat(*,


@router.get('/tool_list', response_model=UnifiedResponseModel)
def get_tool_list(*, is_preset: Optional[bool] = None, Authorize: AuthJWT = Depends()):
def get_tool_list(*, is_preset: Optional[bool] = None, login_user: UserPayload = Depends(get_login_user)):
"""查询所有可见的tool 列表"""
Authorize.jwt_required()
current_user = json.loads(Authorize.get_jwt_subject())
return resp_200(AssistantService.get_gpts_tools(current_user.get('user_id'), is_preset))
return resp_200(AssistantService.get_gpts_tools(login_user, is_preset))


@router.post('/tool_schema', response_model=UnifiedResponseModel)
Expand Down
42 changes: 42 additions & 0 deletions src/backend/bisheng/api/v1/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from datetime import datetime
from typing import Optional, List

from fastapi import APIRouter, Query, Depends

from bisheng.api.JWT import get_login_user
from bisheng.api.services.user_service import UserPayload
from bisheng.api.v1.schemas import UnifiedResponseModel, resp_200

router = APIRouter(prefix='/audit', tags=['AuditLog'])


@router.get('', response_model=UnifiedResponseModel)
def get_audit_logs(*,
group_ids: Optional[List[str]] = Query(default=None, description='分组id列表'),
operator_id: Optional[int] = Query(default=None, description='操作人id'),
start_time: Optional[datetime] = Query(default=None, description='开始时间'),
end_time: Optional[datetime] = Query(default=None, description='结束时间'),
system_id: Optional[str] = Query(default=None, description='系统模块'),
event_type: Optional[str] = Query(default=None, description='操作行为'),
page: Optional[int] = Query(default=0, description='页码'),
limit: Optional[int] = Query(default=0, description='每页条数'),
login_user: UserPayload = Depends(get_login_user)):
return resp_200(data={
'data': [
{
"id": "xxxx",
"operator_id": 1, # 操作用户的ID
"operator_name": "xxx", # 操作用户的用户名
"group_ids": [1, 2, 3], # 所属的分组列表
"system_ids": "chat", # 系统模块
"event_type": "create_chat", # 操作行为
"object_type": "flow", # 操作对象类型
"object_id": 1, # 操作对象唯一标识
"object_name": "xxx", # 操作对象名称
"note": "备注", # 备注
"ip_address": "1.1.1.1", # 操作时客户端的IP地址
"create_time": "2023-01-01 00:00:00", # 操作时间
}
],
"total": 10,
})
11 changes: 10 additions & 1 deletion src/backend/bisheng/database/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import json
import os
import uuid
from contextlib import contextmanager
from typing import List

Expand Down Expand Up @@ -118,7 +119,8 @@ def init_default_data():
session.commit()
# 修改表单数据表
sql_query = text(
'UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` WHERE flow_id=a.flow_id and is_current=1)' # noqa
'UPDATE `t_variable_value` a SET a.version_id=(SELECT `id` from `flowversion` WHERE flow_id=a.flow_id and is_current=1)'
# noqa
)
session.execute(sql_query)
session.commit()
Expand Down Expand Up @@ -159,3 +161,10 @@ def read_from_conf(file_path: str) -> str:
content = f.read()

return content


def generate_uuid() -> str:
"""
生成uuid的字符串
"""
return uuid.uuid4().hex
78 changes: 78 additions & 0 deletions src/backend/bisheng/database/models/audit_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from ast import Dict
from datetime import datetime
from typing import Dict, List, Optional

from bisheng.database.base import session_getter, generate_uuid
from bisheng.database.models.base import SQLModelSerializable
from sqlalchemy import Column, DateTime, delete, text, update, Text, func, or_
from sqlmodel import Field, select


class AuditLogBase(SQLModelSerializable):
"""
审计日志表
"""
operator_id: int = Field(index=True, description="操作用户的ID")
operator_name: Optional[str] = Field(index=True, description="用户名")
group_ids: Optional[List[int]] = Field(index=True, description="所属用户组的ID列表")
system_id: Optional[str] = Field(index=True, description="系统模块")
event_type: Optional[str] = Field(index=True, description="操作行为")
object_type: Optional[str] = Field(index=True, description="操作对象类型")
object_id: Optional[int] = Field(index=True, description="操作对象ID")
object_name: Optional[str] = Field(index=True, description="操作对象名称")
note: Optional[str] = Field(sa_column=Column(Text(255)), description="操作备注")
ip_address: Optional[str] = Field(index=True, description="操作时客户端的IP地址")
create_time: Optional[datetime] = Field(sa_column=Column(
DateTime, nullable=False, index=True, server_default=text('CURRENT_TIMESTAMP')), description="操作时间")
update_time: Optional[datetime] = Field(
sa_column=Column(DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
onupdate=text('CURRENT_TIMESTAMP')), description="操作时间")


class AuditLog(AuditLogBase, table=True):
# id = 2 表示默认用户组
id: str = Field(default_factory=generate_uuid, primary_key=True, index=True, description="主键,uuid格式")


class AuditLogDao(AuditLogBase):

@classmethod
def get_audit_logs(cls, group_ids: List[int], operator_id: int = 0, start_time: datetime = None,
end_time: datetime = None, system_id: str = None, event_type: str = None,
page: int = 0, limit: int = 0) -> (List[AuditLog], int):
"""
通过用户组来筛选日志
"""
statement = select(AuditLog)
count_statement = select(func.count(AuditLog.id))
if group_ids:
group_filters = []
for one in group_ids:
group_filters.append(func.json_array_contains(AuditLog.group_ids, one))
statement = statement.where(or_(*group_filters))
count_statement = count_statement.where(or_(*group_filters))
if operator_id:
statement = statement.where(AuditLog.operator_id == operator_id)
count_statement = count_statement.where(AuditLog.operator_id == operator_id)
if start_time and end_time:
statement = statement.where(AuditLog.create_time >= start_time).where(AuditLog.create_time <= end_time)
count_statement = count_statement.where(AuditLog.create_time >= start_time).where(
AuditLog.create_time <= end_time)
if system_id:
statement = statement.where(AuditLog.system_id == system_id)
count_statement = count_statement.where(AuditLog.system_id == system_id)
if event_type:
statement = statement.where(AuditLog.event_type == event_type)
count_statement = count_statement.where(AuditLog.event_type == event_type)
if page and limit:
statement = statement.offset((page - 1) * limit).limit(limit).order_by(AuditLog.create_time.desc())
with session_getter() as session:
return session.exec(statement).all(), session.scalar(count_statement)

@classmethod
def insert_audit_logs(cls, audit_logs: List[AuditLog]):
with session_getter() as session:
session.add_all(audit_logs)
session.commit()
Loading