Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"

Expand Down Expand Up @@ -35,11 +36,6 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
return nil, errors.New("model configuration is required")
}

if cfg.Provider != "openai" {
slog.Error("OpenAI client creation failed", "error", "model type must be 'openai'", "actual_type", cfg.Provider)
return nil, errors.New("model type must be 'openai'")
}

var globalOptions options.ModelOptions
for _, opt := range opts {
opt(&globalOptions)
Expand All @@ -53,13 +49,31 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}
authToken := env.Get(ctx, key)
if authToken == "" {
return nil, errors.New("OPENAI_API_KEY environment variable is required")
return nil, fmt.Errorf("%s environment variable is required", key)
}

if cfg.Provider == "azure" {
openaiConfig = openai.DefaultAzureConfig(authToken, cfg.BaseURL)
} else {
openaiConfig = openai.DefaultConfig(authToken)
}

openaiConfig = openai.DefaultConfig(authToken)
if cfg.BaseURL != "" {
openaiConfig.BaseURL = cfg.BaseURL
}

// TODO: Move this logic to ProviderAliases as a config function
if cfg.ProviderOpts != nil {
switch cfg.Provider { //nolint:gocritic
case "azure":
if apiVersion, exists := cfg.ProviderOpts["api_version"]; exists {
slog.Debug("Setting API version", "api_version", apiVersion)
if apiVersionStr, ok := apiVersion.(string); ok {
openaiConfig.APIVersion = apiVersionStr
}
}
}
}
} else {
authToken := desktop.GetToken(ctx)
if authToken == "" {
Expand Down
82 changes: 75 additions & 7 deletions pkg/model/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@ import (
"github.com/docker/cagent/pkg/tools"
)

// Alias defines the configuration for a provider alias
type Alias struct {
ApiType string // The actual API type to use (openai, anthropic, etc.)
BaseURL string // Default base URL for the provider
TokenEnvVar string // Environment variable name for the API token
}

// ProviderAliases maps provider names to their corresponding configurations
var ProviderAliases = map[string]Alias{
"requesty": {
ApiType: "openai",
BaseURL: "https://router.requesty.ai/v1",
TokenEnvVar: "REQUESTY_API_KEY",
},
"azure": {
ApiType: "openai",
TokenEnvVar: "AZURE_API_KEY",
},
}

// Provider defines the interface for model providers
type Provider interface {
// ID returns the model provider ID
Expand All @@ -37,21 +57,69 @@ type Provider interface {
func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {
slog.Debug("Creating model provider", "type", cfg.Provider, "model", cfg.Model)

switch cfg.Provider {
// Apply provider alias defaults to the config
enhancedCfg := applyProviderDefaults(cfg)
apiType := ""
if alias, exists := ProviderAliases[cfg.Provider]; exists {
apiType = alias.ApiType
}

// Resolve the actual API type from aliases or direct specification
providerType := resolveProviderType(cfg.Provider, apiType)

switch providerType {
case "openai":
return openai.NewClient(ctx, cfg, env, opts...)
return openai.NewClient(ctx, enhancedCfg, env, opts...)

case "anthropic":
return anthropic.NewClient(ctx, cfg, env, opts...)
return anthropic.NewClient(ctx, enhancedCfg, env, opts...)

case "google":
return gemini.NewClient(ctx, cfg, env, opts...)
return gemini.NewClient(ctx, enhancedCfg, env, opts...)

case "dmr":
return dmr.NewClient(ctx, cfg, opts...)
return dmr.NewClient(ctx, enhancedCfg, opts...)

default:
slog.Error("Unknown provider type", "type", cfg.Provider)
return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider)
slog.Error("Unknown provider type", "type", providerType)
return nil, fmt.Errorf("unknown provider type: %s", providerType)
}
}

// applyProviderDefaults applies default configuration from provider aliases to the model config
// This sets default base URLs and token keys if not already specified
func applyProviderDefaults(cfg *latest.ModelConfig) *latest.ModelConfig {
// Create a copy to avoid modifying the original
enhancedCfg := *cfg

// Check if provider has alias configuration
if alias, exists := ProviderAliases[cfg.Provider]; exists {
// Set default base URL if not already specified
if enhancedCfg.BaseURL == "" && alias.BaseURL != "" {
enhancedCfg.BaseURL = alias.BaseURL
}

// Set default token key if not already specified
if enhancedCfg.TokenKey == "" && alias.TokenEnvVar != "" {
enhancedCfg.TokenKey = alias.TokenEnvVar
}
}

return &enhancedCfg
}

// resolveProviderType resolves the actual API type from the provider name and optional apiType
func resolveProviderType(provider, apiType string) string {
// If apiType is explicitly provided, use it
if apiType != "" {
return apiType
}

// Check if provider has an alias mapping
if resolved, exists := ProviderAliases[provider]; exists {
return resolved.ApiType
}

// Fall back to the provider name itself
return provider
}
2 changes: 1 addition & 1 deletion pkg/modelsdev/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
id = actualID
}

parts := strings.Split(id, "/")
parts := strings.SplitN(id, "/", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid model ID: %q", id)
}
Expand Down
25 changes: 17 additions & 8 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,23 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme

// Models
if runtimeConfig.ModelsGateway == "" {
for _, model := range cfg.Models {
switch model.Provider {
case "openai":
requiredEnv["OPENAI_API_KEY"] = true
case "anthropic":
requiredEnv["ANTHROPIC_API_KEY"] = true
case "google":
requiredEnv["GOOGLE_API_KEY"] = true
for name := range cfg.Models {
model := cfg.Models[name]
// Use the token environment variable from the alias if available
if alias, exists := provider.ProviderAliases[model.Provider]; exists {
if alias.TokenEnvVar != "" {
requiredEnv[alias.TokenEnvVar] = true
}
} else {
// Fallback to hardcoded mappings for unknown providers
switch model.Provider {
case "openai":
requiredEnv["OPENAI_API_KEY"] = true
case "anthropic":
requiredEnv["ANTHROPIC_API_KEY"] = true
case "google":
requiredEnv["GOOGLE_API_KEY"] = true
}
}
}

Expand Down