diff --git a/pkg/app/master/probe/http/swagger.go b/pkg/app/master/probe/http/swagger.go index 83b57431..5d90ca11 100644 --- a/pkg/app/master/probe/http/swagger.go +++ b/pkg/app/master/probe/http/swagger.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "encoding/json" "fmt" "io" "net/http" @@ -274,6 +275,135 @@ func addPathOp(m *map[string]*openapi3.Operation, op *openapi3.Operation, name s } } +// collectParameters merges path-level and operation-level parameters. +// Operation-level parameters override path-level ones with the same (in,name). +func collectParameters(pathItem *openapi3.PathItem, op *openapi3.Operation) []*openapi3.Parameter { + var result []*openapi3.Parameter + // index to handle overrides + key := func(p *openapi3.Parameter) string { return p.In + "\x00" + p.Name } + seen := map[string]bool{} + + if pathItem != nil { + for _, pref := range pathItem.Parameters { + if pref == nil || pref.Value == nil { + continue + } + p := pref.Value + result = append(result, p) + seen[key(p)] = true + } + } + + if op != nil { + for _, pref := range op.Parameters { + if pref == nil || pref.Value == nil { + continue + } + p := pref.Value + k := key(p) + if seen[k] { + for i := range result { + // override by replacing prior entry + if key(result[i]) == k { + result[i] = p + // OpenAPI params are unique per operation by (in,name). An op-level param + // overrides at most one path-level entry, so replace once and stop. + break + } + } + } else { + result = append(result, p) + seen[k] = true + } + } + } + + return result +} + +func paramStringForSchema(sref *openapi3.SchemaRef) string { + if sref == nil || sref.Value == nil { + return "x" + } + s := sref.Value + + if len(s.Enum) > 0 { + if v, ok := s.Enum[0].(string); ok { + return v + } + return fmt.Sprint(s.Enum[0]) + } + + switch s.Type { + case "integer", "number": + return "1" + case "boolean": + return "true" + case "array": + return paramStringForSchema(s.Items) + case "object": + return "x" + default: + return "x" + } +} + +func substitutePathParams(apiPath string, params []*openapi3.Parameter) string { + if !strings.Contains(apiPath, "{") { + return apiPath + } + + for _, p := range params { + if p == nil || p.In != "path" { + continue + } + placeholder := "{" + p.Name + "}" + if strings.Contains(apiPath, placeholder) { + apiPath = strings.ReplaceAll(apiPath, placeholder, url.PathEscape(paramStringForSchema(p.Schema))) + } + } + + // fallback: strip any remaining braces + if strings.Contains(apiPath, "{") { + apiPath = strings.ReplaceAll(apiPath, "{", "") + apiPath = strings.ReplaceAll(apiPath, "}", "") + } + + return apiPath +} + +func buildQueryAndHeaders(params []*openapi3.Parameter) (string, map[string]string) { + var parts []string + headers := make(map[string]string) + for _, p := range params { + if p == nil { + continue + } + + getStringValue := func(pref *openapi3.SchemaRef) string { + if pref != nil && pref.Value != nil && pref.Value.Type == "object" { + // generate a small object and JSON-stringify it + obj, _ := genSchemaObject(pref.Value, false) + if data, err := json.Marshal(obj); err == nil { + return string(data) + } + return "{}" + } + return paramStringForSchema(pref) + } + + switch p.In { + case "query": + v := getStringValue(p.Schema) + parts = append(parts, url.QueryEscape(p.Name)+"="+url.QueryEscape(v)) + case "header": + v := getStringValue(p.Schema) + headers[p.Name] = v + } + } + return strings.Join(parts, "&"), headers +} + func genSchemaObject(schema *openapi3.Schema, minimal bool) (interface{}, bool) { //todo: also need 'max' as a param to generate as many fields as possible @@ -330,7 +460,7 @@ func genSchemaObject(schema *openapi3.Schema, minimal bool) (interface{}, bool) stringVal, _ = schema.Example.(string) } else if schema.Default != nil { stringVal, _ = schema.Default.(string) - } else if schema.Enum != nil && len(schema.Enum) > 0 { + } else if len(schema.Enum) > 0 { stringVal, _ = schema.Enum[0].(string) } @@ -393,16 +523,9 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri } for apiPath, pathInfo := range spec.Paths { - //very primitive way to set the path params (will break for numeric values) - if strings.Contains(apiPath, "{") { - apiPath = strings.ReplaceAll(apiPath, "{", "") + rawRoute := apiPath + // Path param substitution is handled per operation below. - if strings.Contains(apiPath, "}") { - apiPath = strings.ReplaceAll(apiPath, "}", "") - } - } - - endpoint := fmt.Sprintf("%s%s%s", addr, prefix, apiPath) ops := pathOps(pathInfo) for apiMethod, apiInfo := range ops { if apiInfo == nil { @@ -412,6 +535,19 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri continue } + // Build endpoint for this operation using dummy path/query params + params := collectParameters(pathInfo, apiInfo) + finalPath := substitutePathParams(rawRoute, params) + endpoint := fmt.Sprintf("%s%s%s", addr, prefix, finalPath) + qstr, hdrs := buildQueryAndHeaders(params) + if qstr != "" { + if strings.Contains(endpoint, "?") { + endpoint = endpoint + "&" + qstr + } else { + endpoint = endpoint + "?" + qstr + } + } + var bodyBytes []byte var contentType string var formFieldName string @@ -510,8 +646,8 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri } } - //make a call (no params for now) - if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes) { + //make a call + if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes, hdrs) { if formFieldName != "" { //trying again with a different generated body (simple hacky version) //retrying only for form data for now @@ -532,7 +668,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri "op": op, }).Debug("retrying.form.submit.image p.apiSpecEndpointCall") - if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes()) && + if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes(), hdrs) && formFieldName != "" { strBody := strings.NewReader(data.DefaultTextJSON) var bodyForm *bytes.Buffer @@ -545,7 +681,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri log.WithFields(log.Fields{ "op": op, }).Debug("retrying.form.submit.json p.apiSpecEndpointCall") - p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes()) + p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes(), hdrs) } } } @@ -561,7 +697,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri "data": string(bodyBytes), }).Debug("generatedSchemaObject(true)/retrying.post p.apiSpecEndpointCall") - p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes) + p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes, hdrs) } } } @@ -574,8 +710,8 @@ func (p *CustomProbe) apiSpecEndpointCall( method string, contentType string, bodyBytes []byte, + headers map[string]string, ) bool { - const op = "probe.http.CustomProbe.apiSpecEndpointCall" maxRetryCount := p.retryCount() notReadyErrorWait := time.Duration(16) @@ -605,7 +741,14 @@ func (p *CustomProbe) apiSpecEndpointCall( req.Header.Set(HeaderContentType, contentType) } - //no request headers and no credentials for now + for hname, hvalue := range headers { + if strings.EqualFold(hname, HeaderContentType) { + continue + } + req.Header.Add(hname, hvalue) + } + + //no credentials for now res, err := client.Do(req) p.CallCount.Inc() diff --git a/pkg/app/master/probe/http/swagger_test.go b/pkg/app/master/probe/http/swagger_test.go new file mode 100644 index 00000000..5bba0765 --- /dev/null +++ b/pkg/app/master/probe/http/swagger_test.go @@ -0,0 +1,178 @@ +package http + +import ( + "net/url" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" +) + +func TestSubstitutePathParams(t *testing.T) { + apiPath := "/pets/{id}/owners/{ownerId}" + + params := []*openapi3.Parameter{ + { + Name: "id", + In: "path", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "integer", + }}, + }, + { + Name: "ownerId", + In: "path", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "string", + }}, + }, + } + + got := substitutePathParams(apiPath, params) + want := "/pets/1/owners/x" + if got != want { + t.Fatalf("substitutePathParams() = %q; want %q", got, want) + } +} + +func TestBuildQueryAndHeaders(t *testing.T) { + params := []*openapi3.Parameter{ + { + Name: "q", + In: "query", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "string", + }}, + }, + { + Name: "count", + In: "query", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "integer", + }}, + }, + { + Name: "meta", + In: "query", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "object", + Properties: openapi3.Schemas{"x": &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "string"}}}, + }}, + }, + { + Name: "X-Token", + In: "header", + Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "string", + }}, + }, + } + + qs, headers := buildQueryAndHeaders(params) + values, err := url.ParseQuery(qs) + if err != nil { + t.Fatalf("failed to parse query string: %v", err) + } + if values.Get("q") != "x" { + t.Fatalf("query param q = %q; want %q", values.Get("q"), "x") + } + if values.Get("count") != "1" { + t.Fatalf("query param count = %q; want %q", values.Get("count"), "1") + } + if headers["X-Token"] != "x" { + t.Fatalf("header X-Token = %q; want %q", headers["X-Token"], "x") + } + + // object param should be JSON-stringified + if values.Get("meta") == "" { + t.Fatalf("expected meta query param to be present") + } + if !strings.HasPrefix(values.Get("meta"), "{") { + t.Fatalf("expected meta to be JSON string, got: %q", values.Get("meta")) + } +} + +func TestCollectParameters_MergeAndOverride(t *testing.T) { + pathItem := &openapi3.PathItem{ + Parameters: openapi3.Parameters{ + &openapi3.ParameterRef{Value: &openapi3.Parameter{Name: "x", In: "query", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "string"}}}}, + &openapi3.ParameterRef{Value: &openapi3.Parameter{Name: "id", In: "path", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}}}}, + }, + } + op := &openapi3.Operation{ + Parameters: openapi3.Parameters{ + // override x + &openapi3.ParameterRef{Value: &openapi3.Parameter{Name: "x", In: "query", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}}}}, + // add new y + &openapi3.ParameterRef{Value: &openapi3.Parameter{Name: "y", In: "query", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "boolean"}}}}, + }, + } + + merged := collectParameters(pathItem, op) + get := func(in, name string) *openapi3.Parameter { + for _, p := range merged { + if p != nil && p.In == in && p.Name == name { + return p + } + } + return nil + } + + if p := get("query", "x"); p == nil || p.Schema == nil || p.Schema.Value == nil || p.Schema.Value.Type != "integer" { + t.Fatalf("expected op-level override for x to be integer, got: %#v", p) + } + if p := get("query", "y"); p == nil || p.Schema == nil || p.Schema.Value == nil || p.Schema.Value.Type != "boolean" { + t.Fatalf("expected new param y boolean, got: %#v", p) + } + if p := get("path", "id"); p == nil || p.Schema == nil || p.Schema.Value == nil || p.Schema.Value.Type != "integer" { + t.Fatalf("expected path id integer preserved, got: %#v", p) + } +} + +func TestParamStringForSchema_TypesAndEnum(t *testing.T) { + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Enum: []interface{}{"A", "B"}}}); got != "A" { + t.Fatalf("enum preferred value = %q; want %q", got, "A") + } + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}}); got != "1" { + t.Fatalf("integer -> %q; want %q", got, "1") + } + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Type: "number"}}); got != "1" { + t.Fatalf("number -> %q; want %q", got, "1") + } + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Type: "boolean"}}); got != "true" { + t.Fatalf("boolean -> %q; want %q", got, "true") + } + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Type: "array", Items: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}}}}); got != "1" { + t.Fatalf("array[integer] -> %q; want %q", got, "1") + } + if got := paramStringForSchema(&openapi3.SchemaRef{Value: &openapi3.Schema{Type: "object"}}); got != "x" { + t.Fatalf("object -> %q; want %q", got, "x") + } +} + +func TestSubstitutePathParams_FallbackStripsUnknown(t *testing.T) { + apiPath := "/stores/{known}/items/{unknown}" + params := []*openapi3.Parameter{ + {Name: "known", In: "path", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "string"}}}, + } + got := substitutePathParams(apiPath, params) + want := "/stores/x/items/unknown" + if got != want { + t.Fatalf("fallback strip = %q; want %q", got, want) + } +} + +func TestBuildQueryAndHeaders_ArrayHandling(t *testing.T) { + params := []*openapi3.Parameter{ + {Name: "ids", In: "query", Schema: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "array", Items: &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}}}}}, + } + qs, _ := buildQueryAndHeaders(params) + values, err := url.ParseQuery(qs) + if err != nil { + t.Fatalf("failed to parse query string: %v", err) + } + if values.Get("ids") != "1" { + t.Fatalf("array item -> %q; want %q", values.Get("ids"), "1") + } +} +