Skip to content

Commit

Permalink
feat: add azure openai provider
Browse files Browse the repository at this point in the history
Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io>
  • Loading branch information
Aris Boutselis committed Apr 25, 2023
1 parent 2102f06 commit 63e517e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
8 changes: 8 additions & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ var (
backend string
password string
model string
baseURL string
engine string
)

// authCmd represents the auth command
Expand Down Expand Up @@ -73,6 +75,8 @@ var AuthCmd = &cobra.Command{
Name: backend,
Model: model,
Password: password,
BaseURL: baseURL,
Engine: engine,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -100,4 +104,8 @@ func init() {
AuthCmd.Flags().StringVarP(&model, "model", "m", "gpt-3.5-turbo", "Backend AI model")
// add flag for password
AuthCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
AuthCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for azure open ai engine/deployment name
AuthCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
}
5 changes: 4 additions & 1 deletion cmd/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ var ServeCmd = &cobra.Command{
backend = os.Getenv("K8SGPT_BACKEND")
password := os.Getenv("K8SGPT_PASSWORD")
model := os.Getenv("K8SGPT_MODEL")
baseURL := os.Getenv("K8SGPT_BASEURL")
engine := os.Getenv("K8SGPT_ENGINE")
// If the envs are set, alocate in place to the aiProvider
// else exit with error
if backend != "" || password != "" || model != "" {
envIsSet := backend != "" || password != "" || model != "" || baseURL != "" || engine != ""
if envIsSet {
aiProvider = &ai.AIProvider{
Name: backend,
Password: password,
Expand Down
94 changes: 94 additions & 0 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ai

import (
"context"
"encoding/base64"
"errors"
"fmt"
"strings"

"github.com/k8sgpt-ai/k8sgpt/pkg/util"

"github.com/fatih/color"
"github.com/spf13/viper"

"github.com/sashabaranov/go-openai"
)

type AzureAIClient struct {
client *openai.Client
language string
model string
}

func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
token := config.GetPassword()
baseURL := config.GetBaseURL()
engine := config.GetEngine()
defaultConfig := openai.DefaultAzureConfig(token, baseURL, engine)
client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")
}
c.language = lang
c.client = client
c.model = config.GetModel()
return nil
}

func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
// Create a completion request
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: fmt.Sprintf(default_prompt, c.language, prompt),
},
},
})
if err != nil {
return "", err
}
return resp.Choices[0].Message.Content, nil
}

func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, nocache bool) (string, error) {
inputKey := strings.Join(prompt, " ")
// Check for cached data
sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey))
cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc)
// find in viper cache
if viper.IsSet(cacheKey) && !nocache {
// retrieve data from cache
response := viper.GetString(cacheKey)
if response == "" {
color.Red("error retrieving cached data")
return "", nil
}
output, err := base64.StdEncoding.DecodeString(response)
if err != nil {
color.Red("error decoding cached data: %v", err)
return "", nil
}
return string(output), nil
}

response, err := a.GetCompletion(ctx, inputKey)
if err != nil {
return "", err
}

if !viper.IsSet(cacheKey) || nocache {
viper.Set(cacheKey, base64.StdEncoding.EncodeToString([]byte(response)))
if err := viper.WriteConfig(); err != nil {
color.Red("error writing config: %v", err)
return "", nil
}
}
return response, nil
}

func (a *AzureAIClient) GetName() string {
return "azureopenai"
}
16 changes: 15 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ type IAI interface {
type IAIConfig interface {
GetPassword() string
GetModel() string
GetBaseURL() string
GetEngine() string
}

func NewClient(provider string) IAI {
switch provider {
case "openai":
return &OpenAIClient{}
case "azureopenai":
return &AzureAIClient{}
case "noopai":
return &NoOpAIClient{}
default:
Expand All @@ -34,7 +38,9 @@ type AIConfiguration struct {
type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
}

func (p *AIProvider) GetPassword() string {
Expand All @@ -44,3 +50,11 @@ func (p *AIProvider) GetPassword() string {
func (p *AIProvider) GetModel() string {
return p.Model
}

func (p *AIProvider) GetBaseURL() string {
return p.BaseURL
}

func (p *AIProvider) GetEngine() string {
return p.Engine
}

0 comments on commit 63e517e

Please sign in to comment.