diff --git a/src/backend/bisheng/api/errcode/finetune.py b/src/backend/bisheng/api/errcode/finetune.py index 9895d325..f034b037 100644 --- a/src/backend/bisheng/api/errcode/finetune.py +++ b/src/backend/bisheng/api/errcode/finetune.py @@ -55,3 +55,8 @@ class InvalidExtraParamsError(BaseErrorCode): class TrainFileNotExistError(BaseErrorCode): Code: int = 10120 Msg: str = '训练文件不存在' + + +class GetGPUInfoError(BaseErrorCode): + Code: int = 10125 + Msg: str = '获取GPU信息失败' diff --git a/src/backend/bisheng/api/errcode/server.py b/src/backend/bisheng/api/errcode/server.py index 6213a186..a365084e 100644 --- a/src/backend/bisheng/api/errcode/server.py +++ b/src/backend/bisheng/api/errcode/server.py @@ -2,6 +2,6 @@ # RT服务相关的返回错误码,功能模块代码:100 -class NotFoundServerError(BaseErrorCode): +class NoSftServerError(BaseErrorCode): Code: int = 10001 - Msg: str = '未找到RT服务' + Msg: str = '未找到SFT服务' diff --git a/src/backend/bisheng/api/services/finetune.py b/src/backend/bisheng/api/services/finetune.py index f9423356..5e4a0c91 100644 --- a/src/backend/bisheng/api/services/finetune.py +++ b/src/backend/bisheng/api/services/finetune.py @@ -6,20 +6,21 @@ from uuid import UUID from bisheng.api.errcode.finetune import (CancelJobError, ChangeModelNameError, CreateFinetuneError, - DeleteJobError, ExportJobError, InvalidExtraParamsError, - JobStatusError, NotFoundJobError, TrainDataNoneError, - UnExportJobError) + DeleteJobError, ExportJobError, GetGPUInfoError, + InvalidExtraParamsError, JobStatusError, NotFoundJobError, + TrainDataNoneError, UnExportJobError) from bisheng.api.errcode.model_deploy import NotFoundModelError -from bisheng.api.errcode.server import NotFoundServerError +from bisheng.api.errcode.server import NoSftServerError from bisheng.api.services.rt_backend import RTBackend from bisheng.api.services.sft_backend import SFTBackend -from bisheng.api.utils import parse_server_host +from bisheng.api.utils import parse_gpus, parse_server_host from bisheng.api.v1.schemas import FinetuneInfoResponse, UnifiedResponseModel, resp_200 from bisheng.cache import InMemoryCache from bisheng.database.models.finetune import (Finetune, FinetuneChangeModelName, FinetuneDao, FinetuneExtraParams, FinetuneList, FinetuneStatus) from bisheng.database.models.model_deploy import ModelDeploy, ModelDeployDao -from bisheng.database.models.server import ServerDao +from bisheng.database.models.server import Server, ServerDao +from bisheng.database.models.sft_model import SftModelDao from bisheng.utils.logger import logger from bisheng.utils.minio_client import MinioClient from pydantic import ValidationError @@ -103,6 +104,17 @@ def validate_status(cls, finetune: Finetune, new_status: int) -> UnifiedResponse return JobStatusError.return_resp('发布完成只能变为训练成功') return None + @classmethod + def get_sft_server(cls, server_id: int) -> Server | None: + server = cls.get_server_by_cache(server_id) + if not server: + logger.warning('not found rt server data by id: %s', server_id) + return None + if not server.sft_endpoint: + logger.warning('not found sft endpoint by id: %s', server_id) + return None + return server + @classmethod def create_job(cls, finetune: Finetune) -> UnifiedResponseModel[Finetune]: # 校验额外参数 @@ -110,10 +122,10 @@ def create_job(cls, finetune: Finetune) -> UnifiedResponseModel[Finetune]: if validate_ret is not None: return validate_ret - # 查找RT服务是否存在 - server = ServerDao.find_server(finetune.server) + # 查找SFT服务是否存在 + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() # 查找基础模型是否存在 base_model = ModelDeployDao.find_model(finetune.base_model) @@ -124,7 +136,7 @@ def create_job(cls, finetune: Finetune) -> UnifiedResponseModel[Finetune]: logger.info(f'start create sft job: {finetune.id.hex}') # 拼接指令所需的command参数 command_params = cls.parse_command_params(finetune, base_model) - sft_ret = SFTBackend.create_job(host=parse_server_host(server.endpoint), + sft_ret = SFTBackend.create_job(host=parse_server_host(server.sft_endpoint), job_id=finetune.id.hex, params=command_params) if not sft_ret[0]: logger.error(f'create sft job error: job_id: {finetune.id.hex}, err: {sft_ret[1]}') @@ -147,14 +159,14 @@ def cancel_job(cls, job_id: UUID, user: Any) -> UnifiedResponseModel[Finetune]: if validate_ret is not None: return validate_ret - # 查找RT服务是否存在 - server = ServerDao.find_server(finetune.server) + # 查找SFT服务是否存在 + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() # 调用SFT-backend的API取消任务 logger.info(f'start cancel job_id: {job_id}, user: {user.get("user_name")}') - sft_ret = SFTBackend.cancel_job(host=parse_server_host(server.endpoint), job_id=job_id.hex) + sft_ret = SFTBackend.cancel_job(host=parse_server_host(server.sft_endpoint), job_id=job_id.hex) if not sft_ret[0]: logger.error(f'cancel sft job error: job_id: {job_id}, err: {sft_ret[1]}') return CancelJobError.return_resp() @@ -170,16 +182,16 @@ def delete_job(cls, job_id: UUID, user: Any) -> UnifiedResponseModel[Finetune]: finetune = FinetuneDao.find_job(job_id) if not finetune: return NotFoundJobError.return_resp() - # 查找RT服务是否存在 - server = ServerDao.find_server(finetune.server) + # 查找SFT服务是否存在 + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() - model_name = cls.delete_published_model(finetune, server.endpoint) + model_name = cls.delete_published_model(finetune, server.sft_endpoint) # 调用接口删除训练任务 logger.info(f'start delete sft job: {job_id}, user: {user.get("user_name")}') - sft_ret = SFTBackend.delete_job(host=parse_server_host(server.endpoint), job_id=job_id.hex, + sft_ret = SFTBackend.delete_job(host=parse_server_host(server.sft_endpoint), job_id=job_id.hex, model_name=model_name) if not sft_ret[0]: logger.error(f'delete sft job error: job_id: {job_id}, err: {sft_ret[1]}') @@ -250,14 +262,14 @@ def publish_job(cls, job_id: UUID, user: Any) -> UnifiedResponseModel[Finetune]: if validate_ret is not None: return validate_ret - # 查找RT服务是否存在 - server = ServerDao.find_server(finetune.server) + # 查找SFT服务是否存在 + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() # 调用SFT-backend的API接口 logger.info(f'start export sft job: {job_id}, user: {user.get("user_name")}') - sft_ret = SFTBackend.publish_job(host=parse_server_host(server.endpoint), job_id=job_id.hex, + sft_ret = SFTBackend.publish_job(host=parse_server_host(server.sft_endpoint), job_id=job_id.hex, model_name=finetune.model_name) if not sft_ret[0]: logger.error(f'export sft job error: job_id: {job_id}, err: {sft_ret[1]}') @@ -268,6 +280,10 @@ def publish_job(cls, job_id: UUID, user: Any) -> UnifiedResponseModel[Finetune]: server=str(server.id), endpoint=f'http://{server.endpoint}/v2.1/models') published_model = ModelDeployDao.insert_one(published_model) + + # 记录可用于训练的模型名称 + SftModelDao.insert_sft_model(published_model.model) + # 更新训练任务状态 logger.info('update sft job data') finetune.status = new_status @@ -288,17 +304,18 @@ def cancel_publish_job(cls, job_id: UUID, user: Any) -> UnifiedResponseModel[Fin return validate_ret # 查找RT服务是否存在 - server = ServerDao.find_server(finetune.server) + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() # 调用SFT-backend的API接口 logger.info(f'start cancel export sft job: {job_id}, user: {user.get("user_name")}') - sft_ret = SFTBackend.publish_job(host=parse_server_host(server.endpoint), job_id=job_id.hex, - model_name=finetune.model_name) + sft_ret = SFTBackend.cancel_publish_job(host=parse_server_host(server.sft_endpoint), job_id=job_id.hex, + model_name=finetune.model_name) if not sft_ret[0]: logger.error(f'cancel export sft job error: job_id: {job_id}, err: {sft_ret[1]}') return UnExportJobError.return_resp() + SftModelDao.delete_sft_model(finetune.model_name) # 删除发布的模型信息 logger.info(f'delete published model: {finetune.model_id}') ModelDeployDao.delete_model_by_id(finetune.model_id) @@ -323,7 +340,7 @@ def get_server_by_cache(cls, server_id: int): @classmethod def get_all_job(cls, req_data: FinetuneList) -> UnifiedResponseModel[List[FinetuneInfoResponse]]: - job_list = FinetuneDao.find_jobs(req_data) + job_list, total = FinetuneDao.find_jobs(req_data) ret = [] for job in job_list: tmp = FinetuneInfoResponse(**job.dict()) @@ -333,22 +350,17 @@ def get_all_job(cls, req_data: FinetuneList) -> UnifiedResponseModel[List[Finetu ret.append(tmp) # 异步线程更新任务状态 asyncio.get_event_loop().run_in_executor(sync_job_thread_pool, cls.sync_all_job_status, job_list) - return resp_200(data=ret) + return resp_200(data={'data': ret, 'total': total}) @classmethod def sync_all_job_status(cls, job_list: List[Finetune]) -> None: # 异步线程更新批量任务的状态 - server_cache = {} for finetune in job_list: - if finetune.server in server_cache.keys(): - server = server_cache.get(finetune.server) - else: - server = ServerDao.find_server(finetune.server) - server_cache[finetune.server] = server + server = cls.get_server_by_cache(finetune.server) if not server: logger.error(f'server not found: {finetune.server}') continue - cls.sync_job_status(finetune, server.endpoint) + cls.sync_job_status(finetune, server.sft_endpoint) @classmethod def get_job_info(cls, job_id: UUID) -> UnifiedResponseModel: @@ -357,10 +369,11 @@ def get_job_info(cls, job_id: UUID) -> UnifiedResponseModel: finetune = FinetuneDao.find_job(job_id) if not finetune: return NotFoundJobError.return_resp() - # 查找对应的RT服务 - server = ServerDao.find_server(finetune.server) + # 查找对应的SFT服务 + server = cls.get_sft_server(finetune.server) if not server: - return NotFoundServerError.return_resp() + return NoSftServerError.return_resp() + base_model_name = '' if finetune.base_model != 0: base_model = ModelDeployDao.find_model(finetune.base_model) @@ -368,7 +381,7 @@ def get_job_info(cls, job_id: UUID) -> UnifiedResponseModel: base_model_name = base_model.model # 同步任务执行情况 - cls.sync_job_status(finetune, server.endpoint) + cls.sync_job_status(finetune, server.sft_endpoint) # 获取日志文件 log_data = None @@ -379,9 +392,9 @@ def get_job_info(cls, job_id: UUID) -> UnifiedResponseModel: return resp_200(data={ 'finetune': FinetuneInfoResponse(**finetune.dict(), base_model_name=base_model_name), - 'log': log_data, + 'log': log_data if finetune.status != FinetuneStatus.FAILED.value else finetune.reason, 'loss_data': res_data, # like [{"step": 10, "loss": 0.5}, {"step": 20, "loss": 0.3}] - 'report': finetune.report, + 'report': finetune.report if finetune.report else None, }) @classmethod @@ -477,14 +490,39 @@ def change_published_model_name(cls, finetune: Finetune, model_name: str) -> boo logger.error(f'published model not found, job_id: {finetune.id.hex}, model_id: {finetune.model_id}') return False + server = cls.get_sft_server(finetune.server) + if not server: + logger.error(f'change model server not found, job_id: {finetune.id.hex}, server_id: {finetune.server}') + return False # 调用接口修改已发布模型的名称 - sft_ret = SFTBackend.change_model_name(parse_server_host(published_model.endpoint), finetune.id.hex, + sft_ret = SFTBackend.change_model_name(parse_server_host(server.sft_endpoint), finetune.id.hex, published_model.model, model_name) if not sft_ret[0]: logger.error(f'change model name error: job_id: {finetune.id.hex}, err: {sft_ret[1]}') return False + # 修改可预训练的模型名称 + SftModelDao.change_sft_model(published_model.model, model_name) + # 更新已发布模型的model_name published_model.model = model_name ModelDeployDao.update_model(published_model) return True + + @classmethod + def get_gpu_info(cls) -> UnifiedResponseModel: + """ 获取GPU信息 """ + all_server = ServerDao.find_all_server() + res = [] + for server in all_server: + if not server.sft_endpoint: + continue + sft_ret = SFTBackend.get_gpu_info(parse_server_host(server.sft_endpoint)) + if not sft_ret[0]: + logger.error(f'get gpu info error: server_id: {server.id}, err: {sft_ret[1]}') + return GetGPUInfoError.return_resp() + gpu_info = parse_gpus(sft_ret[1]) + for one in gpu_info: + one['server'] = server.server + res.append(one) + return resp_200(data=res) diff --git a/src/backend/bisheng/api/services/sft_backend.py b/src/backend/bisheng/api/services/sft_backend.py index 17433f43..870734cd 100644 --- a/src/backend/bisheng/api/services/sft_backend.py +++ b/src/backend/bisheng/api/services/sft_backend.py @@ -54,7 +54,7 @@ def publish_job(cls, host: str, job_id: str, model_name: str) -> (bool, str | Di """ 发布训练任务 从训练路径到处到正式路径""" uri = '/v2.1/sft/job/publish' url = '/v2.1/models/sft_elem/infer' - res = requests.post(url, json={'uri': uri, 'job_id': job_id, 'model_name': model_name}) + res = requests.post(f'{host}{url}', json={'uri': uri, 'job_id': job_id, 'model_name': model_name}) return cls.handle_response(res) @classmethod @@ -118,3 +118,10 @@ def change_model_name(cls, host, job_id: str, old_model_name: str, model_name: s json={'uri': uri, 'job_id': job_id, 'old_model_name': old_model_name, 'model_name': model_name}) return cls.handle_response(res) + + @classmethod + def get_gpu_info(cls, host) -> (bool, str): + """ 获取GPU信息 """ + url = '/v2.1/sft/gpu' + res = requests.get(f'{host}{url}') + return cls.handle_response(res) diff --git a/src/backend/bisheng/api/utils.py b/src/backend/bisheng/api/utils.py index 2b4c3051..d2e68159 100644 --- a/src/backend/bisheng/api/utils.py +++ b/src/backend/bisheng/api/utils.py @@ -1,3 +1,6 @@ +import xml.dom.minidom +from typing import Dict, List + from bisheng.api.v1.schemas import StreamData from bisheng.database.base import session_getter from bisheng.database.models.role_access import AccessType, RoleAccess @@ -228,8 +231,8 @@ def access_check(payload: dict, owner_user_id: int, target_id: int, type: Access def get_L2_param_from_flow( - flow_data: dict, - flow_id: str, + flow_data: dict, + flow_id: str, ): graph = Graph.from_payload(flow_data) node_id = [] @@ -292,3 +295,26 @@ def parse_server_host(endpoint: str): """ 将数据库中的endpoints解析为http请求的host """ endpoint = endpoint.replace('http://', '').split('/')[0] return f'http://{endpoint}' + + +# 将 nvidia-smi -q -x 的输出解析为可视化数据 +def parse_gpus(gpu_str: str) -> List[Dict]: + dom_tree = xml.dom.minidom.parseString(gpu_str) + collections = dom_tree.documentElement + gpus = collections.getElementsByTagName('gpu') + res = [] + for one in gpus: + fb_mem_elem = one.getElementsByTagName('fb_memory_usage')[0] + gpu_uuid_elem = one.getElementsByTagName('uuid')[0] + gpu_id_elem = one.getElementsByTagName('minor_number')[0] + gpu_total_mem = fb_mem_elem.getElementsByTagName('total')[0] + free_mem = fb_mem_elem.getElementsByTagName('free')[0] + gpu_utility_elem = one.getElementsByTagName('utilization')[0].getElementsByTagName('gpu_util')[0] + res.append({ + 'gpu_uuid': gpu_uuid_elem.firstChild.data, + 'gpu_id': gpu_id_elem.firstChild.data, + 'gpu_total_mem': '%.2f G' % (float(gpu_total_mem.firstChild.data.split(' ')[0]) / 1024), + 'gpu_used_mem': '%.2f G' % (float(free_mem.firstChild.data.split(' ')[0]) / 1024), + 'gpu_utility': round(float(gpu_utility_elem.firstChild.data.split(' ')[0]) * 100, 2) + }) + return res diff --git a/src/backend/bisheng/api/v1/finetune.py b/src/backend/bisheng/api/v1/finetune.py index c933e3e4..abd990ff 100644 --- a/src/backend/bisheng/api/v1/finetune.py +++ b/src/backend/bisheng/api/v1/finetune.py @@ -77,6 +77,7 @@ async def get_job(*, server: int = Query(default=None, description='关联的RT服务ID'), status: str = Query(default='', title='多个以英文逗号,分隔', description='训练任务的状态,1: 训练中 2: 训练失败 3: 任务中止 4: 训练成功 5: 发布完成'), + model_name: Optional[str] = Query(default='', description='模型名称,模糊搜索'), page: Optional[int] = Query(default=1, description='页码'), limit: Optional[int] = Query(default=10, description='每页条数'), Authorize: AuthJWT = Depends()): @@ -85,7 +86,7 @@ async def get_job(*, status_list = [] if status.strip(): status_list = [int(one) for one in status.strip().split(',')] - req_data = FinetuneList(server=server, status=status_list, page=page, limit=limit) + req_data = FinetuneList(server=server, status=status_list, model_name=model_name, page=page, limit=limit) return FinetuneService.get_all_job(req_data) @@ -159,3 +160,11 @@ async def get_download_url(*, return resp_200(data={ 'url': download_url }) + + +@router.get('/gpu', response_model=UnifiedResponseModel) +async def get_gpu_info(*, + Authorize: AuthJWT = Depends()): + # get login user + Authorize.jwt_required() + return FinetuneService.get_gpu_info() diff --git a/src/backend/bisheng/api/v1/server.py b/src/backend/bisheng/api/v1/server.py index 9dd2be2e..271aea4b 100644 --- a/src/backend/bisheng/api/v1/server.py +++ b/src/backend/bisheng/api/v1/server.py @@ -7,9 +7,11 @@ import requests from bisheng.api.v1.schemas import UnifiedResponseModel, resp_200 from bisheng.database.base import session_getter -from bisheng.database.models.model_deploy import (ModelDeploy, ModelDeployQuery, ModelDeployRead, +from bisheng.database.models.model_deploy import (ModelDeploy, ModelDeployDao, ModelDeployInfo, + ModelDeployQuery, ModelDeployRead, ModelDeployUpdate) from bisheng.database.models.server import Server, ServerCreate, ServerRead +from bisheng.database.models.sft_model import SftModelDao from bisheng.utils.logger import logger from fastapi import APIRouter, HTTPException from sqlalchemy import delete @@ -77,26 +79,28 @@ async def list(*, query: ModelDeployQuery = None): servers = session.exec(select(Server)).all() id2server = {server.id: server for server in servers} name2server = {server.server: server for server in servers} + all_sft_model = SftModelDao.get_all_sft_model() + sft_model_dict = {one.model_name: True for one in all_sft_model} for server in servers: await update_model(server.endpoint, server.id) sql = select(ModelDeploy) if query and query.server: sql = sql.where(ModelDeploy.server == str(name2server.get(query.server).id)) - # get with session_getter() as session: db_model = session.exec(sql.order_by(ModelDeploy.model)).all() + res = [] for model in db_model: model.server = id2server.get(int(model.server)).server - return resp_200(data=db_model) + res.append(ModelDeployInfo(**model.dict(), sft_support=sft_model_dict.get(model.model, False))) + return resp_200(data=res) except Exception as exc: - logger.error(f'Error add server: {exc}') + logger.error(f'Error add server: {exc}', exc_info=True) raise HTTPException(status_code=500, detail=str(exc)) from exc @router.get('/model/{deploy_id}', response_model=UnifiedResponseModel[ModelDeployRead], status_code=201) async def get_model_deploy(*, deploy_id: int): try: - model_deploy = ModelDeployDao.find_model(deploy_id) if not ModelDeployDao: raise HTTPException(status_code=404, detail='配置不存在') diff --git a/src/backend/bisheng/database/models/finetune.py b/src/backend/bisheng/database/models/finetune.py index d64e358e..fde6d50a 100644 --- a/src/backend/bisheng/database/models/finetune.py +++ b/src/backend/bisheng/database/models/finetune.py @@ -6,7 +6,7 @@ from bisheng.database.base import session_getter from bisheng.database.models.base import SQLModelSerializable from pydantic import BaseModel, validator -from sqlmodel import JSON, TEXT, Column, DateTime, Field, select, text, update +from sqlmodel import JSON, TEXT, Column, DateTime, Field, func, select, text, update class TrainMethod(Enum): @@ -83,6 +83,7 @@ class Finetune(FinetuneBase, table=True): class FinetuneList(BaseModel): server: Optional[int] = Field(description='关联的RT服务ID') status: Optional[List[int]] = Field(description='训练任务的状态') + model_name: Optional[str] = Field(description='模型名称, 模糊搜索') page: Optional[int] = Field(default=1, description='页码') limit: Optional[int] = Field(default=10, description='每页条数') @@ -172,13 +173,20 @@ def delete_job(cls, job: Finetune) -> bool: return True @classmethod - def find_jobs(cls, finetune_list: FinetuneList) -> List[Finetune]: + def find_jobs(cls, finetune_list: FinetuneList) -> (List[Finetune], int): offset = (finetune_list.page - 1) * finetune_list.limit with session_getter() as session: statement = select(Finetune) + count_statement = session.query(func.count(Finetune.id)) if finetune_list.server: statement = statement.where(Finetune.server == finetune_list.server) + count_statement = count_statement.filter(Finetune.server == finetune_list.server) if finetune_list.status: statement = statement.where(Finetune.status.in_(finetune_list.status)) + count_statement = count_statement.filter(Finetune.status.in_(finetune_list.status)) + if finetune_list.model_name: + statement = statement.where(Finetune.model_name.like(f'%{finetune_list.model_name}%')) + count_statement = count_statement.filter(Finetune.model_name.like(f'%{finetune_list.model_name}%')) statement = statement.offset(offset).limit(finetune_list.limit).order_by(Finetune.create_time.desc()) - return session.exec(statement).all() + all_jobs = session.exec(statement).all() + return all_jobs, count_statement.scalar() diff --git a/src/backend/bisheng/database/models/model_deploy.py b/src/backend/bisheng/database/models/model_deploy.py index 7154d7d6..dffee077 100644 --- a/src/backend/bisheng/database/models/model_deploy.py +++ b/src/backend/bisheng/database/models/model_deploy.py @@ -84,3 +84,7 @@ class ModelDeployCreate(ModelDeployBase): class ModelDeployUpdate(SQLModelSerializable): id: int config: Optional[str] = None + + +class ModelDeployInfo(ModelDeploy): + sft_support: bool = Field(default=False, description='是否支持微调训练') diff --git a/src/backend/bisheng/database/models/server.py b/src/backend/bisheng/database/models/server.py index 1b6ebb77..41623277 100644 --- a/src/backend/bisheng/database/models/server.py +++ b/src/backend/bisheng/database/models/server.py @@ -9,6 +9,7 @@ class ServerBase(SQLModelSerializable): endpoint: str = Field(index=False, unique=True) + sft_endpoint: str = Field(default='', index=False, description='Finetune服务地址') server: str = Field(index=True) remark: Optional[str] = Field(index=False) create_time: Optional[datetime] = Field(sa_column=Column( @@ -32,6 +33,12 @@ def find_server(cls, server_id: int) -> Server | None: statement = select(Server).where(Server.id == server_id) return session.exec(statement).first() + @classmethod + def find_all_server(cls): + with session_getter() as session: + statement = select(Server) + return session.exec(statement).all() + class ServerRead(ServerBase): id: Optional[int] diff --git a/src/backend/bisheng/database/models/sft_model.py b/src/backend/bisheng/database/models/sft_model.py new file mode 100644 index 00000000..af32216a --- /dev/null +++ b/src/backend/bisheng/database/models/sft_model.py @@ -0,0 +1,61 @@ +from datetime import datetime +from typing import Optional + +from bisheng.database.base import session_getter +from bisheng.database.models.base import SQLModelSerializable +from sqlalchemy import Column, DateTime, delete, text, update +from sqlmodel import Field, select + + +# 可用于训练的model列表 +class SftModelBase(SQLModelSerializable): + id: str = Field(default=None, nullable=False, primary_key=True, description='唯一ID') + create_time: Optional[datetime] = Field(sa_column=Column( + DateTime, nullable=False, index=True, server_default=text('CURRENT_TIMESTAMP'))) + update_time: Optional[datetime] = Field(sa_column=Column( + DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'))) + + +class SftModel(SftModelBase, table=True): + model_name: str = Field(index=True, description='可用于微调训练的模型名称') + + +class SftModelDao(SftModel): + + @classmethod + def get_sft_model(cls, model_name: str) -> SftModel | None: + with session_getter() as session: + statement = select(SftModel).where(SftModel.model_name == model_name) + return session.exec(statement).first() + + @classmethod + def get_all_sft_model(cls): + with session_getter() as session: + statement = select(SftModel) + return session.exec(statement).all() + + @classmethod + def insert_sft_model(cls, model_name: str) -> SftModel: + with session_getter() as session: + model = SftModel(model_name=model_name) + session.add(model) + session.commit() + session.refresh(model) + return model + + @classmethod + def delete_sft_model(cls, model_name: str) -> bool: + with session_getter() as session: + statement = delete(SftModel).where(SftModel.model_name == model_name) + session.exec(statement) + session.commit() + return True + + @classmethod + def change_sft_model(cls, old_model_name, model_name) -> bool: + with session_getter() as session: + update_statement = update(SftModel).where(SftModel.model_name == old_model_name).values( + model_name=model_name) + update_ret = session.exec(update_statement) + session.commit() + return update_ret.rowcount != 0 diff --git a/src/backend/bisheng/main.py b/src/backend/bisheng/main.py index 1a20e668..834fc5c7 100644 --- a/src/backend/bisheng/main.py +++ b/src/backend/bisheng/main.py @@ -20,7 +20,7 @@ def handle_http_exception(req: Request, exc: HTTPException) -> ORJSONResponse: msg = {'status_code': exc.status_code, 'status_message': exc.detail} - logger.error(f'{req.method} {req.url} {exc.status_code} {exc.detail}') + logger.error(f'{req.method} {req.url} {exc.status_code} {exc.detail}', exc_info=True) return ORJSONResponse(content=msg) diff --git a/src/frontend/README.md b/src/frontend/README.md index f8271c20..1ab0a76e 100644 --- a/src/frontend/README.md +++ b/src/frontend/README.md @@ -45,9 +45,10 @@ You can learn more in the [Create React App documentation](https://facebook.gith To learn React, check out the [React documentation](https://reactjs.org/). +TODO +搜索:创建旁边加搜索; +- 知识库页面:搜索知识库名称字段 +- 知识库内文件列表页:搜索文件名称 +- 用户列表页:搜索用户名 +- 角色列表页:搜索角色名 -onlyoffice & LibreOffice Online - -2.2 改动记录 -lodash 替换为 lodash-es -技能列表数据拆分 diff --git a/src/frontend/package-lock.json b/src/frontend/package-lock.json index ac589906..5843487a 100644 --- a/src/frontend/package-lock.json +++ b/src/frontend/package-lock.json @@ -668,6 +668,8 @@ }, "node_modules/@emotion/weak-memoize": { "version": "0.3.1", + "resolved": "https://registry.npmmirror.com/@emotion/weak-memoize/-/weak-memoize-0.3.1.tgz", + "integrity": "sha512-EsBwpc7hBUJWAsNPBmJy4hxWx12v6bshQsldrVmjxJoc3isbxhOrF2IcCpaXxfvq03NwkI7sbsOLXbYuqF/8Ww==" "license": "MIT" }, "node_modules/@esbuild/darwin-x64": { @@ -2573,6 +2575,8 @@ }, "node_modules/@swc/core": { "version": "1.4.0", + "resolved": "https://registry.npmmirror.com/@swc/core/-/core-1.4.0.tgz", + "integrity": "sha512-wc5DMI5BJftnK0Fyx9SNJKkA0+BZSJQx8430yutWmsILkHMBD3Yd9GhlMaxasab9RhgKqZp7Ht30hUYO5ZDvQg==", "dev": true, "hasInstallScript": true, "license": "Apache-2.0", @@ -2610,6 +2614,8 @@ }, "node_modules/@swc/core-darwin-x64": { "version": "1.4.0", + "resolved": "https://registry.npmmirror.com/@swc/core-darwin-x64/-/core-darwin-x64-1.4.0.tgz", + "integrity": "sha512-f8v58u2GsGak8EtZFN9guXqE0Ep10Suny6xriaW2d8FGqESPyNrnBzli3aqkSeQk5gGqu2zJ7WiiKp3XoUOidA==", "cpu": [ "x64" ], diff --git a/src/frontend/public/locales/en/bs.json b/src/frontend/public/locales/en/bs.json index fbdc0562..92fe35c5 100644 --- a/src/frontend/public/locales/en/bs.json +++ b/src/frontend/public/locales/en/bs.json @@ -91,7 +91,8 @@ "createFailureTitle": "Creation failed", "createdBy": "Created by", "offline": "Offline", - "online": "Online" + "online": "Online", + "guideWords": "Guide Words" }, "chat": { "newChat": "New Chat", @@ -148,10 +149,9 @@ "modelFineTune": "Model Finetune", "refreshButton": "Refresh", "gpuResourceUsage": "GPU Resource Usage", - "rtServiceManagement": "RT Service Management", "modelCollectionCaption": "Model Collection", "machine": "Machine", - "serviceAddress": "Service Address", + "serviceAddress": "Server Address", "status": "Status", "online": "Online", "offline": "Offline", @@ -161,7 +161,7 @@ "totalMemory": "Total Memory", "freeMemory": "Free Memory", "gpuUtilization": "GPU Utilization", - "machineName": "Machine Name", + "machineName": "Server Name", "addOne": "Add One" }, "flow": { @@ -296,7 +296,7 @@ "inProgress": "进行中", "failedAborted": "失败/终止", "createTrainingTask": "创建训练任务", - "rtServiceManagement": "RT服务管理", + "rtServiceManagement": "服务管理", "confirmCancelPublish": "该模型正处于上线状态,是否仍然取消发布", "confirmDeleteModel": "确认要删除模型 {{name}} 吗?", "confirmDeleteOnlineModel": "该模型已上线,请将模型下线后再删除/确认要删除模型 {{name}} 吗?", @@ -367,5 +367,6 @@ "operations": "Operations", "previousPage": "Previous Page", "nextPage": "Next Page", - "formatError": "Format Error" + "formatError": "Format Error", + "port": "PORT" } \ No newline at end of file diff --git a/src/frontend/public/locales/zh/bs.json b/src/frontend/public/locales/zh/bs.json index 806fc296..a4762572 100644 --- a/src/frontend/public/locales/zh/bs.json +++ b/src/frontend/public/locales/zh/bs.json @@ -88,7 +88,8 @@ "createFailureTitle": "创建失败", "createdBy": "创建用户", "offline": "下线", - "online": "上线" + "online": "上线", + "guideWords": "引导词" }, "chat": { "newChat": "新建会话", @@ -144,7 +145,6 @@ "modelFineTune": "模型Finetune", "refreshButton": "刷新", "gpuResourceUsage": "GPU资源使用情况", - "rtServiceManagement": "RT服务管理", "modelCollectionCaption": "模型集合", "machine": "机器", "serviceAddress": "服务地址", @@ -157,7 +157,7 @@ "totalMemory": "总显存", "freeMemory": "空余显存", "gpuUtilization": "GPU利用率", - "machineName": "机器名", + "machineName": "服务名", "addOne": "加一条" }, "flow": { @@ -290,9 +290,9 @@ "all": "全部", "successful": "成功", "inProgress": "进行中", - "failedAborted": "失败/终止", + "failedAborted": "失败/中止", "createTrainingTask": "创建训练任务", - "rtServiceManagement": "RT服务管理", + "rtServiceManagement": "服务管理", "confirmCancelPublish": "该模型正处于上线状态,是否仍然取消发布", "confirmDeleteModel": "确认要删除模型 {{name}} 吗?", "confirmDeleteOnlineModel": "该模型已上线,请将模型下线后再删除/确认要删除模型 {{name}} 吗?", @@ -311,7 +311,7 @@ "dataset": "数据集", "cancelPublish": "取消发布", "publish": "发布", - "stop": "终止", + "stop": "中止", "uploadDataset": "上传数据集", "downloadSampleFile": "下载示例文件", "customSampleSize": "自定义样本数(可选)", @@ -336,8 +336,8 @@ "finetuneModelName": "Finetune模型名称", "trainingMethod": "训练方法", "fullFineTune": "全量微调", - "freeze": "freeze", - "lora": "lora", + "freeze": "freeze微调", + "lora": "lora微调", "parameterConfiguration": "参数配置", "parameterConfigurationTooltip": "参数配置建议及实验参考数据详见产品文档。", "parameter": "参数", @@ -364,5 +364,6 @@ "operations": "操作", "previousPage": "上一页", "nextPage": "下一页", - "formatError": "格式错误" + "formatError": "格式错误", + "port": "服务端口" } \ No newline at end of file diff --git a/src/frontend/src/components/PaginationComponent/index.tsx b/src/frontend/src/components/PaginationComponent/index.tsx new file mode 100644 index 00000000..00e7b36c --- /dev/null +++ b/src/frontend/src/components/PaginationComponent/index.tsx @@ -0,0 +1,85 @@ +import { Pagination, PaginationContent, PaginationEllipsis, PaginationItem, PaginationLink, PaginationNext, PaginationPrevious } from '../ui/pagination'; + +const PaginationComponent = ({ page, pageSize, total, maxVisiblePages = 5, onChange }) => { + const totalPages = Math.ceil(total / pageSize); + + const handlePageChange = (newPage) => { + if (newPage >= 1 && newPage <= totalPages && newPage !== page) { + onChange(newPage); + } + }; + + const renderPaginationItems = () => { + const items = []; + + // Previous Button + items.push( + + handlePageChange(page - 1)} /> + + ); + + // Page Buttons + if (totalPages <= maxVisiblePages) { + // If total pages are less than or equal to maxVisiblePages, show all pages + for (let i = 1; i <= totalPages; i++) { + items.push( + + handlePageChange(i)} isActive={i === page}> + {i} + + + ); + } + } else { + // If total pages are more than maxVisiblePages, show at most maxVisiblePages pages + const startPage = Math.max(1, page - Math.floor(maxVisiblePages / 2)); + const endPage = Math.min(totalPages, startPage + maxVisiblePages - 1); + + if (startPage > 1) { + // Display ellipsis if there are pages before the startPage + items.push( + + + + ); + } + + for (let i = startPage; i <= endPage; i++) { + items.push( + + handlePageChange(i)} isActive={i === page}> + {i} + + + ); + } + + if (endPage < totalPages) { + // Display ellipsis if there are pages after the endPage + items.push( + + + + ); + } + } + + // Next Button + items.push( + + handlePageChange(page + 1)} /> + + ); + + return items; + }; + + return ( + + {renderPaginationItems()} + + ); +}; + +export default PaginationComponent; diff --git a/src/frontend/src/components/ui/button.tsx b/src/frontend/src/components/ui/button.tsx index ecfd1cd1..77a4600a 100644 --- a/src/frontend/src/components/ui/button.tsx +++ b/src/frontend/src/components/ui/button.tsx @@ -4,27 +4,27 @@ import * as React from "react"; import { cn } from "../../utils"; const buttonVariants = cva( - "inline-flex items-center justify-center rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:opacity-50 disabled:pointer-events-none ring-offset-background", + "inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50", { variants: { variant: { - default: "bg-primary text-primary-foreground hover:bg-primary/90", + default: + "bg-primary text-primary-foreground shadow hover:bg-primary/90", destructive: - "bg-destructive text-destructive-foreground hover:bg-destructive/90", + "bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90", outline: - "border border-input hover:bg-accent hover:text-accent-foreground", - primary: - "border bg-background text-secondary-foreground hover:bg-background/80 dark:hover:bg-background/10 hover:shadow-sm", + "border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground", secondary: - "border border-muted bg-muted text-secondary-foreground hover:bg-secondary/80", + "bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80", ghost: "hover:bg-accent hover:text-accent-foreground", - link: "underline-offset-4 hover:underline text-primary", + link: "text-primary underline-offset-4 hover:underline", bs: "bg-[#6366f1] text-[#fff] hover:bg-[#4f46e5]", }, size: { - default: "h-10 py-2 px-4", - sm: "h-9 px-8 rounded-md", - lg: "h-11 px-8 rounded-md", + default: "h-9 px-4 py-2", + sm: "h-8 rounded-md px-3 text-xs", + lg: "h-10 rounded-md px-8", + icon: "h-9 w-9", }, }, defaultVariants: { @@ -32,26 +32,27 @@ const buttonVariants = cva( size: "default", }, } -); - +) + export interface ButtonProps extends React.ButtonHTMLAttributes, VariantProps { - asChild?: boolean; + asChild?: boolean } - + const Button = React.forwardRef( ({ className, variant, size, asChild = false, ...props }, ref) => { - const Comp = asChild ? Slot : "button"; + const Comp = asChild ? Slot : "button" return ( - ); + ) } -); -Button.displayName = "Button"; - +) +Button.displayName = "Button" + export { Button, buttonVariants }; + diff --git a/src/frontend/src/components/ui/pagination.tsx b/src/frontend/src/components/ui/pagination.tsx new file mode 100644 index 00000000..90fd0b3d --- /dev/null +++ b/src/frontend/src/components/ui/pagination.tsx @@ -0,0 +1,120 @@ +import * as React from "react" +import { + ChevronLeftIcon, + ChevronRightIcon, + DotsHorizontalIcon, +} from "@radix-ui/react-icons" +import { cn } from "../../utils" +import { ButtonProps, buttonVariants } from "./button" + +const Pagination = ({ className, ...props }: React.ComponentProps<"nav">) => ( +