diff --git a/openapi3filter/req_resp_encoder.go b/openapi3filter/req_resp_encoder.go index dd410f588..36b7db6fd 100644 --- a/openapi3filter/req_resp_encoder.go +++ b/openapi3filter/req_resp_encoder.go @@ -16,8 +16,34 @@ func encodeBody(body interface{}, mediaType string) ([]byte, error) { return encoder(body) } -type bodyEncoder func(body interface{}) ([]byte, error) +type BodyEncoder func(body interface{}) ([]byte, error) -var bodyEncoders = map[string]bodyEncoder{ +var bodyEncoders = map[string]BodyEncoder{ "application/json": json.Marshal, } + +func RegisterBodyEncoder(contentType string, encoder BodyEncoder) { + if contentType == "" { + panic("contentType is empty") + } + if encoder == nil { + panic("encoder is not defined") + } + bodyEncoders[contentType] = encoder +} + +// This call is not thread-safe: body encoders should not be created/destroyed by multiple goroutines. +func UnregisterBodyEncoder(contentType string) { + if contentType == "" { + panic("contentType is empty") + } + delete(bodyEncoders, contentType) +} + +// RegisteredBodyEncoder returns the registered body encoder for the given content type. +// +// If no encoder was registered for the given content type, nil is returned. +// This call is not thread-safe: body encoders should not be created/destroyed by multiple goroutines. +func RegisteredBodyEncoder(contentType string) BodyEncoder { + return bodyEncoders[contentType] +} diff --git a/openapi3filter/req_resp_encoder_test.go b/openapi3filter/req_resp_encoder_test.go new file mode 100644 index 000000000..11fe2afa9 --- /dev/null +++ b/openapi3filter/req_resp_encoder_test.go @@ -0,0 +1,43 @@ +package openapi3filter + +import ( + "fmt" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRegisterAndUnregisterBodyEncoder(t *testing.T) { + var encoder BodyEncoder + encoder = func(body interface{}) (data []byte, err error) { + return []byte(strings.Join(body.([]string), ",")), nil + } + contentType := "text/csv" + h := make(http.Header) + h.Set(headerCT, contentType) + + originalEncoder := RegisteredBodyEncoder(contentType) + require.Nil(t, originalEncoder) + + RegisterBodyEncoder(contentType, encoder) + require.Equal(t, fmt.Sprintf("%v", encoder), fmt.Sprintf("%v", RegisteredBodyEncoder(contentType))) + + body := []string{"foo", "bar"} + got, err := encodeBody(body, contentType) + + require.NoError(t, err) + require.Equal(t, []byte("foo,bar"), got) + + UnregisterBodyEncoder(contentType) + + originalEncoder = RegisteredBodyEncoder(contentType) + require.Nil(t, originalEncoder) + + _, err = encodeBody(body, contentType) + require.Equal(t, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: prefixUnsupportedCT + ` "text/csv"`, + }, err) +}