Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions internal/models/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
return createGoogleProvider(ctx, config, modelName)
case "ollama":
return createOllamaProvider(ctx, config, modelName)
case "azure":
return createAzureOpenAIProvider(ctx, config, modelName)
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
Expand All @@ -101,6 +103,50 @@ func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
return nil
}

func createAzureOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
apiKey := config.ProviderAPIKey
if apiKey == "" {
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
}
if apiKey == "" {
return nil, fmt.Errorf("Azure OpenAI API key not provided. Use --provider-api-key flag or AZURE_OPENAI_API_KEY environment variable")
}

azureConfig := &openai.ChatModelConfig{
APIKey: apiKey,
Model: modelName,
ByAzure: true, // Indicate this is an Azure OpenAI model
APIVersion: "2025-01-01-preview", // Default Azure OpenAI API version
}

if config.ProviderURL != "" {
azureConfig.BaseURL = config.ProviderURL
} else {
azureConfig.BaseURL = os.Getenv("AZURE_OPENAI_BASE_URL")
}
if azureConfig.BaseURL == "" {
return nil, fmt.Errorf("Azure OpenAI Base URL not provided. Use --provider-url flag or AZURE_OPENAI_BASE_URL environment variable")
}

if config.MaxTokens > 0 {
azureConfig.MaxTokens = &config.MaxTokens
}

if config.Temperature != nil {
azureConfig.Temperature = config.Temperature
}

if config.TopP != nil {
azureConfig.TopP = config.TopP
}

if len(config.StopSequences) > 0 {
azureConfig.Stop = config.StopSequences
}

return openai.NewChatModel(ctx, azureConfig)
}

func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
apiKey := config.ProviderAPIKey
if apiKey == "" {
Expand Down