diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 0ae300dd7..6b30c53e4 100755 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -24,13 +24,14 @@ # from apps.system.models.user import SQLModel # noqa # from apps.settings.models.setting_models import SQLModel -from apps.chat.models.chat_model import SQLModel -from apps.terminology.models.terminology_model import SQLModel +#from apps.chat.models.chat_model import SQLModel +#from apps.terminology.models.terminology_model import SQLModel #from apps.custom_prompt.models.custom_prompt_model import SQLModel -from apps.data_training.models.data_training_model import SQLModel +#from apps.data_training.models.data_training_model import SQLModel # from apps.dashboard.models.dashboard_model import SQLModel from common.core.config import settings # noqa #from apps.datasource.models.datasource import SQLModel +from apps.system.models.system_model import SQLModel target_metadata = SQLModel.metadata diff --git a/backend/alembic/versions/066_update_assistant_model.py b/backend/alembic/versions/066_update_assistant_model.py new file mode 100644 index 000000000..9e7ea17ce --- /dev/null +++ b/backend/alembic/versions/066_update_assistant_model.py @@ -0,0 +1,30 @@ +"""066_update_assistant_model + +Revision ID: 8adc3a4919be +Revises: 8ff90df7871d +Create Date: 2026-04-28 15:55:42.757276 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8adc3a4919be' +down_revision = '8ff90df7871d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('sys_assistant', sa.Column('enable_custom_model', sa.Boolean(), nullable=True)) + op.add_column('sys_assistant', sa.Column('custom_model', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + op.drop_column('sys_assistant', 'custom_model') + op.drop_column('sys_assistant', 'enable_custom_model') + # ### end Alembic commands ### diff --git a/backend/apps/ai_model/model_factory.py b/backend/apps/ai_model/model_factory.py index 03479fd8e..887005ce7 100644 --- a/backend/apps/ai_model/model_factory.py +++ b/backend/apps/ai_model/model_factory.py @@ -14,6 +14,8 @@ from common.utils.utils import prepare_model_arg from langchain_community.llms import VLLMOpenAI from langchain_openai import AzureChatOpenAI + + # from langchain_community.llms import Tongyi, VLLM class LLMConfig(BaseModel): @@ -24,16 +26,17 @@ class LLMConfig(BaseModel): api_key: Optional[str] = None api_base_url: Optional[str] = None additional_params: Dict[str, Any] = {} + class Config: frozen = True def __hash__(self): if hasattr(self, 'additional_params') and isinstance(self.additional_params, dict): - hashable_params = frozenset((k, tuple(v) if isinstance(v, (list, dict)) else v) - for k, v in self.additional_params.items()) + hashable_params = frozenset((k, tuple(v) if isinstance(v, (list, dict)) else v) + for k, v in self.additional_params.items()) else: hashable_params = None - + return hash(( self.model_id, self.model_type, @@ -61,6 +64,7 @@ def llm(self) -> BaseChatModel: """Return the langchain LLM instance""" return self._llm + class OpenAIvLLM(BaseLLM): def _init_llm(self) -> VLLMOpenAI: return VLLMOpenAI( @@ -71,6 +75,7 @@ def _init_llm(self) -> VLLMOpenAI: **self.config.additional_params, ) + class OpenAIAzureLLM(BaseLLM): def _init_llm(self) -> AzureChatOpenAI: api_version = self.config.additional_params.get("api_version") @@ -88,6 +93,8 @@ def _init_llm(self) -> AzureChatOpenAI: streaming=True, **self.config.additional_params, ) + + class OpenAILLM(BaseLLM): def _init_llm(self) -> BaseChatModel: return BaseChatOpenAI( @@ -138,11 +145,15 @@ def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]): return config """ -async def get_default_config() -> LLMConfig: +async def get_default_config(custom_model_id: Optional[int] = None) -> LLMConfig: with Session(engine) as session: - db_model = session.exec( - select(AiModelDetail).where(AiModelDetail.default_model == True) - ).first() + db_model: AiModelDetail | None = None + if custom_model_id: + db_model = session.get(AiModelDetail, custom_model_id) + if not db_model: + db_model = session.exec( + select(AiModelDetail).where(AiModelDetail.default_model == True) + ).first() if not db_model: raise Exception("The system default model has not been set") @@ -150,14 +161,14 @@ async def get_default_config() -> LLMConfig: if db_model.config: try: config_raw = json.loads(db_model.config) - additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if "key" in item and "val" in item} + additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if + "key" in item and "val" in item} except Exception: pass if not db_model.api_domain.startswith("http"): db_model.api_domain = await sqlbot_decrypt(db_model.api_domain) if db_model.api_key: db_model.api_key = await sqlbot_decrypt(db_model.api_key) - # 构造 LLMConfig return LLMConfig( diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 4193029d9..006b8bb9a 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -6,6 +6,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor, Future from datetime import datetime +from dis import specialized from typing import Any, List, Optional, Union, Dict, Iterator import orjson @@ -174,7 +175,13 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C @classmethod async def create(cls, *args, **kwargs): - config: LLMConfig = await get_default_config() + specialized_model_id = None + if args[3]: + if args[3].enable_custom_model: + if args[3].custom_model: + specialized_model_id = args[3].custom_model + print("use custom model: id[" + args[3].custom_model + "]") + config: LLMConfig = await get_default_config(specialized_model_id) instance = cls(*args, **kwargs, config=config) chat_params: list[SysArgModel] = await get_groups(args[0], "chat") diff --git a/backend/apps/swagger/locales/en.json b/backend/apps/swagger/locales/en.json index 0120903eb..512f7d2cc 100644 --- a/backend/apps/swagger/locales/en.json +++ b/backend/apps/swagger/locales/en.json @@ -117,6 +117,8 @@ "assistant_type": "Assistant Type (0: Basic, 1: Advanced, 4: Page)", "assistant_configuration": "Configuration", "assistant_description": "Description", + "assistant_enableCustomModel": "Use specified model", + "assistant_customModel": "Large Language Model", "system_embedded_api": "Page Embedded API", "embedded_resetsecret_api": "Reset Secret", diff --git a/backend/apps/swagger/locales/zh.json b/backend/apps/swagger/locales/zh.json index da4d60551..b9552c4d7 100644 --- a/backend/apps/swagger/locales/zh.json +++ b/backend/apps/swagger/locales/zh.json @@ -117,6 +117,8 @@ "assistant_type": "助手类型(0: 基础, 1: 高级, 4: 页面)", "assistant_configuration": "配置", "assistant_description": "描述", + "assistant_enableCustomModel": "使用指定大模型", + "assistant_customModel": "大语言模型", "system_embedded_api": "页面嵌入式api", "embedded_resetsecret_api": "重置 Secret", diff --git a/backend/apps/system/models/system_model.py b/backend/apps/system/models/system_model.py index 472f18681..73fa34435 100644 --- a/backend/apps/system/models/system_model.py +++ b/backend/apps/system/models/system_model.py @@ -53,6 +53,8 @@ class AssistantBaseModel(SQLModel): app_id: Optional[str] = Field(default=None, max_length=255, nullable=True) app_secret: Optional[str] = Field(default=None, max_length=255, nullable=True) oid: Optional[int] = Field(nullable=True, sa_type=BigInteger(), default=1) + enable_custom_model: Optional[bool] = Field(default=False, nullable=True) + custom_model: Optional[str] = Field(default=None, max_length=255, nullable=True) class AssistantModel(SnowflakeBase, AssistantBaseModel, table=True): __tablename__ = "sys_assistant" diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py index 87b1a41d2..75eebc39c 100644 --- a/backend/apps/system/schemas/system_schema.py +++ b/backend/apps/system/schemas/system_schema.py @@ -111,6 +111,8 @@ class AssistantBase(BaseModel): configuration: Optional[str] = Field(default=None, description=f"{PLACEHOLDER_PREFIX}assistant_configuration") description: Optional[str] = Field(default=None, description=f"{PLACEHOLDER_PREFIX}assistant_description") oid: Optional[int] = Field(default=1, description=f"{PLACEHOLDER_PREFIX}oid") + enable_custom_model: Optional[bool] = Field(default=False, description=f"{PLACEHOLDER_PREFIX}oid") + custom_model: Optional[str] = Field(description=f"{PLACEHOLDER_PREFIX}oid") class AssistantDTO(AssistantBase, BaseCreatorDTO): diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index 8d2f5c5fe..a447c8287 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -87,7 +87,7 @@ template: generate_rules: | 以下是你必须遵守的规则和可以参考的基础示例: - + 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index b8cf83296..df8d29c06 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -665,6 +665,7 @@ "application_name": "Application name", "application_description": "Application description", "cross_domain_settings": "Cross-domain settings", + "enableCustomModel": "Use specified model", "third_party_address": "Please enter the embedded third party address,multiple items separated by semicolons", "set_to_private": "Set as private", "set_to_public": "Set as public", diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json index 0270dd84a..bb269ca1f 100644 --- a/frontend/src/i18n/ko-KR.json +++ b/frontend/src/i18n/ko-KR.json @@ -665,6 +665,7 @@ "application_name": "애플리케이션 이름", "application_description": "애플리케이션 설명", "cross_domain_settings": "교차 도메인 설정", + "enableCustomModel": "지정된 모델 사용", "third_party_address": "임베디드할 제3자 주소를 입력하십시오, 여러 항목을 세미콜론으로 구분", "set_to_private": "비공개로 설정", "set_to_public": "공개로 설정", diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json index 7570c2e2d..902b06421 100644 --- a/frontend/src/i18n/zh-CN.json +++ b/frontend/src/i18n/zh-CN.json @@ -665,6 +665,7 @@ "application_name": "应用名称", "application_description": "应用描述", "cross_domain_settings": "跨域设置", + "enableCustomModel": "使用指定大模型", "third_party_address": "请输入嵌入的第三方地址,多个以分号分割", "set_to_private": "设为私有", "set_to_public": "设为公共", diff --git a/frontend/src/i18n/zh-TW.json b/frontend/src/i18n/zh-TW.json index 715cf77ff..bd2aa57cb 100644 --- a/frontend/src/i18n/zh-TW.json +++ b/frontend/src/i18n/zh-TW.json @@ -665,6 +665,7 @@ "application_name": "應用名稱", "application_description": "應用描述", "cross_domain_settings": "跨網域設定", + "enableCustomModel": "使用指定模型", "third_party_address": "請輸入嵌入的第三方位址,多個以分號分割", "set_to_private": "設為私有", "set_to_public": "設為公共", diff --git a/frontend/src/views/system/embedded/iframe.vue b/frontend/src/views/system/embedded/iframe.vue index b9df371a3..3e0630196 100644 --- a/frontend/src/views/system/embedded/iframe.vue +++ b/frontend/src/views/system/embedded/iframe.vue @@ -20,6 +20,7 @@ import { getList, updateAssistant, saveAssistant, delOne, dsApi } from '@/api/em import { useI18n } from 'vue-i18n' import { cloneDeep } from 'lodash-es' import { useUserStore } from '@/stores/user.ts' +import { modelApi } from '@/api/system.ts' const userStore = useUserStore() defineProps({ @@ -59,6 +60,8 @@ const defaultEmbedded = { description: '', configuration: '', domain: '', + enable_custom_model: false, + custom_model: '', } const currentEmbedded = reactive(cloneDeep(defaultEmbedded)) @@ -98,10 +101,33 @@ const dsListOptions = ref([]) const embeddedListWithSearch = computed(() => { if (!keywords.value) return embeddedList.value return embeddedList.value.filter((ele: any) => - ele.name.toLowerCase().includes(keywords.value.toLowerCase()) + ele.name.toLowerCase().includes(keywords.value.toLowerCase()), ) }) +interface Model { + name: string + model_type: string + base_model: string + id: string + default_model: boolean + supplier: number +} + +const modelList =ref>([]) + +const searchModels = () => { + searchLoading.value = true + modelApi + .queryAll() + .then((res: any) => { + modelList.value = res + }) + .finally(() => { + searchLoading.value = false + }) +} + const userTypeList = [ { name: t('embedded.basic_application'), @@ -147,6 +173,9 @@ const handleBaseEmbedded = (row: any) => { } getDsList() ruleConfigvVisible.value = true + + searchModels() + dialogTitle.value = row?.id ? t('embedded.edit_basic_applications') : t('embedded.create_basic_application') @@ -166,6 +195,9 @@ const handleAdvancedEmbedded = (row: any) => { Object.assign(urlForm, tempData) } ruleConfigvVisible.value = true + + searchModels() + dialogTitle.value = row?.id ? t('embedded.edit_advanced_applications') : t('embedded.creating_advanced_applications') @@ -275,8 +307,8 @@ const validateUrl = (_: any, value: any, callback: any) => { if (value === '') { callback( new Error( - t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings') - ) + t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings'), + ), ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line @@ -320,7 +352,7 @@ const dsRules = { const validatePass = (_: any, value: any, callback: any) => { if (value === '') { callback( - new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url')) + new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url')), ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line @@ -438,6 +470,9 @@ const saveEmbedded = () => { if (!currentEmbedded.id) { delete obj.id } + if (obj.custom_model == undefined){ + obj.custom_model = '' + } req(obj).then(() => { ElMessage({ type: 'success', @@ -491,29 +526,29 @@ const handleEmbedded = (row: any) => { } const copyJsCode = () => { copy(jsCodeElement.value) - .then(function () { + .then(function() { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function () { + .catch(function() { ElMessage.error(t('embedded.copy_failed')) }) } const copyJsCodeFull = () => { copy(jsCodeElementFull.value) - .then(function () { + .then(function() { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function () { + .catch(function() { ElMessage.error(t('embedded.copy_failed')) }) } const copyCode = () => { copy(scriptElement.value) - .then(function () { + .then(function() { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function () { + .catch(function() { ElMessage.error(t('embedded.copy_failed')) }) } @@ -713,7 +748,7 @@ const saveHandler = () => {
- + @@ -991,7 +1045,7 @@ const saveHandler = () => {