Skip to content

Commit

Permalink
Add support for azure open ai api adaptations (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
liavyona committed Jun 8, 2023
1 parent 43a6574 commit 75a0ab3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
8 changes: 7 additions & 1 deletion libaiac/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ func (conv *Conversation) Send(ctx context.Context, prompt string, msgs ...Messa
Content: prompt,
})

err = conv.client.NewRequest("POST", "/chat/completions").
var apiVersion string
if len(conv.client.apiVersion) > 0 {
apiVersion = fmt.Sprintf("?api-version=%s", conv.client.apiVersion)
}

err = conv.client.NewRequest("POST",
fmt.Sprintf("/chat/completions%s", apiVersion)).
JSONBody(map[string]interface{}{
"model": conv.model.Name,
"messages": conv.messages,
Expand Down
46 changes: 40 additions & 6 deletions libaiac/libaiac.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ import (
// Version contains aiac's version string
var Version = "development"

const OpenAIBackend = "https://api.openai.com/v1"

// Client is a structure used to continuously generate IaC code via OpenAPI/ChatGPT
type Client struct {
*requests.HTTPClient
apiKey string
apiKey string
apiVersion string
}

var (
Expand All @@ -45,21 +48,52 @@ var (
ErrRequestFailed = errors.New("request failed")
)

type NewClientOptions struct {
// APIKey is the OpenAI API key to use for requests. This is required.
ApiKey string

// ChatGPTURL is the URL to use for ChatGPT requests. This is optional nd by default to openai backend.
URL string

// APIVersion is the version of the OpenAI API to use. This is optional and by default to non specified.
APIVersion string
}

// NewClient creates a new instance of the Client struct, with the provided
// input options. Neither the OpenAI API nor ChatGPT are yet contacted at this
// point.
func NewClient(apiKey string) *Client {
if apiKey == "" {
func NewClient(opts *NewClientOptions) *Client {
if opts == nil {
return nil
}

if opts.ApiKey == "" {
return nil
}

if opts.URL == "" {
opts.URL = OpenAIBackend
}

var authHeaderKey string
var authHeaderVal string

if opts.URL == OpenAIBackend {
authHeaderKey = "Authorization"
authHeaderVal = fmt.Sprintf("Bearer %s", opts.ApiKey)
} else {
authHeaderKey = "api-key"
authHeaderVal = opts.ApiKey
}

cli := &Client{
apiKey: strings.TrimPrefix(apiKey, "Bearer "),
apiKey: strings.TrimPrefix(opts.ApiKey, "Bearer "),
apiVersion: opts.APIVersion,
}

cli.HTTPClient = requests.NewClient("https://api.openai.com/v1").
cli.HTTPClient = requests.NewClient(opts.URL).
Accept("application/json").
Header("Authorization", fmt.Sprintf("Bearer %s", cli.apiKey)).
Header(authHeaderKey, authHeaderVal).
ErrorHandler(func(
httpStatus int,
contentType string,
Expand Down
8 changes: 7 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type flags struct {
} `cmd:"" help:"List supported models"`
Get struct {
APIKey string `help:"OpenAI API key" required:"" env:"OPENAI_API_KEY"`
URL string `help:"OpenAI API url. Can be Azure Open AI service" default:"https://api.openai.com/v1" env:"OPENAI_API_URL"`
APIVersion string `help:"OpenAI API version" default:"" env:"OPENAI_API_VERSION"`
OutputFile string `help:"Output file to push resulting code to" optional:"" type:"path" short:"o"` //nolint: lll
ReadmeFile string `help:"Readme file to push entire Markdown output to" optional:"" type:"path" short:"r"` //nolint: lll
Quiet bool `help:"Non-interactive mode, print/save output and exit" default:"false" short:"q"` //nolint: lll
Expand Down Expand Up @@ -107,7 +109,11 @@ func generateCode(cli flags) error { //nolint: funlen, cyclop
cli.Get.Model = libaiac.ModelGPT35Turbo
}

client := libaiac.NewClient(cli.Get.APIKey)
client := libaiac.NewClient(&libaiac.NewClientOptions{
ApiKey: cli.Get.APIKey,
URL: cli.Get.URL,
APIVersion: cli.Get.APIVersion,
})

spin := spinner.New(
spinner.CharSets[11],
Expand Down

0 comments on commit 75a0ab3

Please sign in to comment.