diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index f9f5abd5a..fed451f67 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "log/slog" "strings" @@ -35,11 +36,6 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, errors.New("model configuration is required") } - if cfg.Provider != "openai" { - slog.Error("OpenAI client creation failed", "error", "model type must be 'openai'", "actual_type", cfg.Provider) - return nil, errors.New("model type must be 'openai'") - } - var globalOptions options.ModelOptions for _, opt := range opts { opt(&globalOptions) @@ -53,13 +49,31 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } authToken := env.Get(ctx, key) if authToken == "" { - return nil, errors.New("OPENAI_API_KEY environment variable is required") + return nil, fmt.Errorf("%s environment variable is required", key) + } + + if cfg.Provider == "azure" { + openaiConfig = openai.DefaultAzureConfig(authToken, cfg.BaseURL) + } else { + openaiConfig = openai.DefaultConfig(authToken) } - openaiConfig = openai.DefaultConfig(authToken) if cfg.BaseURL != "" { openaiConfig.BaseURL = cfg.BaseURL } + + // TODO: Move this logic to ProviderAliases as a config function + if cfg.ProviderOpts != nil { + switch cfg.Provider { //nolint:gocritic + case "azure": + if apiVersion, exists := cfg.ProviderOpts["api_version"]; exists { + slog.Debug("Setting API version", "api_version", apiVersion) + if apiVersionStr, ok := apiVersion.(string); ok { + openaiConfig.APIVersion = apiVersionStr + } + } + } + } } else { authToken := desktop.GetToken(ctx) if authToken == "" { diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index d346e7226..39f7fdc1e 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -16,6 +16,26 @@ import ( "github.com/docker/cagent/pkg/tools" ) +// Alias defines the configuration for a provider alias +type Alias struct { + ApiType string // The actual API type to use (openai, anthropic, etc.) + BaseURL string // Default base URL for the provider + TokenEnvVar string // Environment variable name for the API token +} + +// ProviderAliases maps provider names to their corresponding configurations +var ProviderAliases = map[string]Alias{ + "requesty": { + ApiType: "openai", + BaseURL: "https://router.requesty.ai/v1", + TokenEnvVar: "REQUESTY_API_KEY", + }, + "azure": { + ApiType: "openai", + TokenEnvVar: "AZURE_API_KEY", + }, +} + // Provider defines the interface for model providers type Provider interface { // ID returns the model provider ID @@ -37,21 +57,69 @@ type Provider interface { func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { slog.Debug("Creating model provider", "type", cfg.Provider, "model", cfg.Model) - switch cfg.Provider { + // Apply provider alias defaults to the config + enhancedCfg := applyProviderDefaults(cfg) + apiType := "" + if alias, exists := ProviderAliases[cfg.Provider]; exists { + apiType = alias.ApiType + } + + // Resolve the actual API type from aliases or direct specification + providerType := resolveProviderType(cfg.Provider, apiType) + + switch providerType { case "openai": - return openai.NewClient(ctx, cfg, env, opts...) + return openai.NewClient(ctx, enhancedCfg, env, opts...) case "anthropic": - return anthropic.NewClient(ctx, cfg, env, opts...) + return anthropic.NewClient(ctx, enhancedCfg, env, opts...) case "google": - return gemini.NewClient(ctx, cfg, env, opts...) + return gemini.NewClient(ctx, enhancedCfg, env, opts...) case "dmr": - return dmr.NewClient(ctx, cfg, opts...) + return dmr.NewClient(ctx, enhancedCfg, opts...) default: - slog.Error("Unknown provider type", "type", cfg.Provider) - return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider) + slog.Error("Unknown provider type", "type", providerType) + return nil, fmt.Errorf("unknown provider type: %s", providerType) } } + +// applyProviderDefaults applies default configuration from provider aliases to the model config +// This sets default base URLs and token keys if not already specified +func applyProviderDefaults(cfg *latest.ModelConfig) *latest.ModelConfig { + // Create a copy to avoid modifying the original + enhancedCfg := *cfg + + // Check if provider has alias configuration + if alias, exists := ProviderAliases[cfg.Provider]; exists { + // Set default base URL if not already specified + if enhancedCfg.BaseURL == "" && alias.BaseURL != "" { + enhancedCfg.BaseURL = alias.BaseURL + } + + // Set default token key if not already specified + if enhancedCfg.TokenKey == "" && alias.TokenEnvVar != "" { + enhancedCfg.TokenKey = alias.TokenEnvVar + } + } + + return &enhancedCfg +} + +// resolveProviderType resolves the actual API type from the provider name and optional apiType +func resolveProviderType(provider, apiType string) string { + // If apiType is explicitly provided, use it + if apiType != "" { + return apiType + } + + // Check if provider has an alias mapping + if resolved, exists := ProviderAliases[provider]; exists { + return resolved.ApiType + } + + // Fall back to the provider name itself + return provider +} diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index c39ea5740..e66115d76 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -120,7 +120,7 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { id = actualID } - parts := strings.Split(id, "/") + parts := strings.SplitN(id, "/", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid model ID: %q", id) } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 8782b5180..3412b7d0b 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -85,14 +85,23 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme // Models if runtimeConfig.ModelsGateway == "" { - for _, model := range cfg.Models { - switch model.Provider { - case "openai": - requiredEnv["OPENAI_API_KEY"] = true - case "anthropic": - requiredEnv["ANTHROPIC_API_KEY"] = true - case "google": - requiredEnv["GOOGLE_API_KEY"] = true + for name := range cfg.Models { + model := cfg.Models[name] + // Use the token environment variable from the alias if available + if alias, exists := provider.ProviderAliases[model.Provider]; exists { + if alias.TokenEnvVar != "" { + requiredEnv[alias.TokenEnvVar] = true + } + } else { + // Fallback to hardcoded mappings for unknown providers + switch model.Provider { + case "openai": + requiredEnv["OPENAI_API_KEY"] = true + case "anthropic": + requiredEnv["ANTHROPIC_API_KEY"] = true + case "google": + requiredEnv["GOOGLE_API_KEY"] = true + } } }