diff --git a/libaiac/chat.go b/libaiac/chat.go index ecfd0f2..5b7be0f 100644 --- a/libaiac/chat.go +++ b/libaiac/chat.go @@ -64,7 +64,13 @@ func (conv *Conversation) Send(ctx context.Context, prompt string, msgs ...Messa Content: prompt, }) - err = conv.client.NewRequest("POST", "/chat/completions"). + var apiVersion string + if len(conv.client.apiVersion) > 0 { + apiVersion = fmt.Sprintf("?api-version=%s", conv.client.apiVersion) + } + + err = conv.client.NewRequest("POST", + fmt.Sprintf("/chat/completions%s", apiVersion)). JSONBody(map[string]interface{}{ "model": conv.model.Name, "messages": conv.messages, diff --git a/libaiac/libaiac.go b/libaiac/libaiac.go index 7c0fc7a..10ec308 100644 --- a/libaiac/libaiac.go +++ b/libaiac/libaiac.go @@ -16,10 +16,13 @@ import ( // Version contains aiac's version string var Version = "development" +const OpenAIBackend = "https://api.openai.com/v1" + // Client is a structure used to continuously generate IaC code via OpenAPI/ChatGPT type Client struct { *requests.HTTPClient - apiKey string + apiKey string + apiVersion string } var ( @@ -45,21 +48,52 @@ var ( ErrRequestFailed = errors.New("request failed") ) +type NewClientOptions struct { + // APIKey is the OpenAI API key to use for requests. This is required. + ApiKey string + + // ChatGPTURL is the URL to use for ChatGPT requests. This is optional nd by default to openai backend. + URL string + + // APIVersion is the version of the OpenAI API to use. This is optional and by default to non specified. + APIVersion string +} + // NewClient creates a new instance of the Client struct, with the provided // input options. Neither the OpenAI API nor ChatGPT are yet contacted at this // point. -func NewClient(apiKey string) *Client { - if apiKey == "" { +func NewClient(opts *NewClientOptions) *Client { + if opts == nil { return nil } + if opts.ApiKey == "" { + return nil + } + + if opts.URL == "" { + opts.URL = OpenAIBackend + } + + var authHeaderKey string + var authHeaderVal string + + if opts.URL == OpenAIBackend { + authHeaderKey = "Authorization" + authHeaderVal = fmt.Sprintf("Bearer %s", opts.ApiKey) + } else { + authHeaderKey = "api-key" + authHeaderVal = opts.ApiKey + } + cli := &Client{ - apiKey: strings.TrimPrefix(apiKey, "Bearer "), + apiKey: strings.TrimPrefix(opts.ApiKey, "Bearer "), + apiVersion: opts.APIVersion, } - cli.HTTPClient = requests.NewClient("https://api.openai.com/v1"). + cli.HTTPClient = requests.NewClient(opts.URL). Accept("application/json"). - Header("Authorization", fmt.Sprintf("Bearer %s", cli.apiKey)). + Header(authHeaderKey, authHeaderVal). ErrorHandler(func( httpStatus int, contentType string, diff --git a/main.go b/main.go index 86ae401..3ee2fbb 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,8 @@ type flags struct { } `cmd:"" help:"List supported models"` Get struct { APIKey string `help:"OpenAI API key" required:"" env:"OPENAI_API_KEY"` + URL string `help:"OpenAI API url. Can be Azure Open AI service" default:"https://api.openai.com/v1" env:"OPENAI_API_URL"` + APIVersion string `help:"OpenAI API version" default:"" env:"OPENAI_API_VERSION"` OutputFile string `help:"Output file to push resulting code to" optional:"" type:"path" short:"o"` //nolint: lll ReadmeFile string `help:"Readme file to push entire Markdown output to" optional:"" type:"path" short:"r"` //nolint: lll Quiet bool `help:"Non-interactive mode, print/save output and exit" default:"false" short:"q"` //nolint: lll @@ -107,7 +109,11 @@ func generateCode(cli flags) error { //nolint: funlen, cyclop cli.Get.Model = libaiac.ModelGPT35Turbo } - client := libaiac.NewClient(cli.Get.APIKey) + client := libaiac.NewClient(&libaiac.NewClientOptions{ + ApiKey: cli.Get.APIKey, + URL: cli.Get.URL, + APIVersion: cli.Get.APIVersion, + }) spin := spinner.New( spinner.CharSets[11],