diff --git a/.golangci.yml b/.golangci.yml index 142ba345..db2b6870 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -23,8 +23,10 @@ linters: - nestif - nilerr # nilerr crashes on this repo - nlreturn + - noinlineerr - nonamedreturns - paralleltest + - recvcheck - testpackage - thelper - tparallel @@ -33,6 +35,7 @@ linters: - whitespace - wrapcheck - wsl + - wsl_v5 settings: dupl: threshold: 200 diff --git a/client/request.go b/client/request.go index bd35f46d..8f8bab03 100644 --- a/client/request.go +++ b/client/request.go @@ -30,23 +30,11 @@ import ( "strings" "time" - "github.com/go-openapi/strfmt" - "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" ) -// NewRequest creates a new swagger http client request -func newRequest(method, pathPattern string, writer runtime.ClientRequestWriter) *request { - return &request{ - pathPattern: pathPattern, - method: method, - writer: writer, - header: make(http.Header), - query: make(url.Values), - timeout: DefaultTimeout, - getBody: getRequestBuffer, - } -} +var _ runtime.ClientRequest = new(request) // ensure compliance to the interface // Request represents a swagger client request. // @@ -67,40 +55,156 @@ type request struct { query url.Values formFields url.Values fileFields map[string][]runtime.NamedReadCloser - payload interface{} + payload any timeout time.Duration buf *bytes.Buffer getBody func(r *request) []byte } -var ( - // ensure interface compliance - _ runtime.ClientRequest = new(request) -) - -func (r *request) isMultipart(mediaType string) bool { - if len(r.fileFields) > 0 { - return true +// NewRequest creates a new swagger http client request +func newRequest(method, pathPattern string, writer runtime.ClientRequestWriter) *request { + return &request{ + pathPattern: pathPattern, + method: method, + writer: writer, + header: make(http.Header), + query: make(url.Values), + timeout: DefaultTimeout, + getBody: getRequestBuffer, } - - return runtime.MultipartFormMime == mediaType } // BuildHTTP creates a new http request based on the data from the params func (r *request) BuildHTTP(mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry) (*http.Request, error) { return r.buildHTTP(mediaType, basePath, producers, registry, nil) } -func escapeQuotes(s string) string { - return strings.NewReplacer("\\", "\\\\", `"`, "\\\"").Replace(s) + +func (r *request) GetMethod() string { + return r.method } -func logClose(err error, pw *io.PipeWriter) { - log.Println(err) - closeErr := pw.CloseWithError(err) - if closeErr != nil { - log.Println(closeErr) +func (r *request) GetPath() string { + path := r.pathPattern + for k, v := range r.pathParams { + path = strings.ReplaceAll(path, "{"+k+"}", v) + } + return path +} + +func (r *request) GetBody() []byte { + return r.getBody(r) +} + +// SetHeaderParam adds a header param to the request +// when there is only 1 value provided for the varargs, it will set it. +// when there are several values provided for the varargs it will add it (no overriding) +func (r *request) SetHeaderParam(name string, values ...string) error { + if r.header == nil { + r.header = make(http.Header) + } + r.header[http.CanonicalHeaderKey(name)] = values + return nil +} + +// GetHeaderParams returns the all headers currently set for the request +func (r *request) GetHeaderParams() http.Header { + return r.header +} + +// SetQueryParam adds a query param to the request +// when there is only 1 value provided for the varargs, it will set it. +// when there are several values provided for the varargs it will add it (no overriding) +func (r *request) SetQueryParam(name string, values ...string) error { + if r.query == nil { + r.query = make(url.Values) + } + r.query[name] = values + return nil +} + +// GetQueryParams returns a copy of all query params currently set for the request +func (r *request) GetQueryParams() url.Values { + var result = make(url.Values) + for key, value := range r.query { + result[key] = append([]string{}, value...) + } + return result +} + +// SetFormParam adds a forn param to the request +// when there is only 1 value provided for the varargs, it will set it. +// when there are several values provided for the varargs it will add it (no overriding) +func (r *request) SetFormParam(name string, values ...string) error { + if r.formFields == nil { + r.formFields = make(url.Values) + } + r.formFields[name] = values + return nil +} + +// SetPathParam adds a path param to the request +func (r *request) SetPathParam(name string, value string) error { + if r.pathParams == nil { + r.pathParams = make(map[string]string) + } + + r.pathParams[name] = value + return nil +} + +// SetFileParam adds a file param to the request +func (r *request) SetFileParam(name string, files ...runtime.NamedReadCloser) error { + for _, file := range files { + if actualFile, ok := file.(*os.File); ok { + fi, err := os.Stat(actualFile.Name()) + if err != nil { + return err + } + if fi.IsDir() { + return fmt.Errorf("%q is a directory, only files are supported", file.Name()) + } + } + } + + if r.fileFields == nil { + r.fileFields = make(map[string][]runtime.NamedReadCloser) } + if r.formFields == nil { + r.formFields = make(url.Values) + } + + r.fileFields[name] = files + return nil +} + +func (r *request) GetFileParam() map[string][]runtime.NamedReadCloser { + return r.fileFields +} + +// SetBodyParam sets a body parameter on the request. +// This does not yet serialze the object, this happens as late as possible. +func (r *request) SetBodyParam(payload any) error { + r.payload = payload + return nil +} + +func (r *request) GetBodyParam() any { + return r.payload +} + +// SetTimeout sets the timeout for a request +func (r *request) SetTimeout(timeout time.Duration) error { + r.timeout = timeout + return nil +} + +func (r *request) isMultipart(mediaType string) bool { + if len(r.fileFields) > 0 { + return true + } + + return runtime.MultipartFormMime == mediaType } func (r *request) buildHTTP(mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, error) { //nolint:gocyclo,maintidx @@ -349,27 +453,8 @@ DoneChoosingBodySource: return req, nil } -func mangleContentType(mediaType, boundary string) string { - if strings.ToLower(mediaType) == runtime.URLencodedFormMime { - return fmt.Sprintf("%s; boundary=%s", mediaType, boundary) - } - return "multipart/form-data; boundary=" + boundary -} - -func (r *request) GetMethod() string { - return r.method -} - -func (r *request) GetPath() string { - path := r.pathPattern - for k, v := range r.pathParams { - path = strings.ReplaceAll(path, "{"+k+"}", v) - } - return path -} - -func (r *request) GetBody() []byte { - return r.getBody(r) +func escapeQuotes(s string) string { + return strings.NewReplacer("\\", "\\\\", `"`, "\\\"").Replace(s) } func getRequestBuffer(r *request) []byte { @@ -379,105 +464,17 @@ func getRequestBuffer(r *request) []byte { return r.buf.Bytes() } -// SetHeaderParam adds a header param to the request -// when there is only 1 value provided for the varargs, it will set it. -// when there are several values provided for the varargs it will add it (no overriding) -func (r *request) SetHeaderParam(name string, values ...string) error { - if r.header == nil { - r.header = make(http.Header) - } - r.header[http.CanonicalHeaderKey(name)] = values - return nil -} - -// GetHeaderParams returns the all headers currently set for the request -func (r *request) GetHeaderParams() http.Header { - return r.header -} - -// SetQueryParam adds a query param to the request -// when there is only 1 value provided for the varargs, it will set it. -// when there are several values provided for the varargs it will add it (no overriding) -func (r *request) SetQueryParam(name string, values ...string) error { - if r.query == nil { - r.query = make(url.Values) - } - r.query[name] = values - return nil -} - -// GetQueryParams returns a copy of all query params currently set for the request -func (r *request) GetQueryParams() url.Values { - var result = make(url.Values) - for key, value := range r.query { - result[key] = append([]string{}, value...) - } - return result -} - -// SetFormParam adds a forn param to the request -// when there is only 1 value provided for the varargs, it will set it. -// when there are several values provided for the varargs it will add it (no overriding) -func (r *request) SetFormParam(name string, values ...string) error { - if r.formFields == nil { - r.formFields = make(url.Values) - } - r.formFields[name] = values - return nil -} - -// SetPathParam adds a path param to the request -func (r *request) SetPathParam(name string, value string) error { - if r.pathParams == nil { - r.pathParams = make(map[string]string) +func logClose(err error, pw *io.PipeWriter) { + log.Println(err) + closeErr := pw.CloseWithError(err) + if closeErr != nil { + log.Println(closeErr) } - - r.pathParams[name] = value - return nil } -// SetFileParam adds a file param to the request -func (r *request) SetFileParam(name string, files ...runtime.NamedReadCloser) error { - for _, file := range files { - if actualFile, ok := file.(*os.File); ok { - fi, err := os.Stat(actualFile.Name()) - if err != nil { - return err - } - if fi.IsDir() { - return fmt.Errorf("%q is a directory, only files are supported", file.Name()) - } - } - } - - if r.fileFields == nil { - r.fileFields = make(map[string][]runtime.NamedReadCloser) - } - if r.formFields == nil { - r.formFields = make(url.Values) +func mangleContentType(mediaType, boundary string) string { + if strings.ToLower(mediaType) == runtime.URLencodedFormMime { + return fmt.Sprintf("%s; boundary=%s", mediaType, boundary) } - - r.fileFields[name] = files - return nil -} - -func (r *request) GetFileParam() map[string][]runtime.NamedReadCloser { - return r.fileFields -} - -// SetBodyParam sets a body parameter on the request. -// This does not yet serialze the object, this happens as late as possible. -func (r *request) SetBodyParam(payload interface{}) error { - r.payload = payload - return nil -} - -func (r *request) GetBodyParam() interface{} { - return r.payload -} - -// SetTimeout sets the timeout for a request -func (r *request) SetTimeout(timeout time.Duration) error { - r.timeout = timeout - return nil + return "multipart/form-data; boundary=" + boundary } diff --git a/client/request_test.go b/client/request_test.go index 8689b37d..742cc1e8 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -54,7 +54,7 @@ func TestBuildRequest_SetHeaders(t *testing.T) { // multi value _ = r.SetHeaderParam("X-Accepts", "json", "xml", "yaml") - assert.EqualValues(t, []string{"json", "xml", "yaml"}, r.header["X-Accepts"]) + assert.Equal(t, []string{"json", "xml", "yaml"}, r.header["X-Accepts"]) } func TestBuildRequest_SetPath(t *testing.T) { @@ -514,6 +514,7 @@ func TestBuildRequest_BuildHTTP_Files_URLEncoded(t *testing.T) { type contentTypeProvider struct { runtime.NamedReadCloser + contentType string } @@ -713,7 +714,7 @@ func TestGetBodyCallsBeforeRoundTrip(t *testing.T) { require.NoError(t, e) require.Len(t, bodyContent, int(req.ContentLength)) - require.EqualValues(t, "\"test body\"\n", string(bodyContent)) + require.Equal(t, "\"test body\"\n", string(bodyContent)) // Read the body a second time before sending the request body, e = req.GetBody() @@ -721,7 +722,7 @@ func TestGetBodyCallsBeforeRoundTrip(t *testing.T) { bodyContent, e = io.ReadAll(io.Reader(body)) require.NoError(t, e) require.Len(t, bodyContent, int(req.ContentLength)) - require.EqualValues(t, "\"test body\"\n", string(bodyContent)) + require.Equal(t, "\"test body\"\n", string(bodyContent)) }, } @@ -752,5 +753,5 @@ func TestGetBodyCallsBeforeRoundTrip(t *testing.T) { require.NoError(t, err) actual := res.(string) - require.EqualValues(t, "test result", actual) + require.Equal(t, "test result", actual) } diff --git a/client/response_test.go b/client/response_test.go index acbbf716..9368b4c2 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -34,7 +34,7 @@ func TestResponse(t *testing.T) { under.Body = io.NopCloser(bytes.NewBufferString("some content")) var resp runtime.ClientResponse = response{under} - assert.EqualValues(t, under.StatusCode, resp.Code()) + assert.Equal(t, under.StatusCode, resp.Code()) assert.Equal(t, under.Status, resp.Message()) assert.Equal(t, "blahblah", resp.GetHeader("blah")) assert.Equal(t, []string{"blahblah"}, resp.GetHeaders("blah")) diff --git a/client/runtime.go b/client/runtime.go index 5bd4d75d..2b29875e 100644 --- a/client/runtime.go +++ b/client/runtime.go @@ -32,13 +32,12 @@ import ( "sync" "time" - "github.com/go-openapi/strfmt" - "github.com/opentracing/opentracing-go" - "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/logger" "github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/yamlpc" + "github.com/go-openapi/strfmt" + "github.com/opentracing/opentracing-go" ) const ( @@ -46,6 +45,9 @@ const ( schemeHTTPS = "https" ) +// DefaultTimeout the default request timeout +var DefaultTimeout = 30 * time.Second + // TLSClientOptions to configure client authentication with mutual TLS type TLSClientOptions struct { // Certificate is the path to a PEM-encoded certificate to be used for @@ -193,13 +195,6 @@ func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) { return cfg, nil } -func basePool(pool *x509.CertPool) *x509.CertPool { - if pool == nil { - return x509.NewCertPool() - } - return pool -} - // TLSTransport creates a http client transport suitable for mutual tls auth func TLSTransport(opts TLSClientOptions) (http.RoundTripper, error) { cfg, err := TLSClientAuth(opts) @@ -219,9 +214,6 @@ func TLSClient(opts TLSClientOptions) (*http.Client, error) { return &http.Client{Transport: transport}, nil } -// DefaultTimeout the default request timeout -var DefaultTimeout = 30 * time.Second - // Runtime represents an API client that uses the transport // to make http requests based on a swagger specification. type Runtime struct { @@ -318,42 +310,6 @@ func (r *Runtime) WithOpenTelemetry(opts ...OpenTelemetryOpt) runtime.ClientTran return newOpenTelemetryTransport(r, r.Host, opts) } -func (r *Runtime) pickScheme(schemes []string) string { - if v := r.selectScheme(r.schemes); v != "" { - return v - } - if v := r.selectScheme(schemes); v != "" { - return v - } - return schemeHTTP -} - -func (r *Runtime) selectScheme(schemes []string) string { - schLen := len(schemes) - if schLen == 0 { - return "" - } - - scheme := schemes[0] - // prefer https, but skip when not possible - if scheme != schemeHTTPS && schLen > 1 { - for _, sch := range schemes { - if sch == schemeHTTPS { - scheme = sch - break - } - } - } - return scheme -} - -func transportOrDefault(left, right http.RoundTripper) http.RoundTripper { - if left == nil { - return right - } - return left -} - // EnableConnectionReuse drains the remaining body from a response // so that go will reuse the TCP connections. // @@ -376,57 +332,7 @@ func (r *Runtime) EnableConnectionReuse() { ) } -// takes a client operation and creates equivalent http.Request -func (r *Runtime) createHttpRequest(operation *runtime.ClientOperation) (*request, *http.Request, error) { //nolint:revive,stylecheck - params, _, auth := operation.Params, operation.Reader, operation.AuthInfo - - request := newRequest(operation.Method, operation.PathPattern, params) - - var accept []string - accept = append(accept, operation.ProducesMediaTypes...) - if err := request.SetHeaderParam(runtime.HeaderAccept, accept...); err != nil { - return nil, nil, err - } - - if auth == nil && r.DefaultAuthentication != nil { - auth = runtime.ClientAuthInfoWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { - if req.GetHeaderParams().Get(runtime.HeaderAuthorization) != "" { - return nil - } - return r.DefaultAuthentication.AuthenticateRequest(req, reg) - }) - } - // if auth != nil { - // if err := auth.AuthenticateRequest(request, r.Formats); err != nil { - // return nil, err - // } - //} - - // TODO: pick appropriate media type - cmt := r.DefaultMediaType - for _, mediaType := range operation.ConsumesMediaTypes { - // Pick first non-empty media type - if mediaType != "" { - cmt = mediaType - break - } - } - - if _, ok := r.Producers[cmt]; !ok && cmt != runtime.MultipartFormMime && cmt != runtime.URLencodedFormMime { - return nil, nil, fmt.Errorf("none of producers: %v registered. try %s", r.Producers, cmt) - } - - req, err := request.buildHTTP(cmt, r.BasePath, r.Producers, r.Formats, auth) - if err != nil { - return nil, nil, err - } - req.URL.Scheme = r.pickScheme(operation.Schemes) - req.URL.Host = r.Host - req.Host = r.Host - return request, req, nil -} - -func (r *Runtime) CreateHttpRequest(operation *runtime.ClientOperation) (req *http.Request, err error) { //nolint:revive,stylecheck +func (r *Runtime) CreateHttpRequest(operation *runtime.ClientOperation) (req *http.Request, err error) { //nolint:revive _, req, err = r.createHttpRequest(operation) return } @@ -550,3 +456,96 @@ func (r *Runtime) SetResponseReader(f ClientResponseFunc) { } r.response = f } + +func (r *Runtime) pickScheme(schemes []string) string { + if v := r.selectScheme(r.schemes); v != "" { + return v + } + if v := r.selectScheme(schemes); v != "" { + return v + } + return schemeHTTP +} + +func (r *Runtime) selectScheme(schemes []string) string { + schLen := len(schemes) + if schLen == 0 { + return "" + } + + scheme := schemes[0] + // prefer https, but skip when not possible + if scheme != schemeHTTPS && schLen > 1 { + for _, sch := range schemes { + if sch == schemeHTTPS { + scheme = sch + break + } + } + } + return scheme +} + +func transportOrDefault(left, right http.RoundTripper) http.RoundTripper { + if left == nil { + return right + } + return left +} + +// takes a client operation and creates equivalent http.Request +func (r *Runtime) createHttpRequest(operation *runtime.ClientOperation) (*request, *http.Request, error) { //nolint:revive + params, _, auth := operation.Params, operation.Reader, operation.AuthInfo + + request := newRequest(operation.Method, operation.PathPattern, params) + + var accept []string + accept = append(accept, operation.ProducesMediaTypes...) + if err := request.SetHeaderParam(runtime.HeaderAccept, accept...); err != nil { + return nil, nil, err + } + + if auth == nil && r.DefaultAuthentication != nil { + auth = runtime.ClientAuthInfoWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { + if req.GetHeaderParams().Get(runtime.HeaderAuthorization) != "" { + return nil + } + return r.DefaultAuthentication.AuthenticateRequest(req, reg) + }) + } + // if auth != nil { + // if err := auth.AuthenticateRequest(request, r.Formats); err != nil { + // return nil, err + // } + //} + + // TODO: pick appropriate media type + cmt := r.DefaultMediaType + for _, mediaType := range operation.ConsumesMediaTypes { + // Pick first non-empty media type + if mediaType != "" { + cmt = mediaType + break + } + } + + if _, ok := r.Producers[cmt]; !ok && cmt != runtime.MultipartFormMime && cmt != runtime.URLencodedFormMime { + return nil, nil, fmt.Errorf("none of producers: %v registered. try %s", r.Producers, cmt) + } + + req, err := request.buildHTTP(cmt, r.BasePath, r.Producers, r.Formats, auth) + if err != nil { + return nil, nil, err + } + req.URL.Scheme = r.pickScheme(operation.Schemes) + req.URL.Host = r.Host + req.Host = r.Host + return request, req, nil +} + +func basePool(pool *x509.CertPool) *x509.CertPool { + if pool == nil { + return x509.NewCertPool() + } + return pool +} diff --git a/client/runtime_test.go b/client/runtime_test.go index b01b7e26..ac753fbb 100644 --- a/client/runtime_test.go +++ b/client/runtime_test.go @@ -130,7 +130,7 @@ func TestRuntime_Concurrent(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_Canary(t *testing.T) { @@ -176,7 +176,7 @@ func TestRuntime_Canary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } type tasks struct { @@ -228,7 +228,7 @@ func TestRuntime_XMLCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, tasks{}, res) actual := res.(tasks) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_TextCanary(t *testing.T) { @@ -270,7 +270,7 @@ func TestRuntime_TextCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, "", res) actual := res.(string) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_CSVCanary(t *testing.T) { @@ -315,7 +315,7 @@ func TestRuntime_CSVCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, bytes.Buffer{}, res) actual := res.(bytes.Buffer) - assert.EqualValues(t, result, actual.String()) + assert.Equal(t, result, actual.String()) } type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -373,7 +373,7 @@ func TestRuntime_CustomTransport(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_CustomCookieJar(t *testing.T) { @@ -484,7 +484,7 @@ func TestRuntime_AuthCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_PickConsumer(t *testing.T) { @@ -536,7 +536,7 @@ func TestRuntime_PickConsumer(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_ContentTypeCanary(t *testing.T) { @@ -589,7 +589,7 @@ func TestRuntime_ContentTypeCanary(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_ChunkedResponse(t *testing.T) { @@ -644,7 +644,7 @@ func TestRuntime_ChunkedResponse(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_DebugValue(t *testing.T) { @@ -894,7 +894,7 @@ func TestRuntime_AuthHeaderParamDetected(t *testing.T) { require.NoError(t, err) assert.IsType(t, []task{}, res) actual := res.([]task) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } func TestRuntime_Timeout(t *testing.T) { //nolint:maintidx // linter evaluates the total lines of code, which is misleading @@ -1274,7 +1274,7 @@ func assertResult(result []task) func(testing.TB, interface{}) { assert.IsType(t, []task{}, res) actual, ok := res.([]task) require.True(t, ok) - assert.EqualValues(t, result, actual) + assert.Equal(t, result, actual) } } diff --git a/client/runtime_tls_test.go b/client/runtime_tls_test.go index 44c80e4b..98f33b9d 100644 --- a/client/runtime_tls_test.go +++ b/client/runtime_tls_test.go @@ -177,7 +177,7 @@ func TestRuntimeManualCertificateValidation(t *testing.T) { assert.IsType(t, []task{}, resp) assert.Truef(t, certVerifyCalled, "the client cert verification has not been called") - assert.EqualValues(t, result, received) + assert.Equal(t, result, received) } func testTLSServer(t testing.TB, fixtures *tlsFixtures, expectedResult []task) (string, func()) { diff --git a/client_response.go b/client_response.go index f89a304b..b9929946 100644 --- a/client_response.go +++ b/client_response.go @@ -20,7 +20,8 @@ import ( "io" ) -// A ClientResponse represents a client response +// A ClientResponse represents a client response. +// // This bridges between responses obtained from different transports type ClientResponse interface { Code() int @@ -44,6 +45,13 @@ type ClientResponseReader interface { ReadResponse(ClientResponse, Consumer) (interface{}, error) } +// APIError wraps an error model and captures the status code +type APIError struct { + OperationName string + Response interface{} + Code int +} + // NewAPIError creates a new API error func NewAPIError(opName string, payload interface{}, code int) *APIError { return &APIError{ @@ -53,13 +61,6 @@ func NewAPIError(opName string, payload interface{}, code int) *APIError { } } -// APIError wraps an error model and captures the status code -type APIError struct { - OperationName string - Response interface{} - Code int -} - func (o *APIError) Error() string { var resp []byte if err, ok := o.Response.(error); ok { diff --git a/csv_test.go b/csv_test.go index a9d8edfe..ba08d1be 100644 --- a/csv_test.go +++ b/csv_test.go @@ -609,7 +609,8 @@ func (r *csvEmptyReader) Read(_ []byte) (int, error) { type readerFromDummy struct { err error - b bytes.Buffer + + b bytes.Buffer } func (r *readerFromDummy) ReadFrom(rdr io.Reader) (int64, error) { @@ -629,8 +630,9 @@ func (w *writerToDummy) WriteTo(writer io.Writer) (int64, error) { } type csvWriterDummy struct { - err error *csv.Writer + + err error } func (w *csvWriterDummy) Write(record []string) error { diff --git a/internal/testing/simplepetstore/api_test.go b/internal/testing/simplepetstore/api_test.go index f3181272..dce37d83 100644 --- a/internal/testing/simplepetstore/api_test.go +++ b/internal/testing/simplepetstore/api_test.go @@ -93,7 +93,7 @@ func TestSimplePetstoreDeletePet(t *testing.T) { rw := httptest.NewRecorder() handler.ServeHTTP(rw, r) assert.Equal(t, http.StatusNoContent, rw.Code) - assert.Equal(t, "", rw.Body.String()) + assert.Empty(t, rw.Body.String()) r, err = runtime.JSONRequest(http.MethodGet, "/api/pets/1", nil) require.NoError(t, err) diff --git a/middleware/context.go b/middleware/context.go index 8718ee40..ecff9e3d 100644 --- a/middleware/context.go +++ b/middleware/context.go @@ -218,7 +218,7 @@ func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Ro // If a nil Router is provided, the DefaultRouter (denco-based) will be used. func NewRoutableContextWithAnalyzedSpec(spec *loads.Document, an *analysis.Spec, routableAPI RoutableAPI, routes Router) *Context { // Either there are no spec doc and analysis, or both of them. - if !((spec == nil && an == nil) || (spec != nil && an != nil)) { + if (spec != nil || an != nil) && (spec == nil || an == nil) { panic(errors.New(http.StatusInternalServerError, "routable context requires either both spec doc and analysis, or none of them")) } @@ -677,6 +677,15 @@ func (c *Context) APIHandler(builder Builder, opts ...UIOption) http.Handler { return Spec(specPath, c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)), specOpts...) } +// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec +func (c *Context) RoutesHandler(builder Builder) http.Handler { + b := builder + if b == nil { + b = PassthroughBuilder + } + return NewRouter(c, b(NewOperationExecutor(c))) +} + func (c Context) uiOptionsForHandler(opts []UIOption) (string, uiOptions, []SpecOption) { var title string sp := c.spec.Spec() @@ -708,15 +717,6 @@ func (c Context) uiOptionsForHandler(opts []UIOption) (string, uiOptions, []Spec return pth, uiOpts, []SpecOption{WithSpecDocument(doc)} } -// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec -func (c *Context) RoutesHandler(builder Builder) http.Handler { - b := builder - if b == nil { - b = PassthroughBuilder - } - return NewRouter(c, b(NewOperationExecutor(c))) -} - func cantFindProducer(format string) string { return "can't find a producer for " + format } diff --git a/middleware/debug_test.go b/middleware/debug_test.go index 088f412e..b89d3a60 100644 --- a/middleware/debug_test.go +++ b/middleware/debug_test.go @@ -16,6 +16,7 @@ import ( type customLogger struct { logger.StandardLogger + lg *log.Logger } diff --git a/middleware/denco/router.go b/middleware/denco/router.go index 67b2f668..7c499752 100644 --- a/middleware/denco/router.go +++ b/middleware/denco/router.go @@ -148,7 +148,7 @@ func (bc baseCheck) Base() int { } func (bc *baseCheck) SetBase(base int) { - *bc |= baseCheck(base) << flagsBits + *bc |= baseCheck(base) << flagsBits //nolint:gosec // integer conversion is ok } func (bc baseCheck) Check() byte { @@ -196,7 +196,7 @@ func (da *doubleArray) lookup(path string, params []Param, idx int) (*node, []Pa indices := make([]uint64, 0, 1) for i := 0; i < len(path); i++ { if da.bc[idx].IsAnyParam() { - indices = append(indices, (uint64(i)<= len(da.bc) || da.bc[idx].Check() != c { @@ -209,7 +209,7 @@ func (da *doubleArray) lookup(path string, params []Param, idx int) (*node, []Pa BACKTRACKING: for j := len(indices) - 1; j >= 0; j-- { - i, idx := int(indices[j]>>indexOffset), int(indices[j]&indexMask) + i, idx := int(indices[j]>>indexOffset), int(indices[j]&indexMask) //nolint:gosec // integer conversion is okay if da.bc[idx].IsSingleParam() { nextIdx := nextIndex(da.bc[idx].Base(), ParamCharacter) if nextIdx >= len(da.bc) { @@ -436,6 +436,7 @@ func NewRecord(key string, value interface{}) Record { // record represents a record that use to build the Double-Array. type record struct { Record + paramNames []string } diff --git a/middleware/parameter.go b/middleware/parameter.go index f854a406..59337aec 100644 --- a/middleware/parameter.go +++ b/middleware/parameter.go @@ -66,85 +66,6 @@ func (p *untypedParamBinder) Type() reflect.Type { return p.typeForSchema(p.parameter.Type, p.parameter.Format, p.parameter.Items) } -func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type { - switch tpe { - case "boolean": - return reflect.TypeOf(true) - - case typeString: - if tt, ok := p.formats.GetType(format); ok { - return tt - } - return reflect.TypeOf("") - - case "integer": - switch format { - case "int8": - return reflect.TypeOf(int8(0)) - case "int16": - return reflect.TypeOf(int16(0)) - case "int32": - return reflect.TypeOf(int32(0)) - case "int64": - return reflect.TypeOf(int64(0)) - default: - return reflect.TypeOf(int64(0)) - } - - case "number": - switch format { - case "float": - return reflect.TypeOf(float32(0)) - case "double": - return reflect.TypeOf(float64(0)) - } - - case typeArray: - if items == nil { - return nil - } - itemsType := p.typeForSchema(items.Type, items.Format, items.Items) - if itemsType == nil { - return nil - } - return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type() - - case "file": - return reflect.TypeOf(&runtime.File{}).Elem() - - case "object": - return reflect.TypeOf(map[string]interface{}{}) - } - return nil -} - -func (p *untypedParamBinder) allowsMulti() bool { - return p.parameter.In == "query" || p.parameter.In == "formData" -} - -func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) { - name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type - if tpe == typeArray { - if cf == "multi" { - if !p.allowsMulti() { - return nil, false, false, errors.InvalidCollectionFormat(name, in, cf) - } - vv, hasKey, _ := values.GetOK(name) - return vv, false, hasKey, nil - } - - v, hk, hv := values.GetOK(name) - if !hv { - return nil, false, hk, nil - } - d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target) - return d, c, hk, e - } - - vv, hk, _ := values.GetOK(name) - return vv, false, hk, nil -} - func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error { // fmt.Println("binding", p.name, "as", p.Type()) switch p.parameter.In { @@ -268,6 +189,85 @@ func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams } } +func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type { + switch tpe { + case "boolean": + return reflect.TypeOf(true) + + case typeString: + if tt, ok := p.formats.GetType(format); ok { + return tt + } + return reflect.TypeOf("") + + case "integer": + switch format { + case "int8": + return reflect.TypeOf(int8(0)) + case "int16": + return reflect.TypeOf(int16(0)) + case "int32": + return reflect.TypeOf(int32(0)) + case "int64": + return reflect.TypeOf(int64(0)) + default: + return reflect.TypeOf(int64(0)) + } + + case "number": + switch format { + case "float": + return reflect.TypeOf(float32(0)) + case "double": + return reflect.TypeOf(float64(0)) + } + + case typeArray: + if items == nil { + return nil + } + itemsType := p.typeForSchema(items.Type, items.Format, items.Items) + if itemsType == nil { + return nil + } + return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type() + + case "file": + return reflect.TypeOf(&runtime.File{}).Elem() + + case "object": + return reflect.TypeOf(map[string]interface{}{}) + } + return nil +} + +func (p *untypedParamBinder) allowsMulti() bool { + return p.parameter.In == "query" || p.parameter.In == "formData" +} + +func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) { + name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type + if tpe == typeArray { + if cf == "multi" { + if !p.allowsMulti() { + return nil, false, false, errors.InvalidCollectionFormat(name, in, cf) + } + vv, hasKey, _ := values.GetOK(name) + return vv, false, hasKey, nil + } + + v, hk, hv := values.GetOK(name) + if !hv { + return nil, false, hk, nil + } + d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target) + return d, c, hk, e + } + + vv, hk, _ := values.GetOK(name) + return vv, false, hk, nil +} + func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflect.Value) error { if p.parameter.Type == typeArray { return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey) diff --git a/middleware/router.go b/middleware/router.go index 249d29f5..96b748e6 100644 --- a/middleware/router.go +++ b/middleware/router.go @@ -22,18 +22,16 @@ import ( "regexp" "strings" - "github.com/go-openapi/runtime/logger" - "github.com/go-openapi/runtime/security" - "github.com/go-openapi/swag" - "github.com/go-openapi/analysis" "github.com/go-openapi/errors" "github.com/go-openapi/loads" - "github.com/go-openapi/spec" - "github.com/go-openapi/strfmt" - "github.com/go-openapi/runtime" + "github.com/go-openapi/runtime/logger" "github.com/go-openapi/runtime/middleware/denco" + "github.com/go-openapi/runtime/security" + "github.com/go-openapi/spec" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // RouteParam is a object to capture route params in a framework agnostic way. @@ -336,6 +334,7 @@ type routeEntry struct { // MatchedRoute represents the route that was matched in this request type MatchedRoute struct { routeEntry + Params RouteParams Consumer runtime.Consumer Producer runtime.Producer @@ -490,6 +489,20 @@ func (d *defaultRouteBuilder) AddRoute(method, path string, operation *spec.Oper } } +func (d *defaultRouteBuilder) Build() *defaultRouter { + routers := make(map[string]*denco.Router) + for method, records := range d.records { + router := denco.New() + _ = router.Build(records) + routers[method] = router + } + return &defaultRouter{ + spec: d.spec, + routers: routers, + debugLogf: d.debugLogf, + } +} + func (d *defaultRouteBuilder) buildAuthenticators(operation *spec.Operation) RouteAuthenticators { requirements := d.analyzer.SecurityRequirementsFor(operation) auths := make([]RouteAuthenticator, 0, len(requirements)) @@ -516,17 +529,3 @@ func (d *defaultRouteBuilder) buildAuthenticators(operation *spec.Operation) Rou } return auths } - -func (d *defaultRouteBuilder) Build() *defaultRouter { - routers := make(map[string]*denco.Router) - for method, records := range d.records { - router := denco.New() - _ = router.Build(records) - routers[method] = router - } - return &defaultRouter{ - spec: d.spec, - routers: routers, - debugLogf: d.debugLogf, - } -} diff --git a/middleware/router_test.go b/middleware/router_test.go index 9b6a47a8..6d6a56a6 100644 --- a/middleware/router_test.go +++ b/middleware/router_test.go @@ -257,7 +257,7 @@ func TestExtractCompositParameters(t *testing.T) { } for _, tc := range cases { names, values := decodeCompositParams(tc.name, tc.value, tc.pattern, nil, nil) - assert.EqualValues(t, tc.names, names) - assert.EqualValues(t, tc.values, values) + assert.Equal(t, tc.names, names) + assert.Equal(t, tc.values, values) } } diff --git a/middleware/ui_options.go b/middleware/ui_options.go index b86efa00..7a5fb16f 100644 --- a/middleware/ui_options.go +++ b/middleware/ui_options.go @@ -168,6 +168,6 @@ func serveUI(pth string, assets []byte, next http.Handler) http.Handler { rw.Header().Set(contentTypeHeader, "text/plain") rw.WriteHeader(http.StatusNotFound) - _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth))) + _, _ = fmt.Fprintf(rw, "%q not found", pth) }) } diff --git a/middleware/untyped/api.go b/middleware/untyped/api.go index 6dfbe66e..a98d6690 100644 --- a/middleware/untyped/api.go +++ b/middleware/untyped/api.go @@ -34,6 +34,22 @@ const ( mediumPreallocatedSlots = 30 ) +// API represents an untyped mux for a swagger spec +type API struct { + spec *loads.Document + analyzer *analysis.Spec + DefaultProduces string + DefaultConsumes string + consumers map[string]runtime.Consumer + producers map[string]runtime.Producer + authenticators map[string]runtime.Authenticator + authorizer runtime.Authorizer + operations map[string]map[string]runtime.OperationHandler + ServeError func(http.ResponseWriter, *http.Request, error) + Models map[string]func() any + formats strfmt.Registry +} + // NewAPI creates the default untyped API func NewAPI(spec *loads.Document) *API { var an *analysis.Spec @@ -48,26 +64,11 @@ func NewAPI(spec *loads.Document) *API { authenticators: make(map[string]runtime.Authenticator), operations: make(map[string]map[string]runtime.OperationHandler), ServeError: errors.ServeError, - Models: make(map[string]func() interface{}), + Models: make(map[string]func() any), formats: strfmt.NewFormats(), } - return api.WithJSONDefaults() -} -// API represents an untyped mux for a swagger spec -type API struct { - spec *loads.Document - analyzer *analysis.Spec - DefaultProduces string - DefaultConsumes string - consumers map[string]runtime.Consumer - producers map[string]runtime.Producer - authenticators map[string]runtime.Authenticator - authorizer runtime.Authorizer - operations map[string]map[string]runtime.OperationHandler - ServeError func(http.ResponseWriter, *http.Request, error) - Models map[string]func() interface{} - formats strfmt.Registry + return api.WithJSONDefaults() } // WithJSONDefaults loads the json defaults for this api diff --git a/security/authenticator.go b/security/authenticator.go index 14735ce5..def3746e 100644 --- a/security/authenticator.go +++ b/security/authenticator.go @@ -31,7 +31,7 @@ const ( ) // HttpAuthenticator is a function that authenticates a HTTP request -func HttpAuthenticator(handler func(*http.Request) (bool, interface{}, error)) runtime.Authenticator { //nolint:revive,stylecheck +func HttpAuthenticator(handler func(*http.Request) (bool, interface{}, error)) runtime.Authenticator { //nolint:revive return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) { if request, ok := params.(*http.Request); ok { return handler(request) diff --git a/security/basic_auth_test.go b/security/basic_auth_test.go index 162cd352..65aad8a2 100644 --- a/security/basic_auth_test.go +++ b/security/basic_auth_test.go @@ -67,7 +67,7 @@ func TestBasicAuth(t *testing.T) { ok, usr, err := ba.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.NotEmpty(t, FailedBasicAuth(req)) assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) @@ -103,7 +103,7 @@ func TestBasicAuth(t *testing.T) { ok, usr, err := br.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.Equal(t, "realm", FailedBasicAuth(req)) }) @@ -117,7 +117,7 @@ func TestBasicAuth(t *testing.T) { ok, usr, err := br.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) }) } @@ -155,7 +155,7 @@ func TestBasicAuthCtx(t *testing.T) { ok, usr, err := ba.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.Equal(t, wisdom, req.Context().Value(original)) assert.Nil(t, req.Context().Value(extra)) @@ -193,7 +193,7 @@ func TestBasicAuthCtx(t *testing.T) { ok, usr, err := br.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.Equal(t, "realm", FailedBasicAuth(req)) }) @@ -207,7 +207,7 @@ func TestBasicAuthCtx(t *testing.T) { ok, usr, err := br.Authenticate(req) require.Error(t, err) assert.True(t, ok) - assert.Equal(t, "", usr) + assert.Empty(t, usr) assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) }) }