-
Notifications
You must be signed in to change notification settings - Fork 5
support Custom Endpoints #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ package cmd | |
|
|
||
| import ( | ||
| "fmt" | ||
| "net/url" | ||
| "os" | ||
|
|
||
| "github.com/AlecAivazis/survey/v2" | ||
|
|
@@ -26,8 +27,14 @@ var getCmd = &cobra.Command{ | |
| fmt.Println("Error getting model:", err) | ||
| os.Exit(1) | ||
| } | ||
| endpoint, err := config.GetEndpoint() | ||
| if err != nil { | ||
| fmt.Println("Error getting endpoint:", err) | ||
| os.Exit(1) | ||
| } | ||
| fmt.Printf("Active Provider: %s\n", provider) | ||
| fmt.Printf("Model: %s\n", model) | ||
| fmt.Printf("Endpoint: %s\n", endpoint) | ||
| }, | ||
| } | ||
|
|
||
|
|
@@ -39,13 +46,40 @@ var setCmd = &cobra.Command{ | |
| }, | ||
| } | ||
|
|
||
| func validateEndpointURL(val interface{}) error { | ||
| endpoint, ok := val.(string) | ||
| if !ok { | ||
| return fmt.Errorf("endpoint must be a string") | ||
| } | ||
|
|
||
| // Empty string is valid (uses default) | ||
| if endpoint == "" { | ||
| return nil | ||
| } | ||
|
|
||
| parsedURL, err := url.Parse(endpoint) | ||
| if err != nil { | ||
| return fmt.Errorf("invalid URL format: %w", err) | ||
| } | ||
|
|
||
| if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { | ||
| return fmt.Errorf("endpoint must use http or https protocol") | ||
| } | ||
|
|
||
| if parsedURL.Host == "" { | ||
| return fmt.Errorf("endpoint must have a valid host") | ||
| } | ||
|
|
||
| return nil | ||
|
Comment on lines
+49
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The For example, you could use a regular expression to validate the URL format or check the URL length to prevent buffer overflow issues. |
||
| } | ||
|
|
||
| func runInteractiveConfig() { | ||
| currentProvider := config.GetProvider() | ||
| currentModel, _ := config.GetModel() | ||
|
|
||
| providerPrompt := &survey.Select{ | ||
| Message: "Choose a provider:", | ||
| Options: []string{"openai", "openrouter", "copilot"}, | ||
| Options: []string{"openai", "copilot"}, | ||
| Default: currentProvider, | ||
| } | ||
| var selectedProvider string | ||
|
|
@@ -87,9 +121,8 @@ func runInteractiveConfig() { | |
|
|
||
| // Dynamically generate available models for OpenAI | ||
| availableModels := map[string][]string{ | ||
| "openai": {}, | ||
| "openrouter": {}, | ||
| "copilot": {"gpt-4o"}, // TODO: update if copilot models are dynamic | ||
| "openai": {}, | ||
| "copilot": {"gpt-4o"}, // TODO: update if copilot models are dynamic | ||
| } | ||
|
|
||
| modelDisplayToID := map[string]string{} | ||
|
|
@@ -99,12 +132,6 @@ func runInteractiveConfig() { | |
| availableModels["openai"] = append(availableModels["openai"], display) | ||
| modelDisplayToID[display] = string(id) | ||
| } | ||
| } else if selectedProvider == "openrouter" { | ||
| for id, m := range models.OpenRouterModels { | ||
| display := fmt.Sprintf("%s (%s)", m.Name, string(id)) | ||
| availableModels["openrouter"] = append(availableModels["openrouter"], display) | ||
| modelDisplayToID[display] = string(id) | ||
| } | ||
| } | ||
|
|
||
| modelPrompt := &survey.Select{ | ||
|
|
@@ -115,7 +142,7 @@ func runInteractiveConfig() { | |
| // Try to set the default to the current model if possible | ||
| isValidDefault := false | ||
| currentDisplay := "" | ||
| if selectedProvider == "openai" || selectedProvider == "openrouter" { | ||
| if selectedProvider == "openai" { | ||
| for display, id := range modelDisplayToID { | ||
| if id == currentModel || display == currentModel { | ||
| isValidDefault = true | ||
|
|
@@ -144,7 +171,7 @@ func runInteractiveConfig() { | |
| } | ||
|
|
||
| selectedModel := selectedDisplay | ||
| if selectedProvider == "openai" || selectedProvider == "openrouter" { | ||
| if selectedProvider == "openai" { | ||
| selectedModel = modelDisplayToID[selectedDisplay] | ||
| } | ||
|
|
||
|
|
@@ -156,6 +183,33 @@ func runInteractiveConfig() { | |
| } | ||
| fmt.Printf("Model set to: %s\n", selectedModel) | ||
| } | ||
|
|
||
| // Get current endpoint | ||
| currentEndpoint, _ := config.GetEndpoint() | ||
|
|
||
| // Endpoint configuration prompt | ||
| endpointPrompt := &survey.Input{ | ||
| Message: "Enter custom endpoint URL (leave empty for default):", | ||
| Default: currentEndpoint, | ||
| } | ||
| var endpoint string | ||
| err = survey.AskOne(endpointPrompt, &endpoint, survey.WithValidator(validateEndpointURL)) | ||
| if err != nil { | ||
| fmt.Println(err.Error()) | ||
| return | ||
| } | ||
|
|
||
| // Only set endpoint if it's different from current | ||
| if endpoint != currentEndpoint && endpoint != "" { | ||
| err := config.SetEndpoint(selectedProvider, endpoint) | ||
| if err != nil { | ||
| fmt.Printf("Error setting endpoint: %v\n", err) | ||
| return | ||
| } | ||
| fmt.Printf("Endpoint set to: %s\n", endpoint) | ||
| } else if endpoint == "" { | ||
| fmt.Println("Using default endpoint for provider") | ||
| } | ||
| } | ||
|
|
||
| func init() { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ package config | |
| import ( | ||
| "encoding/json" | ||
| "fmt" | ||
| "net/url" | ||
| "os" | ||
| "path/filepath" | ||
| "runtime" | ||
|
|
@@ -12,8 +13,9 @@ import ( | |
| ) | ||
|
|
||
| type ProviderConfig struct { | ||
| APIKey string `mapstructure:"api_key"` | ||
| Model string `mapstructure:"model"` | ||
| APIKey string `mapstructure:"api_key"` | ||
| Model string `mapstructure:"model"` | ||
| EndpointURL string `mapstructure:"endpoint_url"` | ||
| } | ||
|
|
||
| type Config struct { | ||
|
|
@@ -124,6 +126,28 @@ func GetModel() (string, error) { | |
| return providerConfig.Model, nil | ||
| } | ||
|
|
||
| func GetEndpoint() (string, error) { | ||
| providerConfig, err := GetActiveProviderConfig() | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
|
|
||
| // If custom endpoint is configured, use it | ||
| if providerConfig.EndpointURL != "" { | ||
| return providerConfig.EndpointURL, nil | ||
| } | ||
|
|
||
| // Return default endpoints based on provider | ||
| switch cfg.ActiveProvider { | ||
| case "openai": | ||
| return "https://api.openai.com/v1", nil | ||
| case "copilot": | ||
| return "https://api.githubcopilot.com", nil | ||
| default: | ||
| return "", fmt.Errorf("no default endpoint available for provider '%s'", cfg.ActiveProvider) | ||
| } | ||
| } | ||
|
|
||
| func SetProvider(provider string) error { | ||
| if cfg == nil { | ||
| InitConfig() | ||
|
|
@@ -150,6 +174,41 @@ func SetAPIKey(provider, apiKey string) error { | |
| return viper.WriteConfig() | ||
| } | ||
|
|
||
| func validateEndpointURL(endpoint string) error { | ||
| if endpoint == "" { | ||
| return nil // Empty endpoint is valid (will use default) | ||
| } | ||
|
|
||
| parsedURL, err := url.Parse(endpoint) | ||
| if err != nil { | ||
| return fmt.Errorf("invalid URL format: %w", err) | ||
| } | ||
|
|
||
| if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { | ||
| return fmt.Errorf("endpoint must use http or https protocol") | ||
| } | ||
|
|
||
| if parsedURL.Host == "" { | ||
| return fmt.Errorf("endpoint must have a valid host") | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
Comment on lines
+177
to
+196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+177
to
+196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The For example, you could use a regular expression to validate the URL format or check the URL length to prevent buffer overflow issues. |
||
|
|
||
| func SetEndpoint(provider, endpoint string) error { | ||
| if cfg == nil { | ||
| InitConfig() | ||
| } | ||
|
|
||
| // Validate endpoint URL | ||
| if err := validateEndpointURL(endpoint); err != nil { | ||
| return err | ||
| } | ||
|
|
||
| viper.Set(fmt.Sprintf("providers.%s.endpoint_url", provider), endpoint) | ||
| return viper.WriteConfig() | ||
| } | ||
|
|
||
| func LoadGitHubToken() (string, error) { | ||
| if token := os.Getenv("GITHUB_TOKEN"); token != "" { | ||
| return token, nil | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block retrieves the endpoint, but the error is only handled by printing to stderr and exiting. It would be better to propagate this error so that the calling function can handle it gracefully, potentially providing more context to the user or attempting a fallback mechanism.
Consider returning the error to allow the caller to handle it.