Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 160 additions & 17 deletions pkg/app/master/probe/http/swagger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -274,6 +275,135 @@ func addPathOp(m *map[string]*openapi3.Operation, op *openapi3.Operation, name s
}
}

// collectParameters merges path-level and operation-level parameters.
// Operation-level parameters override path-level ones with the same (in,name).
func collectParameters(pathItem *openapi3.PathItem, op *openapi3.Operation) []*openapi3.Parameter {
var result []*openapi3.Parameter
// index to handle overrides
key := func(p *openapi3.Parameter) string { return p.In + "\x00" + p.Name }
seen := map[string]bool{}

if pathItem != nil {
for _, pref := range pathItem.Parameters {
if pref == nil || pref.Value == nil {
continue
}
p := pref.Value
result = append(result, p)
seen[key(p)] = true
}
}

if op != nil {
for _, pref := range op.Parameters {
if pref == nil || pref.Value == nil {
continue
}
p := pref.Value
k := key(p)
if seen[k] {
for i := range result {
// override by replacing prior entry
if key(result[i]) == k {
result[i] = p
// OpenAPI params are unique per operation by (in,name). An op-level param
// overrides at most one path-level entry, so replace once and stop.
break
}
}
} else {
result = append(result, p)
seen[k] = true
}
}
}

return result
}

func paramStringForSchema(sref *openapi3.SchemaRef) string {
if sref == nil || sref.Value == nil {
return "x"
}
s := sref.Value

if len(s.Enum) > 0 {
if v, ok := s.Enum[0].(string); ok {
return v
}
return fmt.Sprint(s.Enum[0])
}

switch s.Type {
case "integer", "number":
return "1"
case "boolean":
return "true"
case "array":
return paramStringForSchema(s.Items)
case "object":
return "x"
default:
return "x"
}
}

func substitutePathParams(apiPath string, params []*openapi3.Parameter) string {
if !strings.Contains(apiPath, "{") {
return apiPath
}

for _, p := range params {
if p == nil || p.In != "path" {
continue
}
placeholder := "{" + p.Name + "}"
if strings.Contains(apiPath, placeholder) {
apiPath = strings.ReplaceAll(apiPath, placeholder, url.PathEscape(paramStringForSchema(p.Schema)))
}
}

// fallback: strip any remaining braces
if strings.Contains(apiPath, "{") {
apiPath = strings.ReplaceAll(apiPath, "{", "")
apiPath = strings.ReplaceAll(apiPath, "}", "")
}

return apiPath
}

func buildQueryAndHeaders(params []*openapi3.Parameter) (string, map[string]string) {
var parts []string
headers := make(map[string]string)
for _, p := range params {
if p == nil {
continue
}

getStringValue := func(pref *openapi3.SchemaRef) string {
if pref != nil && pref.Value != nil && pref.Value.Type == "object" {
// generate a small object and JSON-stringify it
obj, _ := genSchemaObject(pref.Value, false)
if data, err := json.Marshal(obj); err == nil {
return string(data)
}
return "{}"
}
return paramStringForSchema(pref)
}

switch p.In {
case "query":
v := getStringValue(p.Schema)
parts = append(parts, url.QueryEscape(p.Name)+"="+url.QueryEscape(v))
case "header":
v := getStringValue(p.Schema)
headers[p.Name] = v
}
}
return strings.Join(parts, "&"), headers
}

func genSchemaObject(schema *openapi3.Schema, minimal bool) (interface{}, bool) {
//todo: also need 'max' as a param to generate as many fields as possible

Expand Down Expand Up @@ -330,7 +460,7 @@ func genSchemaObject(schema *openapi3.Schema, minimal bool) (interface{}, bool)
stringVal, _ = schema.Example.(string)
} else if schema.Default != nil {
stringVal, _ = schema.Default.(string)
} else if schema.Enum != nil && len(schema.Enum) > 0 {
} else if len(schema.Enum) > 0 {
stringVal, _ = schema.Enum[0].(string)
}

Expand Down Expand Up @@ -393,16 +523,9 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
}

for apiPath, pathInfo := range spec.Paths {
//very primitive way to set the path params (will break for numeric values)
if strings.Contains(apiPath, "{") {
apiPath = strings.ReplaceAll(apiPath, "{", "")
rawRoute := apiPath
// Path param substitution is handled per operation below.

if strings.Contains(apiPath, "}") {
apiPath = strings.ReplaceAll(apiPath, "}", "")
}
}

endpoint := fmt.Sprintf("%s%s%s", addr, prefix, apiPath)
ops := pathOps(pathInfo)
for apiMethod, apiInfo := range ops {
if apiInfo == nil {
Expand All @@ -412,6 +535,19 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
continue
}

// Build endpoint for this operation using dummy path/query params
params := collectParameters(pathInfo, apiInfo)
finalPath := substitutePathParams(rawRoute, params)
endpoint := fmt.Sprintf("%s%s%s", addr, prefix, finalPath)
qstr, hdrs := buildQueryAndHeaders(params)
if qstr != "" {
if strings.Contains(endpoint, "?") {
endpoint = endpoint + "&" + qstr
} else {
endpoint = endpoint + "?" + qstr
}
}

var bodyBytes []byte
var contentType string
var formFieldName string
Expand Down Expand Up @@ -510,8 +646,8 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
}
}

//make a call (no params for now)
if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes) {
//make a call
if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes, hdrs) {
if formFieldName != "" {
//trying again with a different generated body (simple hacky version)
//retrying only for form data for now
Expand All @@ -532,7 +668,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
"op": op,
}).Debug("retrying.form.submit.image p.apiSpecEndpointCall")

if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes()) &&
if p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes(), hdrs) &&
formFieldName != "" {
strBody := strings.NewReader(data.DefaultTextJSON)
var bodyForm *bytes.Buffer
Expand All @@ -545,7 +681,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
log.WithFields(log.Fields{
"op": op,
}).Debug("retrying.form.submit.json p.apiSpecEndpointCall")
p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes())
p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyForm.Bytes(), hdrs)
}
}
}
Expand All @@ -561,7 +697,7 @@ func (p *CustomProbe) probeAPISpecEndpoints(proto, targetHost, port, prefix stri
"data": string(bodyBytes),
}).Debug("generatedSchemaObject(true)/retrying.post p.apiSpecEndpointCall")

p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes)
p.apiSpecEndpointCall(httpClient, endpoint, apiMethod, contentType, bodyBytes, hdrs)
}
}
}
Expand All @@ -574,8 +710,8 @@ func (p *CustomProbe) apiSpecEndpointCall(
method string,
contentType string,
bodyBytes []byte,
headers map[string]string,
) bool {
const op = "probe.http.CustomProbe.apiSpecEndpointCall"
maxRetryCount := p.retryCount()

notReadyErrorWait := time.Duration(16)
Expand Down Expand Up @@ -605,7 +741,14 @@ func (p *CustomProbe) apiSpecEndpointCall(
req.Header.Set(HeaderContentType, contentType)
}

//no request headers and no credentials for now
for hname, hvalue := range headers {
if strings.EqualFold(hname, HeaderContentType) {
continue
}
req.Header.Add(hname, hvalue)
}

//no credentials for now
res, err := client.Do(req)
p.CallCount.Inc()

Expand Down
Loading
Loading