diff --git a/README.md b/README.md index aafdfcb..251a41f 100644 --- a/README.md +++ b/README.md @@ -37,125 +37,6 @@ yarn add @yokowu/modelkit-ui 需要实现以下4个接口,其中 `listModel` 和 `checkModel` 已提供业务逻辑,在handler中调用即可: -#### ListModel 接口 -- **请求参数** (`domain.ModelListReq`): - ```go - type ModelListReq struct { - Provider string `json:"provider" validate:"required,oneof=SiliconFlow OpenAI Ollama DeepSeek Moonshot AzureOpenAI BaiZhiCloud Hunyuan BaiLian Volcengine Gemini ZhiPu"` - BaseURL string `json:"base_url" validate:"required"` - APIKey string `json:"api_key"` - APIHeader string `json:"api_header"` - Type string `json:"type" validate:"required,oneof=chat embedding rerank"` - } - ``` -- **响应参数** (`domain.ModelListResp`): - ```go - type ModelListResp struct { - Models []ModelListItem `json:"models"` - } - type ModelListItem struct { - Model string `json:"model"` - } - ``` - -#### CheckModel 接口 -- **请求参数** (`domain.CheckModelReq`): - ```go - type CheckModelReq struct { - Provider string `json:"provider" validate:"required,oneof=OpenAI Ollama DeepSeek SiliconFlow Moonshot Other AzureOpenAI BaiZhiCloud Hunyuan BaiLian Volcengine Gemini ZhiPu"` - Model string `json:"model" validate:"required"` - BaseURL string `json:"base_url" validate:"required"` - APIKey string `json:"api_key"` - APIHeader string `json:"api_header"` - APIVersion string `json:"api_version"` // for azure openai - Type string `json:"type" validate:"required,oneof=chat embedding rerank"` - } - ``` -- **响应参数** (`domain.CheckModelResp`): - ```go - type CheckModelResp struct { - Error string `json:"error"` - Content string `json:"content"` - } - ``` - -#### CreateModel 接口 -- **请求参数** (`CreateModelReq`): - ```go - type CreateModelReq struct { - APIBase string `json:"api_base" validate:"required"` - APIHeader string `json:"api_header"` - APIKey string `json:"api_key"` - APIVersion string `json:"api_version"` - ModelName string `json:"model_name" validate:"required"` - ModelType string `json:"model_type" validate:"oneof=llm coder embedding audio reranker"` - Param *ModelParam `json:"param"` - Provider string `json:"provider" validate:"required,oneof=SiliconFlow OpenAI Ollama DeepSeek Moonshot AzureOpenAI BaiZhiCloud Hunyuan BaiLian Volcengine Other"` - ShowName string `json:"show_name"` - } - - type ModelParam struct { - ContextWindow int `json:"context_window"` - MaxTokens int `json:"max_tokens"` - R1Enabled bool `json:"r1_enabled"` - SupportComputerUse bool `json:"support_computer_use"` - SupportImages bool `json:"support_images"` - SupportPromptCache bool `json:"support_prompt_cache"` - } - ``` -- **响应参数** (`CreateModelResp`): - ```go - type CreateModelResp struct { - Model Model `json:"model"` - } - ``` - -#### UpdateModel 接口 -- **请求参数** (`UpdateModelReq`): - ```go - type UpdateModelReq struct { - ID string `json:"id" validate:"required"` - APIBase string `json:"api_base"` - APIHeader string `json:"api_header"` - APIKey string `json:"api_key"` - APIVersion string `json:"api_version"` - ModelName string `json:"model_name"` - Param *ModelParam `json:"param"` - Provider string `json:"provider" validate:"oneof=SiliconFlow OpenAI Ollama DeepSeek Moonshot AzureOpenAI BaiZhiCloud Hunyuan BaiLian Volcengine Other"` - ShowName string `json:"show_name"` - Status string `json:"status" validate:"oneof=active inactive"` - } - ``` -- **响应参数** (`UpdateModelResp`): - ```go - type UpdateModelResp struct { - Model Model `json:"model"` - } - ``` - -#### 通用Model结构 -```go -type Model struct { - ID string `json:"id"` - APIBase string `json:"api_base"` - APIHeader string `json:"api_header"` - APIKey string `json:"api_key"` - APIVersion string `json:"api_version"` - CreatedAt int64 `json:"created_at"` - Input int `json:"input"` - IsActive bool `json:"is_active"` - IsInternal bool `json:"is_internal"` - ModelName string `json:"model_name"` - ModelType string `json:"model_type"` - Output int `json:"output"` - Param *ModelParam `json:"param"` - Provider string `json:"provider"` - ShowName string `json:"show_name"` - Status string `json:"status"` - UpdatedAt int64 `json:"updated_at"` -} -``` - ### 3. 后端使用方式 在handler中调用 `listModel` 与 `checkModel` 业务逻辑: @@ -340,26 +221,17 @@ function App() { - 说明更改的目的和影响 - 关联相关的Issue(如果有) -### 代码规范 - -- **Go代码**: 遵循 `gofmt` 和 `golint` 标准 -- **TypeScript/React代码**: 遵循 ESLint 和 Prettier 配置 -- **提交信息**: 使用 [Conventional Commits](https://www.conventionalcommits.org/) 格式 -- **测试**: 新功能必须包含相应的单元测试 - ### 开发环境设置 1. **后端开发** ```bash go mod tidy - go run main.go ``` 2. **前端开发** ```bash cd ui/ModelModal - npm install - npm run dev + pnpm install ``` ## 📄 许可证 diff --git a/consts/model.go b/consts/model.go index 15d05aa..1bb417a 100644 --- a/consts/model.go +++ b/consts/model.go @@ -213,3 +213,7 @@ func ParseModelProvider(s string) ModelProvider { return ModelProviderOther } } + + +var ApiKeyBalanceKeyWords = []string{"quota", "billing", "balance", "payment required"} + diff --git a/test/main.go b/test/main.go index 53e6182..fec7900 100644 --- a/test/main.go +++ b/test/main.go @@ -38,7 +38,6 @@ func (p *ModelKit) GetModelList(c echo.Context) error { Data: nil, }) } - fmt.Println("list model req:", req) resp, err := usecase.ModelList(c.Request().Context(), &req) if err != nil { @@ -65,7 +64,6 @@ func (p *ModelKit) CheckModel(c echo.Context) error { Message: "参数绑定失败: " + err.Error(), }) } - fmt.Println("check model req:", req) resp, err := usecase.CheckModel(c.Request().Context(), &req) if err != nil { diff --git a/ui/ModelModal/src/ModelModal.tsx b/ui/ModelModal/src/ModelModal.tsx index fa6782e..1bbbb99 100644 --- a/ui/ModelModal/src/ModelModal.tsx +++ b/ui/ModelModal/src/ModelModal.tsx @@ -135,8 +135,47 @@ export const ModelModal: React.FC = ({ api_header: value.api_header || header, }) .then((res) => { - if (res.error) { - messageHandler.error("获取模型失败 " + res.error); + // 替换host即可成功请求的情况, 替换host继续请求 + if (res.error && res.error.includes("请将host替换为host.docker.internal")) { + // 解析base_url,将host替换为host.docker.internal + const url = new URL(value.base_url); + url.hostname = 'host.docker.internal'; + value.base_url = url.toString(); + modelService.listModel({ + model_type, + api_key: value.api_key, + base_url: value.base_url, + provider: value.provider as Exclude, + api_header: value.api_header || header, + }).then((res) => { + if (res.error) { + messageHandler.error("获取模型失败"); + setModelLoading(false); + } else { + setModelUserList( + (res.models || []) + .filter((item): item is { model: string } => !!item.model) + .sort((a, b) => a.model!.localeCompare(b.model!)) + ); + if ( + data && + (res.models || []).find((it) => it.model === data.model_name) + ) { + setValue('model_name', data.model_name!); + } else { + setValue('model_name', res.models?.[0]?.model || ''); + } + setSuccess(true); + } + }). + finally(() => { + setModelLoading(false); + }). + catch((res) => { + setModelLoading(false); + }); + } else if (res.error) { + messageHandler.error("获取模型失败"); setModelLoading(false); } else { setModelUserList( @@ -181,10 +220,28 @@ export const ModelModal: React.FC = ({ } ) .then((res) => { - if (res.error) { + // 错误处理 + if (res.error && res.error.includes("API地址末尾添加/v1, host替换为host.docker.internal")){ + // 解析base_url,将host替换为host.docker.internal + const url = new URL(value.base_url); + url.hostname = 'host.docker.internal'; + value.base_url = url.toString(); + value.base_url = value.base_url + '/v1'; + } else if (res.error && res.error.includes("请在API地址末尾添加/v1")) { + value.base_url = value.base_url + '/v1'; + } else if (res.error && res.error.includes("请将host替换为host.docker.internal")) { + // 解析base_url,将host替换为host.docker.internal + const url = new URL(value.base_url); + url.hostname = 'host.docker.internal'; + value.base_url = url.toString(); + } else if (res.error) { messageHandler.error("模型检查失败 " + res.error); setLoading(false); - } else if (data) { + return; + } + // end + + if (data) { modelService.updateModel({ api_key: value.api_key, model_type, @@ -207,7 +264,7 @@ export const ModelModal: React.FC = ({ }) .then((res) => { if (res.error) { - messageHandler.error("修改模型失败 " + res.error); + messageHandler.error("修改模型失败"); setLoading(false); } else { messageHandler.success('修改成功'); @@ -218,6 +275,7 @@ export const ModelModal: React.FC = ({ setLoading(false); }) .catch((res) => { + messageHandler.error("修改模型失败"); setLoading(false); }); } else { @@ -241,7 +299,7 @@ export const ModelModal: React.FC = ({ }) .then((res) => { if (res.error) { - messageHandler.error("添加模型失败 " + res.error); + messageHandler.error("添加模型失败"); setLoading(false); } else { messageHandler.success('添加成功'); @@ -252,11 +310,13 @@ export const ModelModal: React.FC = ({ setLoading(false); }) .catch((res) => { + messageHandler.error("添加模型失败"); setLoading(false); }); } }) .catch((res) => { + messageHandler.error("检查模型失败"); setLoading(false); }); }; @@ -514,6 +574,11 @@ export const ModelModal: React.FC = ({ /> )} /> + {providerBrand === 'Other' && ( + + 模型供应商必须支持与 OpenAI 兼容的 API 格式 + + )} = ({ sx={{ fontSize: 14, lineHeight: '32px', mt: 2 }} > - API Secret + API Key {providers[providerBrand].secretRequired && ( {' '} @@ -545,7 +610,7 @@ export const ModelModal: React.FC = ({ ) } > - 查看文档 + 获取API Key )} @@ -555,7 +620,7 @@ export const ModelModal: React.FC = ({ rules={{ required: { value: providers[providerBrand].secretRequired, - message: 'API Secret 不能为空', + message: 'API Key 不能为空', }, }} render={({ field }) => ( diff --git a/ui/ModelModal/src/constants/providers.ts b/ui/ModelModal/src/constants/providers.ts index 5f1e9af..45b391e 100644 --- a/ui/ModelModal/src/constants/providers.ts +++ b/ui/ModelModal/src/constants/providers.ts @@ -27,7 +27,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.bigmodel.cn/', + modelDocumentUrl: 'https://open.bigmodel.cn/usercenter/apikeys', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', }, DeepSeek: { @@ -41,7 +41,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://platform.deepseek.com/api-docs/', + modelDocumentUrl: 'https://platform.deepseek.com/api_keys', defaultBaseUrl: 'https://api.deepseek.com/v1', }, Hunyuan: { @@ -55,7 +55,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://cloud.tencent.com/document/product/1729/111007', + modelDocumentUrl: 'https://console.cloud.tencent.com/hunyuan/api-key', defaultBaseUrl: 'https://api.hunyuan.cloud.tencent.com/v1', }, BaiLian: { @@ -69,7 +69,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://help.aliyun.com/zh/model-studio/getting-started/', + modelDocumentUrl: 'https://bailian.console.aliyun.com/?tab=model#/api-key', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', }, Volcengine: { @@ -83,7 +83,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://www.volcengine.com/docs/82379/1182403', + modelDocumentUrl: 'https://console.volcengine.com/ark/region:ark+cn-beijing/apiKey', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/v3', }, OpenAI: { @@ -97,7 +97,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://platform.openai.com/docs', + modelDocumentUrl: 'https://platform.openai.com/api-keys', defaultBaseUrl: 'https://api.openai.com/v1', }, Ollama: { @@ -125,7 +125,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.siliconflow.cn/', + modelDocumentUrl: 'https://cloud.siliconflow.cn/me/account/ak', defaultBaseUrl: 'https://api.siliconflow.cn/v1', }, Moonshot: { @@ -139,7 +139,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://platform.moonshot.cn/docs/', + modelDocumentUrl: 'https://platform.moonshot.cn/console/api-keys', defaultBaseUrl: 'https://api.moonshot.cn/v1', }, AzureOpenAI: { @@ -153,7 +153,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://learn.microsoft.com/en-us/azure/ai-services/openai/', + modelDocumentUrl: 'https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/OpenAI', defaultBaseUrl: 'https://.openai.azure.com', }, Gemini: { @@ -167,7 +167,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://ai.google.dev/gemini-api/docs', + modelDocumentUrl: 'https://aistudio.google.com/app/apikey', defaultBaseUrl: 'https://generativelanguage.googleapis.com', }, Qiniu: { @@ -181,7 +181,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://developer.qiniu.com/aitokenapi', + modelDocumentUrl: 'https://portal.qiniu.com/ai-inference/api-key', defaultBaseUrl: 'https://api.qnaigc.com/v1', }, // NewAPI: { @@ -235,7 +235,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://platform.lingyiwanwu.com/docs', + modelDocumentUrl: 'https://platform.lingyiwanwu.com/apikeys', defaultBaseUrl: 'https://api.lingyiwanwu.com/v1', }, // Baichuan: { @@ -310,7 +310,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://www.ctyun.cn/products/ctxirang', + modelDocumentUrl: 'https://huiju.ctyun.cn/service/serviceGroup', defaultBaseUrl: 'https://wishub-x1.ctyun.cn/v1', }, TencentTI: { @@ -324,7 +324,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://cloud.tencent.com/document/product/1772', + modelDocumentUrl: 'https://console.cloud.tencent.com/lkeap/api', defaultBaseUrl: 'https://api.lkeap.cloud.tencent.com/v1', }, BaiDuQianFan: { @@ -338,7 +338,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://cloud.baidu.com/doc/index.html', + modelDocumentUrl: 'https://console.bce.baidu.com/iam/#/iam/apikey/list', defaultBaseUrl: 'https://qianfan.baidubce.com/v2', }, ModelScope: { @@ -352,7 +352,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://modelscope.cn/docs/model-service/API-Inference/intro', + modelDocumentUrl: 'https://modelscope.cn/my/myaccesstoken', defaultBaseUrl: 'https://api-inference.modelscope.cn/v1', }, Infini: { @@ -366,7 +366,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.infini-ai.com/gen-studio/api/maas.html#/operations/chatCompletions', + modelDocumentUrl: 'https://cloud.infini-ai.com/iam/secret/key', defaultBaseUrl: 'https://cloud.infini-ai.com/maas/v1', }, StepFun: { @@ -380,7 +380,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://platform.stepfun.com/docs/overview/concept', + modelDocumentUrl: 'https://platform.stepfun.com/interface-key', defaultBaseUrl: 'https://api.stepfun.com/v1', }, LanYun: { @@ -394,7 +394,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://archive.lanyun.net/#/maas/', + modelDocumentUrl: 'https://maas.lanyun.net/#/system/apiKey', defaultBaseUrl: 'https://maas-api.lanyun.net/v1', }, AlayaNew: { @@ -408,7 +408,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.alayanew.com/docs/modelService/interview?utm_source=cherrystudio', + modelDocumentUrl: ' https://www.alayanew.com/backend/register', defaultBaseUrl: 'https://deepseek.alayanew.com/v1', }, PPIO: { @@ -422,7 +422,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.cherry-ai.com/pre-basic/providers/ppio?invited_by=JYT9GD&utm_source=github_cherry-studio', + modelDocumentUrl: 'https://ppio.com/settings/key-management', defaultBaseUrl: 'https://api.ppinfra.com/v3/openai', }, AiHubMix: { @@ -436,7 +436,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://doc.aihubmix.com/', + modelDocumentUrl: 'https://aihubmix.com', defaultBaseUrl: 'https://aihubmix.com/v1', }, OcoolAI: { @@ -450,7 +450,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.ocoolai.com/', + modelDocumentUrl: 'https://one.ocoolai.com/token', defaultBaseUrl: 'https://api.ocoolai.com/v1', }, DMXAPI: { @@ -464,7 +464,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://dmxapi.cn/models.html#code-block', + modelDocumentUrl: 'https://www.dmxapi.cn/register', defaultBaseUrl: 'https://www.dmxapi.cn/v1', }, BurnCloud: { @@ -478,7 +478,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://ai.burncloud.com/docs', + modelDocumentUrl: 'https://ai.burncloud.com/console/token', defaultBaseUrl: 'https://ai.burncloud.com/v1', }, // Grok: { @@ -502,7 +502,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.api.nvidia.com/nim/reference/llm-apis', + modelDocumentUrl: 'https://build.nvidia.com/?integrate_nim=true&hosted_api=true&modal=integrate-nim', defaultBaseUrl: 'https://integrate.api.nvidia.com/v1', }, TokenFlux: { @@ -530,7 +530,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://302ai.apifox.cn/api-147522039', + modelDocumentUrl: 'https://dash.302.ai/apis/list', defaultBaseUrl: 'https://api.302.ai/v1', }, Cephalon: { @@ -544,7 +544,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://cephalon.cloud/apitoken/1864244127731589124', + modelDocumentUrl: 'https://cephalon.cloud/api', defaultBaseUrl: 'https://cephalon.cloud/user-center/v1/model', }, OpenRouter: { @@ -558,7 +558,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://openrouter.ai/docs/quick-start', + modelDocumentUrl: 'https://openrouter.ai/settings/keys', defaultBaseUrl: 'https://openrouter.ai/api/v1', }, Fireworks: { @@ -572,7 +572,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: true, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.fireworks.ai/getting-started/introduction', + modelDocumentUrl: 'https://app.fireworks.ai/settings/users/api-keys', defaultBaseUrl: 'https://api.fireworks.ai/inference/v1', }, Mistral: { @@ -586,7 +586,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.mistral.ai', + modelDocumentUrl: 'https://console.mistral.ai/api-keys/', defaultBaseUrl: 'https://api.mistral.ai/v1', }, Perplexity: { @@ -600,7 +600,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.perplexity.ai/home', + modelDocumentUrl: 'https://www.perplexity.ai/settings/api', defaultBaseUrl: 'https://api.perplexity.ai', }, Hyperbolic: { @@ -614,7 +614,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { code: false, embedding: false, rerank: false, - modelDocumentUrl: 'https://docs.hyperbolic.xyz', + modelDocumentUrl: 'https://app.hyperbolic.xyz/settings', defaultBaseUrl: 'https://api.hyperbolic.xyz/v1', }, // VoyageAI: { @@ -642,7 +642,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = { cn: '其他', icon: 'icon-a-AIshezhi', urlWrite: true, - secretRequired: true, + secretRequired: false, customHeader: false, chat: true, code: true, diff --git a/usecase/modelkit.go b/usecase/modelkit.go index c6b3513..db46b12 100644 --- a/usecase/modelkit.go +++ b/usecase/modelkit.go @@ -31,37 +31,8 @@ import ( "github.com/chaitin/ModelKit/utils" ) -// reqModelListApi 获取OpenAI兼容API的模型列表 -// 使用泛型和接口抽象来支持不同供应商的响应格式 -func reqModelListApi[T domain.ModelResponseParser](req *domain.ModelListReq, httpClient *http.Client, responseType T) ([]domain.ModelListItem, error) { - u, err := url.Parse(req.BaseURL) - if err != nil { - return nil, err - } - u.Path = path.Join(u.Path, "/models") - - client := request.NewClient(u.Scheme, u.Host, httpClient.Timeout, request.WithClient(httpClient)) - query, err := utils.GetQuery(req) - if err != nil { - return nil, err - } - resp, err := request.Get[T]( - client, u.Path, - request.WithHeader( - request.Header{ - "Authorization": fmt.Sprintf("Bearer %s", req.APIKey), - }, - ), - request.WithQuery(query), - ) - if err != nil { - return nil, err - } - - return (*resp).ParseModels(), nil -} - func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelListResp, error) { + log.Printf("ModelList req: provider=%s, baseURL=%s", req.Provider, req.BaseURL) httpClient := &http.Client{ Timeout: time.Second * 30, Transport: &http.Transport{ @@ -84,23 +55,29 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList }, nil // 以下模型供应商需要特殊处理 case consts.ModelProviderOllama: - // get from ollama http://10.10.16.24:11434/api/tags - u, err := url.Parse(req.BaseURL) + resp, err := ollamaListModel(req.BaseURL, httpClient, req.APIHeader) + // 尝试通过替换baseURL的host为host.docker.internal解决ollama list err if err != nil { - return &domain.ModelListResp{ - Error: err.Error(), - }, nil - } - u.Path = "/api/tags" - client := request.NewClient(u.Scheme, u.Host, httpClient.Timeout, request.WithClient(httpClient)) - - h := request.Header{} - if req.APIHeader != "" { - headers := request.GetHeaderMap(req.APIHeader) - maps.Copy(h, headers) + newBaseURL, err := baseURLReplaceHost(req.BaseURL) + if err != nil { + return &domain.ModelListResp{ + Error: err.Error(), + }, nil + } + // 替换host后与原始host相同,无需继续请求 + if newBaseURL == req.BaseURL { + return resp, nil + } + resp, err = ollamaListModel(newBaseURL, httpClient, req.APIHeader) + // 替换后可以成功请求 + if err == nil { + return &domain.ModelListResp{ + Error: fmt.Errorf("请将host替换为host.docker.internal").Error(), + }, nil + } } - - return request.Get[domain.ModelListResp](client, u.Path, request.WithHeader(h)) + // end + return resp, nil case consts.ModelProviderGemini: client, err := generativeGenai.NewClient(ctx, option.WithAPIKey(req.APIKey)) if err != nil { @@ -171,9 +148,11 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList } func CheckModel(ctx context.Context, req *domain.CheckModelReq) (*domain.CheckModelResp, error) { + log.Printf("CheckModel req: provider=%s, model=%s, baseURL=%s", req.Provider, req.Model, req.BaseURL) checkResp := &domain.CheckModelResp{} modelType := consts.ParseModelType(req.Type) + // embedding 与 rerank 模型检查 if modelType == consts.ModelTypeEmbedding || modelType == consts.ModelTypeRerank { url := req.BaseURL reqBody := map[string]any{} @@ -227,34 +206,45 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq) (*domain.CheckMo } return checkResp, nil } + // end provider := consts.ParseModelProvider(req.Provider) - chatModel, err := GetChatModel(ctx, &domain.ModelMetadata{ - Provider: provider, - ModelName: req.Model, - APIKey: req.APIKey, - APIHeader: req.APIHeader, - BaseURL: req.BaseURL, - APIVersion: req.APIVersion, - ModelType: modelType, - }) - if err != nil { + + resp, err := getChatModelGenerateChat(ctx, provider, modelType, req.BaseURL, req) + // 其他模型供应商,尝试修复baseURL + if err != nil && provider == consts.ModelProviderOther { + res, err := fixProviderOtherCheckErr(ctx, req, provider, modelType) + if err != nil { + checkResp.Error = err.Error() + return checkResp, nil + } + if res != "" { + checkResp.Error = res + return checkResp, nil + } + } + // end + if err != nil && provider != consts.ModelProviderOther { + // 检查错误信息中是否包含余额相关关键词 + errorMsg := strings.ToLower(err.Error()) + for _, keyword := range consts.ApiKeyBalanceKeyWords { + if strings.Contains(errorMsg, keyword) { + checkResp.Error = "API Key余额不足" + return checkResp, nil + } + } checkResp.Error = err.Error() return checkResp, nil } - resp, err := chatModel.Generate(ctx, []*schema.Message{ - schema.SystemMessage("You are a helpful assistant."), - schema.UserMessage("hi"), - }) if err != nil { checkResp.Error = err.Error() return checkResp, nil } - content := resp.Content - if content == "" { + + if resp.Content == "" { checkResp.Error = "生成内容失败" return checkResp, nil } - checkResp.Content = content + checkResp.Content = resp.Content return checkResp, nil } @@ -341,3 +331,160 @@ func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseC return chatModel, nil } } + +// 以下是辅助函数,用于处理模型列表和检查相关的功能 + +func ollamaListModel(baseURL string, httpClient *http.Client, apiHeader string) (*domain.ModelListResp, error) { + // get from ollama http://10.10.16.24:11434/api/tags + u, err := url.Parse(baseURL) + if err != nil { + return &domain.ModelListResp{ + Error: err.Error(), + }, nil + } + u.Path = "/api/tags" + client := request.NewClient(u.Scheme, u.Host, httpClient.Timeout, request.WithClient(httpClient)) + + h := request.Header{} + if apiHeader != "" { + headers := request.GetHeaderMap(apiHeader) + maps.Copy(h, headers) + } + return request.Get[domain.ModelListResp](client, u.Path, request.WithHeader(h)) +} + +// 通过修复baseURL尝试修复其它供应商check err, 返回用于提示用户如何修复错误 +func fixProviderOtherCheckErr(ctx context.Context, req *domain.CheckModelReq, provider consts.ModelProvider, modelType consts.ModelType) (string, error) { + log.Println("尝试修复") + // 尝试添加v1 + fixedBaseURL, err := baseURLAddV1(req.BaseURL) + // baseurl解析错误,直接返回 + if err != nil { + log.Printf("baseurl解析错误: %v", err) + return "", err + } + + // baseurl被修改,重新请求 + if fixedBaseURL != req.BaseURL { + _, err := getChatModelGenerateChat(ctx, provider, modelType, fixedBaseURL, req) + // 添加v1有效, 提示用户 + if err == nil { + log.Println("添加v1有效") + return "请在API地址末尾添加/v1", nil + } + log.Println("添加v1无效", err) + } + + // url末尾添加v1无效,尝试替换host为host.docker.internal + fixedBaseURL, err = baseURLReplaceHost(req.BaseURL) + // baseurl解析错误,直接返回 + if err != nil { + return "", err + } + + if fixedBaseURL != req.BaseURL { + _, err := getChatModelGenerateChat(ctx, provider, modelType, fixedBaseURL, req) + // 替换host有效, 提示用户 + if err == nil { + return "请将host替换为host.docker.internal", nil + } + } + + // 替换host也无效,尝试添加v1与替换host + fixedBaseURL, err = baseURLAddV1(req.BaseURL) + // baseurl解析错误,直接返回 + if err != nil { + return "", err + } + fixedBaseURL, err = baseURLReplaceHost(fixedBaseURL) + // baseurl解析错误,直接返回 + if err != nil { + return "", err + } + // baseurl被修改,重新请求 + if fixedBaseURL != req.BaseURL { + _, err := getChatModelGenerateChat(ctx, provider, modelType, fixedBaseURL, req) + // 添加v1与替换host有效, 提示用户 + if err == nil { + return "API地址末尾添加/v1, host替换为host.docker.internal", nil + } + } + return "", nil +} + +func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider, modelType consts.ModelType, baseURL string, req *domain.CheckModelReq) (*schema.Message, error) { + chatModel, err := GetChatModel(ctx, &domain.ModelMetadata{ + Provider: provider, + ModelName: req.Model, + APIKey: req.APIKey, + APIHeader: req.APIHeader, + BaseURL: baseURL, + APIVersion: req.APIVersion, + ModelType: modelType, + }) + if err != nil { + return nil, err + } + + return chatModel.Generate(ctx, []*schema.Message{ + schema.SystemMessage("You are a helpful assistant."), + schema.UserMessage("hi"), + }) +} + +// baseURL添加/v1 +func baseURLAddV1(inputURL string) (string, error) { + rawURL, err := url.Parse(inputURL) + if err != nil { + return "", err + } + // 没有path, 则添加/v1 + if rawURL.Path == "" { + rawURL.Path = "/v1" + } + return rawURL.String(), nil +} + +// baseURL的host换成host.docker.internal +func baseURLReplaceHost(inputURL string) (string, error) { + rawURL, err := url.Parse(inputURL) + if err != nil { + return "", err + } + hostAddress := "host.docker.internal" + + if rawURL.Host != hostAddress { + rawURL.Host = hostAddress + } + return rawURL.String(), nil +} + +// reqModelListApi 获取OpenAI兼容API的模型列表 +// 使用泛型和接口抽象来支持不同供应商的响应格式 +func reqModelListApi[T domain.ModelResponseParser](req *domain.ModelListReq, httpClient *http.Client, responseType T) ([]domain.ModelListItem, error) { + u, err := url.Parse(req.BaseURL) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, "/models") + + client := request.NewClient(u.Scheme, u.Host, httpClient.Timeout, request.WithClient(httpClient)) + query, err := utils.GetQuery(req) + if err != nil { + return nil, err + } + resp, err := request.Get[T]( + client, u.Path, + request.WithHeader( + request.Header{ + "Authorization": fmt.Sprintf("Bearer %s", req.APIKey), + }, + ), + request.WithQuery(query), + ) + if err != nil { + return nil, err + } + + return (*resp).ParseModels(), nil +}