Skip to content

Commit

Permalink
Add validation of parameters of all types declared by JSON schema (#69)…
Browse files Browse the repository at this point in the history
… (#76)
  • Loading branch information
nikolaas authored and fenollp committed Mar 21, 2019
1 parent 27f3ea7 commit 9acc5e9
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 303 deletions.
151 changes: 61 additions & 90 deletions openapi3filter/param_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,42 @@ import (
"github.com/getkin/kin-openapi/openapi3"
)

const errMsgInvalidSerializationF = "%s parameter %q has an invalid serialization method: style=%q, explode=%v"

// ParseErrorKind describes a kind of ParseError.
// The type simplifies comparison of errors.
type ParseErrorKind int

const (
errMsgInvalidValue = "an invalid value"
errMsgInvalidSerializationF = "%s parameter %q has an invalid serialization method: style=%q, explode=%v"
// KindOther describes an untyped parsing error.
KindOther ParseErrorKind = iota
// KindInvalidFormat describes an error that happens when a value does not conform a format
// that is required by a serialization method.
KindInvalidFormat
// KindInvalidInt describes an error that happens when a value is an invalid integer.
KindInvalidInt
// KindInvalidNumber describes an error that happens when a value is an invalid number.
KindInvalidNumber
// KindInvalidBool describes an error that happens when a value is an invalid boolean.
KindInvalidBool
)

// ParseError describes errors which happens while parse operation's parameters.
type ParseError struct {
Value interface{}
Kind ParseErrorKind
Path []interface{}
Value interface{}
Reason string
Cause error
}

func (e *ParseError) Error() string {
var msg []string
if e.Path != nil {
msg = append(msg, fmt.Sprintf("failed to parse value %v at path %v", e.Value, e.Path))
} else {
msg = append(msg, fmt.Sprintf("failed to parse value %v", e.Value))
msg = append(msg, fmt.Sprintf("path %v", e.Path))
}
if e.Value != nil {
msg = append(msg, fmt.Sprintf("value %v", e.Value))
}
if e.Reason != "" {
msg = append(msg, e.Reason)
Expand All @@ -40,6 +57,7 @@ func (e *ParseError) Error() string {
}

// decodeParameter returns a value of an operation's parameter from HTTP request.
// The function returns ParseError when HTTP request contains an invalid value of a parameter.
func decodeParameter(param *openapi3.Parameter, input *RequestValidationInput) (interface{}, error) {
var decoder interface {
DecodePrimitive(param *openapi3.Parameter) (interface{}, error)
Expand Down Expand Up @@ -103,13 +121,9 @@ func (d *pathParamDecoder) DecodePrimitive(param *openapi3.Parameter) (interface
}
src, err := cutPrefix(raw, prefix)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
val, err := parsePrimitive(src, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
return val, nil
return parsePrimitive(src, param.Schema)
}

func (d *pathParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{}, error) {
Expand Down Expand Up @@ -148,13 +162,9 @@ func (d *pathParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{}
}
src, err := cutPrefix(raw, prefix)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
val, err := parseArray(strings.Split(src, delim), param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
return val, nil
return parseArray(strings.Split(src, delim), param.Schema)
}

func (d *pathParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]interface{}, error) {
Expand Down Expand Up @@ -201,17 +211,13 @@ func (d *pathParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]i
}
src, err := cutPrefix(raw, prefix)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
props, err := propsFromString(src, propsDelim, valueDelim)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
val, err := makeObject(props, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
return val, nil
return makeObject(props, param.Schema)
}

// paramKey returns a key to get a raw value of a path parameter.
Expand All @@ -233,7 +239,11 @@ func cutPrefix(raw, prefix string) (string, error) {
return raw, nil
}
if len(raw) < len(prefix) || raw[:len(prefix)] != prefix {
return "", &ParseError{Value: raw, Reason: fmt.Sprintf("a value must be prefixed with %q", prefix)}
return "", &ParseError{
Kind: KindInvalidFormat,
Value: raw,
Reason: fmt.Sprintf("a value must be prefixed with %q", prefix),
}
}
return raw[len(prefix):], nil
}
Expand All @@ -257,11 +267,7 @@ func (d *queryParamDecoder) DecodePrimitive(param *openapi3.Parameter) (interfac
// HTTP request does not contain a value of the target query parameter.
return nil, nil
}
val, err := parsePrimitive(values[0], param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parsePrimitive(values[0], param.Schema)
}

func (d *queryParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{}, error) {
Expand Down Expand Up @@ -290,11 +296,7 @@ func (d *queryParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{
}
values = strings.Split(values[0], delim)
}
val, err := parseArray(values, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parseArray(values, param.Schema)
}

func (d *queryParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]interface{}, error) {
Expand Down Expand Up @@ -347,16 +349,12 @@ func (d *queryParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]

props, err := propsFn(d.input.GetQueryParams())
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
if props == nil {
return nil, nil
}
val, err := makeObject(props, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return makeObject(props, param.Schema)
}

// headerParamDecoder decodes values of header parameters.
Expand All @@ -374,11 +372,7 @@ func (d *headerParamDecoder) DecodePrimitive(param *openapi3.Parameter) (interfa
}

raw := d.input.Request.Header.Get(http.CanonicalHeaderKey(param.Name))
val, err := parsePrimitive(raw, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parsePrimitive(raw, param.Schema)
}

func (d *headerParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{}, error) {
Expand All @@ -391,16 +385,11 @@ func (d *headerParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface
}

raw := d.input.Request.Header.Get(http.CanonicalHeaderKey(param.Name))

val, err := parseArray(strings.Split(raw, ","), param.Schema)
if raw == "" {
// HTTP request does not contains a corresponding header
return nil, nil
}
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parseArray(strings.Split(raw, ","), param.Schema)
}

func (d *headerParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]interface{}, error) {
Expand All @@ -423,13 +412,9 @@ func (d *headerParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string
}
props, err := propsFromString(raw, ",", valueDelim)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
val, err := makeObject(props, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
return val, nil
return makeObject(props, param.Schema)
}

// cookieParamDecoder decodes values of cookie parameters.
Expand All @@ -454,11 +439,7 @@ func (d *cookieParamDecoder) DecodePrimitive(param *openapi3.Parameter) (interfa
if err != nil {
return nil, fmt.Errorf("decode param %q: %s", param.Name, err)
}
val, err := parsePrimitive(cookie.Value, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parsePrimitive(cookie.Value, param.Schema)
}

func (d *cookieParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface{}, error) {
Expand All @@ -478,11 +459,7 @@ func (d *cookieParamDecoder) DecodeArray(param *openapi3.Parameter) ([]interface
if err != nil {
return nil, fmt.Errorf("decode param %q: %s", param.Name, err)
}
val, err := parseArray(strings.Split(cookie.Value, ","), param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
return val, nil
return parseArray(strings.Split(cookie.Value, ","), param.Schema)
}

func (d *cookieParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string]interface{}, error) {
Expand All @@ -504,13 +481,9 @@ func (d *cookieParamDecoder) DecodeObject(param *openapi3.Parameter) (map[string
}
props, err := propsFromString(cookie.Value, ",", ",")
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
}
val, err := makeObject(props, param.Schema)
if err != nil {
return nil, &RequestError{Input: d.input, Parameter: param, Reason: errMsgInvalidValue, Err: err}
return nil, err
}
return val, nil
return makeObject(props, param.Schema)
}

// propsFromString returns a properties map that is created by splitting a source string by propDelim and valueDelim.
Expand All @@ -527,6 +500,7 @@ func propsFromString(src, propDelim, valueDelim string) (map[string]string, erro
// to an array with an even number of items.
if len(pairs)%2 != 0 {
return nil, &ParseError{
Kind: KindInvalidFormat,
Value: src,
Reason: fmt.Sprintf("a value must be a list of object's properties in format \"name%svalue\" separated by %s", valueDelim, propDelim),
}
Expand All @@ -543,6 +517,7 @@ func propsFromString(src, propDelim, valueDelim string) (map[string]string, erro
prop := strings.Split(pair, valueDelim)
if len(prop) != 2 {
return nil, &ParseError{
Kind: KindInvalidFormat,
Value: src,
Reason: fmt.Sprintf("a value must be a list of object's properties in format \"name%svalue\" separated by %s", valueDelim, propDelim),
}
Expand All @@ -561,12 +536,7 @@ func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string
value, err := parsePrimitive(props[propName], propSchema)
if err != nil {
if v, ok := err.(*ParseError); ok {
return nil, &ParseError{
Value: v.Value,
Reason: v.Reason,
Cause: v.Cause,
Path: []interface{}{propName},
}
return nil, &ParseError{Path: []interface{}{propName}, Cause: v}
}
return nil, err
}
Expand All @@ -584,12 +554,7 @@ func parseArray(raw []string, schemaRef *openapi3.SchemaRef) ([]interface{}, err
item, err := parsePrimitive(v, schemaRef.Value.Items)
if err != nil {
if v, ok := err.(*ParseError); ok {
return nil, &ParseError{
Value: v.Value,
Reason: v.Reason,
Cause: v.Cause,
Path: []interface{}{i},
}
return nil, &ParseError{Path: []interface{}{i}, Cause: v}
}
return nil, err
}
Expand All @@ -606,16 +571,22 @@ func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error)
return nil, nil
}
switch schema.Value.Type {
case "integer", "number":
case "integer":
v, err := strconv.ParseFloat(raw, 64)
if err != nil {
return nil, &ParseError{Kind: KindInvalidInt, Value: raw, Reason: "an invalid interger", Cause: err}
}
return v, nil
case "number":
v, err := strconv.ParseFloat(raw, 64)
if err != nil {
return nil, &ParseError{Value: raw, Cause: err}
return nil, &ParseError{Kind: KindInvalidNumber, Value: raw, Reason: "an invalid number", Cause: err}
}
return v, nil
case "boolean":
v, err := strconv.ParseBool(raw)
if err != nil {
return nil, &ParseError{Value: raw, Cause: err}
return nil, &ParseError{Kind: KindInvalidBool, Value: raw, Reason: "an invalid number", Cause: err}
}
return v, nil
case "string":
Expand Down
Loading

0 comments on commit 9acc5e9

Please sign in to comment.