Skip to content

Commit

Permalink
feat: add proxysettings for azureopenai and openai (#987)
Browse files Browse the repository at this point in the history
Signed-off-by: tanujd11 <dwiveditanuj41@gmail.com>
Co-authored-by: Aris Boutselis <arisboutselis08@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
  • Loading branch information
3 people committed Feb 28, 2024
1 parent aab8d77 commit 307710e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cmd/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ var ServeCmd = &cobra.Command{
model := os.Getenv("K8SGPT_MODEL")
baseURL := os.Getenv("K8SGPT_BASEURL")
engine := os.Getenv("K8SGPT_ENGINE")
proxyEndpoint := os.Getenv("K8SGPT_PROXY_ENDPOINT")
// If the envs are set, allocate in place to the aiProvider
// else exit with error
envIsSet := backend != "" || password != "" || model != ""
Expand All @@ -83,6 +84,7 @@ var ServeCmd = &cobra.Command{
Model: model,
BaseURL: baseURL,
Engine: engine,
ProxyEndpoint: proxyEndpoint,
Temperature: temperature(),
}

Expand Down
17 changes: 17 additions & 0 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"

"github.com/sashabaranov/go-openai"
)
Expand All @@ -21,6 +23,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
baseURL := config.GetBaseURL()
engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL)

defaultConfig.AzureModelMapperFunc = func(model string) string {
Expand All @@ -31,6 +34,20 @@ func (c *AzureAIClient) Configure(config IAIConfig) error {
return azureModelMapping[model]

}

if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}

defaultConfig.HTTPClient = &http.Client{
Transport: transport,
}
}
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")
Expand Down
7 changes: 7 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetProxyEndpoint() string
GetEndpointName() string
GetEngine() string
GetTemperature() float32
Expand Down Expand Up @@ -92,6 +93,8 @@ type AIProvider struct {
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
Expand All @@ -104,6 +107,10 @@ func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}

func (p *AIProvider) GetProxyEndpoint() string {
return p.ProxyEndpoint
}

func (p *AIProvider) GetEndpointName() string {
return p.EndpointName
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package ai
import (
"context"
"errors"
"net/http"
"net/url"

"github.com/sashabaranov/go-openai"
)
Expand All @@ -41,12 +43,27 @@ const (
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
proxyEndpoint := config.GetProxyEndpoint()

baseURL := config.GetBaseURL()
if baseURL != "" {
defaultConfig.BaseURL = baseURL
}

if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}

defaultConfig.HTTPClient = &http.Client{
Transport: transport,
}
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")
Expand Down

0 comments on commit 307710e

Please sign in to comment.