Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/ui_example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ function App() {
modelService={localModelService}
language="zh-CN"
messageComponent={messageComponent}
is_close_model_remark={true}
/>


Expand Down
11 changes: 6 additions & 5 deletions ui/ModelModal/src/ModelModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
modelService,
language = 'zh-CN',
messageComponent,
is_close_model_remark = false,
}: ModelModalProps) => {
const theme = useTheme();

Expand Down Expand Up @@ -578,7 +579,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
/>
)}
/>
{(modelUserList.length !== 0 || providerBrand === 'Other') && (
{(modelUserList.length !== 0 || providerBrand === 'Other') && !is_close_model_remark && (
<>
<Box sx={{ fontSize: 14, lineHeight: '32px', mt: 2 }}>
模型备注
Expand All @@ -591,7 +592,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
name='show_name'
rules={{
required: {
value: true,
value: !is_close_model_remark,
message: '模型备注不能为空',
},
}}
Expand Down Expand Up @@ -661,7 +662,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
)}
/>
<Box sx={{ fontSize: 12, color: 'error.main', mt: 1 }}>
需要与模型供应商提供的名称完全一致,不要随便填写
需要与模型供应商提供的名称完全一致
</Box>
</>
) : modelUserList.length === 0 ? (
Expand Down Expand Up @@ -751,8 +752,8 @@ export const ModelModal: React.FC<ModelModalProps> = ({

</>
)}
{/* 高级设置部分 - 在选择了模型或者是其它供应商时显示 */}
{(modelUserList.length !== 0 || providerBrand === 'Other') && (
{/* 高级设置部分 - 在选择了模型或者是其它供应商时显示,但不包括embedding、rerank、reranker类型 */}
{(modelUserList.length !== 0 || providerBrand === 'Other') && !['embedding', 'rerank', 'reranker'].includes(model_type) && (
<Box sx={{ mt: 2 }}>
<Accordion
sx={{
Expand Down
2 changes: 1 addition & 1 deletion ui/ModelModal/src/constants/locale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const LOCALE_MESSAGES: Record<'zh-CN' | 'en-US', Record<string, string>>
'urlRequired': 'URL 不能为空',
'secretRequired': 'API Secret 不能为空',
'modelNameRequired': '模型名称不能为空',
'modelNameHint': '需要与模型供应商提供的名称完全一致,不要随便填写',
'modelNameHint': '需要与模型供应商提供的名称完全一致',

// 成功/错误消息
'addSuccess': '添加成功',
Expand Down
2 changes: 1 addition & 1 deletion ui/ModelModal/src/constants/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ export const DEFAULT_MODEL_PROVIDERS: ModelProviderMap = {
urlWrite: false,
secretRequired: true,
customHeader: false,
chat: true,
chat: false,
code: true,
embedding: false,
rerank: false,
Expand Down
1 change: 1 addition & 0 deletions ui/ModelModal/src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,5 @@ export interface ModelModalProps {
modelService: ModelService;
language?: 'zh-CN' | 'en-US';
messageComponent?: MessageComponent;
is_close_model_remark?: boolean;
}
26 changes: 13 additions & 13 deletions usecase/modelkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ import (
func reqModelListApi[T domain.ModelResponseParser](req *domain.ModelListReq, httpClient *http.Client, responseType T) ([]domain.ModelListItem, error) {
u, err := url.Parse(req.BaseURL)
if err != nil {
return nil, fmt.Errorf("解析BaseURL失败: %w", err)
return nil, err
}
u.Path = path.Join(u.Path, "/models")

client := request.NewClient(u.Scheme, u.Host, httpClient.Timeout, request.WithClient(httpClient))
query, err := utils.GetQuery(req)
if err != nil {
return nil, fmt.Errorf("获取查询参数失败: %w", err)
return nil, err
}
resp, err := request.Get[T](
client, u.Path,
Expand All @@ -55,7 +55,7 @@ func reqModelListApi[T domain.ModelResponseParser](req *domain.ModelListReq, htt
request.WithQuery(query),
)
if err != nil {
return nil, fmt.Errorf("请求模型列表API失败: %w", err)
return nil, err
}

return (*resp).ParseModels(), nil
Expand Down Expand Up @@ -87,7 +87,7 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList
u, err := url.Parse(req.BaseURL)
if err != nil {
return &domain.ModelListResp{
Error: fmt.Errorf("Ollama解析BaseURL失败: %w", err).Error(),
Error: err.Error(),
}, nil
}
u.Path = "/api/tags"
Expand All @@ -104,7 +104,7 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList
client, err := generativeGenai.NewClient(ctx, option.WithAPIKey(req.APIKey))
if err != nil {
return &domain.ModelListResp{
Error: fmt.Errorf("创建Gemini客户端失败: %w", err).Error(),
Error: err.Error(),
}, nil
}
defer func() {
Expand Down Expand Up @@ -148,7 +148,7 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList
models, err := reqModelListApi(req, httpClient, &domain.GithubResp{})
if err != nil {
return &domain.ModelListResp{
Error: fmt.Errorf("获取Github模型列表失败: %w", err).Error(),
Error: err.Error(),
}, nil
}
return &domain.ModelListResp{
Expand All @@ -160,7 +160,7 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList

if err != nil {
return &domain.ModelListResp{
Error: fmt.Errorf("获取OpenAI兼容模型列表失败: %w", err).Error(),
Error: err.Error(),
}, nil
}
return &domain.ModelListResp{
Expand Down Expand Up @@ -288,15 +288,15 @@ func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseC
Temperature: temperature,
})
if err != nil {
return nil, fmt.Errorf("创建DeepSeek聊天模型失败: %w", err)
return nil, err
}
return chatModel, nil
case consts.ModelProviderGemini:
client, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: model.APIKey,
})
if err != nil {
return nil, fmt.Errorf("创建Genai客户端失败: %w", err)
return nil, err
}

chatModel, err := gemini.NewChatModel(ctx, &gemini.Config{
Expand All @@ -308,13 +308,13 @@ func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseC
},
})
if err != nil {
return nil, fmt.Errorf("创建Gemini聊天模型失败: %w", err)
return nil, err
}
return chatModel, nil
case consts.ModelProviderOllama:
baseUrl, err := utils.URLRemovePath(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("解析Ollama URL失败: %w", err)
return nil, err
}

chatModel, err := ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
Expand All @@ -326,14 +326,14 @@ func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseC
},
})
if err != nil {
return nil, fmt.Errorf("创建Ollama聊天模型失败: %w", err)
return nil, err
}
return chatModel, nil
// 兼容 openai api
default:
chatModel, err := openai.NewChatModel(ctx, config)
if err != nil {
return nil, fmt.Errorf("创建OpenAI兼容聊天模型失败: %w", err)
return nil, err
}
return chatModel, nil
}
Expand Down