diff --git a/Gopkg.lock b/Gopkg.lock index 6586ed00..9b4d6140 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -365,6 +365,7 @@ "github.com/stretchr/testify/assert", "go.opencensus.io/plugin/ocgrpc", "go.opencensus.io/plugin/ochttp", + "go.opencensus.io/plugin/ochttp/propagation/b3", "go.opencensus.io/trace", "go.opencensus.io/trace/propagation", "golang.org/x/net/context", diff --git a/gateway/header.go b/gateway/header.go index b47c5375..c78c79a3 100644 --- a/gateway/header.go +++ b/gateway/header.go @@ -114,19 +114,71 @@ func handleForwardResponseTrailer(w http.ResponseWriter, md runtime.ServerMetada } } +// AtlasDefaultHeaderMatcher func used to add all headers used by atlas-app-toolkit +// This function also passes through all the headers that runtime.DefaultHeaderMatcher handles. +// AtlasDefaultHeaderMatcher can be used as a Incoming/Outgoing header matcher. +func AtlasDefaultHeaderMatcher() func(string) (string, bool) { + //Put headers only in lower case + allow := map[string]struct{}{ + //X-Geo-* headers are set of geo metadata from MaxMind DB injected on ingress nginx + "x-geo-org": struct{}{}, + "x-geo-country-code": struct{}{}, + "x-geo-country-name": struct{}{}, + "x-geo-region-code": struct{}{}, + "x-geo-region-name": struct{}{}, + "x-geo-city-name": struct{}{}, + "x-geo-postal-code": struct{}{}, + "x-geo-latitude": struct{}{}, + "x-geo-longitude": struct{}{}, + //request id header contains unique identifier for request + "request-id": struct{}{}, + //Tracing headers + "x-b3-traceid": struct{}{}, + "x-b3-parentspanid": struct{}{}, + "x-b3-spanid": struct{}{}, + "x-b3-sampled": struct{}{}, + } + + return func(h string) (string, bool) { + if key, ok := runtime.DefaultHeaderMatcher(h); ok { + return key, ok + } + + _, ok := allow[strings.ToLower(h)] + return h, ok + } +} + // ExtendedDefaultHeaderMatcher func is used to add custom headers to be matched // from incoming http requests, If this returns true the header will be added to grpc context. -// This function also passes through all the headers that runtime.DefaultHeaderMatcher handles. +// This function also passes through all the headers that AtlasDefaultHeaderMatcher handles. func ExtendedDefaultHeaderMatcher(headerNames ...string) func(string) (string, bool) { customHeaders := map[string]bool{} for _, name := range headerNames { customHeaders[strings.ToLower(name)] = true } + + atlasMatcher := AtlasDefaultHeaderMatcher() return func(headerName string) (string, bool) { - if key, ok := runtime.DefaultHeaderMatcher(headerName); ok { + if key, ok := atlasMatcher(headerName); ok { return key, ok } _, ok := customHeaders[strings.ToLower(headerName)] return headerName, ok } } + +// ChainHeaderMatcher func is used to build chain on header matcher funcitons +// this function can be used as incoming or outgoing header matcher +// keep in mind that gRPC metadata treat as case insensitive strings +func ChainHeaderMatcher(matchers ...runtime.HeaderMatcherFunc) runtime.HeaderMatcherFunc { + return func(h string) (string, bool) { + for _, m := range matchers { + if k, allow := m(h); allow { + return k, allow + } + } + + return "", false + } +} diff --git a/gateway/header_test.go b/gateway/header_test.go index 5c4342ff..c9dda684 100644 --- a/gateway/header_test.go +++ b/gateway/header_test.go @@ -97,7 +97,7 @@ func TestExtendedDefaultHeaderMatcher(t *testing.T) { { name: "custom headers in | without custom headers | failure", customHeaders: []string{}, - in: "Request-Id", + in: "CustomHeader", isValid: false, }, } @@ -111,3 +111,163 @@ func TestExtendedDefaultHeaderMatcher(t *testing.T) { }) } } + +func TestAtlasDefaultHeaderMatcher(t *testing.T) { + var customMatcherTests = []struct { + name string + in string + isValid bool + }{ + { + name: "X-Geo-Org | success", + in: "X-Geo-Org", + isValid: true, + }, + { + name: "X-Geo-Country-Code | success", + in: "X-Geo-Country-Code", + isValid: true, + }, + { + name: "X-Geo-Country-Name | success", + in: "X-Geo-Country-Name", + isValid: true, + }, + { + name: "X-Geo-Region-Code | success", + in: "X-Geo-Region-Code", + isValid: true, + }, + { + name: "X-Geo-Region-Name | success", + in: "X-Geo-Region-Name", + isValid: true, + }, + { + name: "X-Geo-City-Name | success", + in: "X-Geo-City-Name", + isValid: true, + }, + { + name: "X-Geo-Postal-Code | success", + in: "X-Geo-Postal-Code", + isValid: true, + }, + { + name: "X-Geo-Latitude | success", + in: "X-Geo-Latitude", + isValid: true, + }, + { + name: "X-Geo-Longitude | success", + in: "X-Geo-Longitude", + isValid: true, + }, + { + name: "Request-Id | success", + in: "Request-Id", + isValid: true, + }, + { + name: "X-B3-TraceId | success", + in: "X-B3-TraceId", + isValid: true, + }, + { + name: "X-B3-ParentSpanId | success", + in: "X-B3-ParentSpanId", + isValid: true, + }, + { + name: "X-B3-SpanId | success", + in: "X-B3-SpanId", + isValid: true, + }, + { + name: "X-B3-Sampled | success", + in: "X-B3-Sampled", + isValid: true, + }, + { + name: "x-b3-sampled | success", + in: "x-b3-sampled", + isValid: true, + }, + { + name: "Failed-Header | failure", + in: "Failed-Header", + isValid: false, + }, + } + for _, tt := range customMatcherTests { + t.Run(tt.name, func(t *testing.T) { + f := AtlasDefaultHeaderMatcher() + _, ok := f(tt.in) + if ok != tt.isValid { + t.Errorf("got %v, want %v", ok, tt.isValid) + } + }) + } +} + +func TestChainHeaderMatcher(t *testing.T) { + chain := ChainHeaderMatcher( + func(h string) (string, bool) { + if h == "first" { + return h, true + } + + return "", false + }, + func(h string) (string, bool) { + if h == "second" { + return h, true + } + + return "", false + }, + func(h string) (string, bool) { + if h == "third" { + return h, true + } + + return "", false + }, + ) + + var customMatcherTests = []struct { + name string + in string + isValid bool + }{ + { + name: "first | success", + in: "first", + isValid: true, + }, + { + name: "second | success", + in: "second", + isValid: true, + }, + { + name: "third | success", + in: "third", + isValid: true, + }, + { + name: "fourth | success", + in: "fourth", + isValid: false, + }, + } + + for _, tt := range customMatcherTests { + t.Run(tt.name, func(t *testing.T) { + _, ok := chain(tt.in) + if ok != tt.isValid { + t.Errorf("got %v, want %v", ok, tt.isValid) + } + }) + } +} diff --git a/server/server.go b/server/server.go index 68c68da8..3c49c7b8 100644 --- a/server/server.go +++ b/server/server.go @@ -12,6 +12,7 @@ import ( "errors" + "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/infobloxopen/atlas-app-toolkit/gateway" "github.com/infobloxopen/atlas-app-toolkit/health" "google.golang.org/grpc" @@ -135,14 +136,18 @@ func WithHealthChecks(checker health.Checker) Option { func WithGateway(options ...gateway.Option) Option { return func(s *Server) error { s.registrars = append(s.registrars, func(mux *http.ServeMux) error { - _, err := gateway.NewGateway(append(options, gateway.WithMux(mux))...) + _, err := gateway.NewGateway(append(options, + gateway.WithGatewayOptions( + runtime.WithIncomingHeaderMatcher( + gateway.AtlasDefaultHeaderMatcher())), + gateway.WithMux(mux))...) return err }) return nil } } -// WithMiddlewaries add opportunity to add different middleware +// WithMiddlewares add opportunity to add different middleware func WithMiddlewares(middleware ...Middleware) Option { return func(s *Server) error { s.middlewares = append(s.middlewares, middleware...)