diff --git a/client.go b/client.go index 6f334c6a..3cee6aeb 100644 --- a/client.go +++ b/client.go @@ -31,11 +31,9 @@ type Client struct { Uploads *UploadService } -// NewClient generates a new client with the default option read from the -// environment (OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID). The option -// passed in as arguments are applied after these default arguments, and all option -// will be passed down to the services and requests that this client makes. -func NewClient(opts ...option.RequestOption) (r *Client) { +// DefaultClientOptions read from the environment (OPENAI_API_KEY, OPENAI_ORG_ID, +// OPENAI_PROJECT_ID). This should be used to initialize new clients. +func DefaultClientOptions() []option.RequestOption { defaults := []option.RequestOption{option.WithEnvironmentProduction()} if o, ok := os.LookupEnv("OPENAI_API_KEY"); ok { defaults = append(defaults, option.WithAPIKey(o)) @@ -46,7 +44,15 @@ func NewClient(opts ...option.RequestOption) (r *Client) { if o, ok := os.LookupEnv("OPENAI_PROJECT_ID"); ok { defaults = append(defaults, option.WithProject(o)) } - opts = append(defaults, opts...) + return defaults +} + +// NewClient generates a new client with the default option read from the +// environment (OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID). The option +// passed in as arguments are applied after these default arguments, and all option +// will be passed down to the services and requests that this client makes. +func NewClient(opts ...option.RequestOption) (r *Client) { + opts = append(DefaultClientOptions(), opts...) r = &Client{Options: opts} diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index ce5b93a3..b5b0d782 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -22,6 +22,7 @@ import ( "github.com/openai/openai-go/internal/apierror" "github.com/openai/openai-go/internal/apiform" "github.com/openai/openai-go/internal/apiquery" + "github.com/openai/openai-go/internal/param" ) func getDefaultHeaders() map[string]string { @@ -77,7 +78,17 @@ func getPlatformProperties() map[string]string { } } -func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...func(*RequestConfig) error) (*RequestConfig, error) { +type RequestOption interface { + Apply(*RequestConfig) error +} + +type RequestOptionFunc func(*RequestConfig) error +type PreRequestOptionFunc func(*RequestConfig) error + +func (s RequestOptionFunc) Apply(r *RequestConfig) error { return s(r) } +func (s PreRequestOptionFunc) Apply(r *RequestConfig) error { return s(r) } + +func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) (*RequestConfig, error) { var reader io.Reader contentType := "application/json" @@ -174,10 +185,17 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa return &cfg, nil } +func UseDefaultParam[T any](dst *param.Field[T], src *T) { + if !dst.Present && src != nil { + dst.Value = *src + dst.Present = true + } +} + // RequestConfig represents all the state related to one request. // // Editing the variables inside RequestConfig directly is unstable api. Prefer -// composing func(\*RequestConfig) error instead if possible. +// composing the RequestOption instead if possible. type RequestConfig struct { MaxRetries int RequestTimeout time.Duration @@ -519,7 +537,7 @@ func (cfg *RequestConfig) Execute() (err error) { return nil } -func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...func(*RequestConfig) error) error { +func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) error { cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...) if err != nil { return err @@ -555,12 +573,27 @@ func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig { return new } -func (cfg *RequestConfig) Apply(opts ...func(*RequestConfig) error) error { +func (cfg *RequestConfig) Apply(opts ...RequestOption) error { for _, opt := range opts { - err := opt(cfg) + err := opt.Apply(cfg) if err != nil { return err } } return nil } + +func PreRequestOptions(opts ...RequestOption) (RequestConfig, error) { + cfg := RequestConfig{} + for _, opt := range opts { + if _, ok := opt.(PreRequestOptionFunc); !ok { + continue + } + + err := opt.Apply(&cfg) + if err != nil { + return cfg, err + } + } + return cfg, nil +} diff --git a/option/requestoption.go b/option/requestoption.go index c37aa232..a4350826 100644 --- a/option/requestoption.go +++ b/option/requestoption.go @@ -21,7 +21,7 @@ import ( // options pattern in our [README]. // // [README]: https://pkg.go.dev/github.com/openai/openai-go#readme-requestoptions -type RequestOption = func(*requestconfig.RequestConfig) error +type RequestOption = requestconfig.RequestOption // WithBaseURL returns a RequestOption that sets the BaseURL for the client. func WithBaseURL(base string) RequestOption { @@ -29,22 +29,22 @@ func WithBaseURL(base string) RequestOption { if err != nil { log.Fatalf("failed to parse BaseURL: %s\n", err) } - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { if u.Path != "" && !strings.HasSuffix(u.Path, "/") { u.Path += "/" } r.BaseURL = u return nil - } + }) } // WithHTTPClient returns a RequestOption that changes the underlying [http.Client] used to make this // request, which by default is [http.DefaultClient]. func WithHTTPClient(client *http.Client) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.HTTPClient = client return nil - } + }) } // MiddlewareNext is a function which is called by a middleware to pass an HTTP request @@ -59,10 +59,10 @@ type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error) // WithMiddleware returns a RequestOption that applies the given middleware // to the requests made. Each middleware will execute in the order they were given. func WithMiddleware(middlewares ...Middleware) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Middlewares = append(r.Middlewares, middlewares...) return nil - } + }) } // WithMaxRetries returns a RequestOption that sets the maximum number of retries that the client @@ -74,68 +74,68 @@ func WithMaxRetries(retries int) RequestOption { if retries < 0 { panic("option: cannot have fewer than 0 retries") } - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.MaxRetries = retries return nil - } + }) } // WithHeader returns a RequestOption that sets the header value to the associated key. It overwrites // any value if there was one already present. func WithHeader(key, value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Request.Header.Set(key, value) return nil - } + }) } // WithHeaderAdd returns a RequestOption that adds the header value to the associated key. It appends // onto any existing values. func WithHeaderAdd(key, value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Request.Header.Add(key, value) return nil - } + }) } // WithHeaderDel returns a RequestOption that deletes the header value(s) associated with the given key. func WithHeaderDel(key string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Request.Header.Del(key) return nil - } + }) } // WithQuery returns a RequestOption that sets the query value to the associated key. It overwrites // any value if there was one already present. func WithQuery(key, value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { query := r.Request.URL.Query() query.Set(key, value) r.Request.URL.RawQuery = query.Encode() return nil - } + }) } // WithQueryAdd returns a RequestOption that adds the query value to the associated key. It appends // onto any existing values. func WithQueryAdd(key, value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { query := r.Request.URL.Query() query.Add(key, value) r.Request.URL.RawQuery = query.Encode() return nil - } + }) } // WithQueryDel returns a RequestOption that deletes the query value(s) associated with the key. func WithQueryDel(key string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { query := r.Request.URL.Query() query.Del(key) r.Request.URL.RawQuery = query.Encode() return nil - } + }) } // WithJSONSet returns a RequestOption that sets the body's JSON value associated with the key. @@ -143,7 +143,7 @@ func WithQueryDel(key string) RequestOption { // // [sjson format]: https://github.com/tidwall/sjson func WithJSONSet(key string, value interface{}) RequestOption { - return func(r *requestconfig.RequestConfig) (err error) { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) (err error) { if buffer, ok := r.Body.(*bytes.Buffer); ok { b := buffer.Bytes() b, err = sjson.SetBytes(b, key, value) @@ -155,7 +155,7 @@ func WithJSONSet(key string, value interface{}) RequestOption { } return fmt.Errorf("cannot use WithJSONSet on a body that is not serialized as *bytes.Buffer") - } + }) } // WithJSONDel returns a RequestOption that deletes the body's JSON value associated with the key. @@ -163,7 +163,7 @@ func WithJSONSet(key string, value interface{}) RequestOption { // // [sjson format]: https://github.com/tidwall/sjson func WithJSONDel(key string) RequestOption { - return func(r *requestconfig.RequestConfig) (err error) { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) (err error) { if buffer, ok := r.Body.(*bytes.Buffer); ok { b := buffer.Bytes() b, err = sjson.DeleteBytes(b, key) @@ -175,24 +175,24 @@ func WithJSONDel(key string) RequestOption { } return fmt.Errorf("cannot use WithJSONDel on a body that is not serialized as *bytes.Buffer") - } + }) } // WithResponseBodyInto returns a RequestOption that overwrites the deserialization target with // the given destination. If provided, we don't deserialize into the default struct. func WithResponseBodyInto(dst any) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.ResponseBodyInto = dst return nil - } + }) } // WithResponseInto returns a RequestOption that copies the [*http.Response] into the given address. func WithResponseInto(dst **http.Response) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.ResponseInto = dst return nil - } + }) } // WithRequestBody returns a RequestOption that provides a custom serialized body with the given @@ -200,7 +200,7 @@ func WithResponseInto(dst **http.Response) RequestOption { // // body accepts an io.Reader or raw []bytes. func WithRequestBody(contentType string, body any) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { if reader, ok := body.(io.Reader); ok { r.Body = reader return r.Apply(WithHeader("Content-Type", contentType)) @@ -212,17 +212,17 @@ func WithRequestBody(contentType string, body any) RequestOption { } return fmt.Errorf("body must be a byte slice or implement io.Reader") - } + }) } // WithRequestTimeout returns a RequestOption that sets the timeout for // each request attempt. This should be smaller than the timeout defined in // the context, which spans all retries. func WithRequestTimeout(dur time.Duration) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.RequestTimeout = dur return nil - } + }) } // WithEnvironmentProduction returns a RequestOption that sets the current @@ -234,24 +234,24 @@ func WithEnvironmentProduction() RequestOption { // WithAPIKey returns a RequestOption that sets the client setting "api_key". func WithAPIKey(value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.APIKey = value return r.Apply(WithHeader("authorization", fmt.Sprintf("Bearer %s", r.APIKey))) - } + }) } // WithOrganization returns a RequestOption that sets the client setting "organization". func WithOrganization(value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Organization = value return r.Apply(WithHeader("OpenAI-Organization", value)) - } + }) } // WithProject returns a RequestOption that sets the client setting "project". func WithProject(value string) RequestOption { - return func(r *requestconfig.RequestConfig) error { + return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error { r.Project = value return r.Apply(WithHeader("OpenAI-Project", value)) - } + }) }