Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions usecase/modelkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList
resp, err := ollamaListModel(req.BaseURL, httpClient, req.APIHeader)
// 尝试通过替换baseURL的host为host.docker.internal解决ollama list err
if err != nil {
newBaseURL, err := baseURLReplaceHost(req.BaseURL)
if err != nil {
newBaseURL, replaceHostErr := baseURLReplaceHost(req.BaseURL)
if replaceHostErr != nil {
return &domain.ModelListResp{
Error: err.Error(),
}, nil
Expand All @@ -68,16 +68,16 @@ func ModelList(ctx context.Context, req *domain.ModelListReq) (*domain.ModelList
if newBaseURL == req.BaseURL {
return resp, nil
}
resp, err = ollamaListModel(newBaseURL, httpClient, req.APIHeader)
_, listErr := ollamaListModel(newBaseURL, httpClient, req.APIHeader)
// 替换后可以成功请求
if err == nil {
if listErr == nil {
return &domain.ModelListResp{
Error: fmt.Errorf("请将host替换为host.docker.internal").Error(),
}, nil
}
}
// end
return resp, nil
return resp, err
case consts.ModelProviderGemini:
client, err := generativeGenai.NewClient(ctx, option.WithAPIKey(req.APIKey))
if err != nil {
Expand Down Expand Up @@ -207,12 +207,13 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq) (*domain.CheckMo
return checkResp, nil
}
// end
// end
provider := consts.ParseModelProvider(req.Provider)

resp, err := getChatModelGenerateChat(ctx, provider, modelType, req.BaseURL, req)
// 其他模型供应商,尝试修复baseURL
if err != nil && provider == consts.ModelProviderOther {
res, err := fixProviderOtherCheckErr(ctx, req, provider, modelType)
res, err := tryFixBaseURL(ctx, req, provider, modelType)
if err != nil {
checkResp.Error = err.Error()
return checkResp, nil
Expand Down Expand Up @@ -354,13 +355,12 @@ func ollamaListModel(baseURL string, httpClient *http.Client, apiHeader string)
}

// 通过修复baseURL尝试修复其它供应商check err, 返回用于提示用户如何修复错误
func fixProviderOtherCheckErr(ctx context.Context, req *domain.CheckModelReq, provider consts.ModelProvider, modelType consts.ModelType) (string, error) {
func tryFixBaseURL(ctx context.Context, req *domain.CheckModelReq, provider consts.ModelProvider, modelType consts.ModelType) (string, error) {
log.Println("尝试修复")
// 尝试添加v1
fixedBaseURL, err := baseURLAddV1(req.BaseURL)
// baseurl解析错误,直接返回
if err != nil {
log.Printf("baseurl解析错误: %v", err)
return "", err
}

Expand All @@ -372,9 +372,10 @@ func fixProviderOtherCheckErr(ctx context.Context, req *domain.CheckModelReq, pr
log.Println("添加v1有效")
return "请在API地址末尾添加/v1", nil
}
log.Println("添加v1无效", err)
log.Println("添加v1无效", err, fixedBaseURL)
}

log.Println("尝试替换host")
// url末尾添加v1无效,尝试替换host为host.docker.internal
fixedBaseURL, err = baseURLReplaceHost(req.BaseURL)
// baseurl解析错误,直接返回
Expand All @@ -388,9 +389,11 @@ func fixProviderOtherCheckErr(ctx context.Context, req *domain.CheckModelReq, pr
if err == nil {
return "请将host替换为host.docker.internal", nil
}
log.Println("替换host无效", err, fixedBaseURL)
}

// 替换host也无效,尝试添加v1与替换host
log.Println("尝试添加v1与替换host")
fixedBaseURL, err = baseURLAddV1(req.BaseURL)
// baseurl解析错误,直接返回
if err != nil {
Expand All @@ -408,6 +411,7 @@ func fixProviderOtherCheckErr(ctx context.Context, req *domain.CheckModelReq, pr
if err == nil {
return "API地址末尾添加/v1, host替换为host.docker.internal", nil
}
log.Println("添加v1与替换host无效", err, fixedBaseURL)
}
return "", nil
}
Expand Down Expand Up @@ -453,8 +457,12 @@ func baseURLReplaceHost(inputURL string) (string, error) {
}
hostAddress := "host.docker.internal"

if rawURL.Host != hostAddress {
rawURL.Host = hostAddress
if rawURL.Hostname() != hostAddress {
if rawURL.Port() != "" {
rawURL.Host = hostAddress + ":" + rawURL.Port()
} else {
rawURL.Host = hostAddress
}
}
return rawURL.String(), nil
}
Expand Down