From 8b823e590fd92c0244e54e80fde13503b47f5844 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Wed, 18 Jan 2023 15:34:08 -0500 Subject: [PATCH] Iterate vaild method name --- gcp/observability/config.go | 36 ++++++--- gcp/observability/logging.go | 23 ++++-- gcp/observability/logging_test.go | 126 ++++++++++++++++++++++++++++-- 3 files changed, 162 insertions(+), 23 deletions(-) diff --git a/gcp/observability/config.go b/gcp/observability/config.go index b361bc367c01..a44d3db614fe 100644 --- a/gcp/observability/config.go +++ b/gcp/observability/config.go @@ -24,19 +24,14 @@ import ( "errors" "fmt" "os" - "regexp" + "strings" gcplogging "cloud.google.com/go/logging" "golang.org/x/oauth2/google" "google.golang.org/grpc/internal/envconfig" ) -const ( - envProjectID = "GOOGLE_CLOUD_PROJECT" - methodStringRegexpStr = `^([\w./]+)/((?:\w+)|[*])$` -) - -var methodStringRegexp = regexp.MustCompile(methodStringRegexpStr) +const envProjectID = "GOOGLE_CLOUD_PROJECT" // fetchDefaultProjectID fetches the default GCP project id from environment. func fetchDefaultProjectID(ctx context.Context) string { @@ -59,6 +54,28 @@ func fetchDefaultProjectID(ctx context.Context) string { return credentials.ProjectID } +// validateMethodString validates whether the string passed in is a valid +// pattern. +func validateMethodString(method string) error { + if method == "*" { + return nil + } + if strings.HasPrefix(method, "/") { + return errors.New("cannot have a leading slash") + } + serviceMethod := strings.Split(method, "/") + if len(serviceMethod) != 2 { + return errors.New("/ must come in between service and method, only one /") + } + if serviceMethod[1] == "" { + return errors.New("method name must be non empty") + } + if serviceMethod[0] == "*" { + return errors.New("cannot have service wildcard * i.e. (*/m)") + } + return nil +} + func validateLogEventMethod(methods []string, exclude bool) error { for _, method := range methods { if method == "*" { @@ -67,9 +84,8 @@ func validateLogEventMethod(methods []string, exclude bool) error { } continue } - match := methodStringRegexp.FindStringSubmatch(method) - if match == nil { - return fmt.Errorf("invalid method string: %v", method) + if err := validateMethodString(method); err != nil { + return fmt.Errorf("invalid method string: %v, err: %v", method, err) } } return nil diff --git a/gcp/observability/logging.go b/gcp/observability/logging.go index dcd7bf848fd7..846694b6b095 100644 --- a/gcp/observability/logging.go +++ b/gcp/observability/logging.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "encoding/base64" + "errors" "fmt" "strings" "time" @@ -322,6 +323,7 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) { } type eventConfig struct { + // ServiceMethod has /s/m syntax for fast matching. ServiceMethod map[string]bool Services map[string]bool MatchAll bool @@ -364,6 +366,17 @@ func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger { return nil } +// parseMethod splits service and method from the input. It expects format +// "service/method". +func parseMethod(method string) (string, string, error) { + pos := strings.LastIndex(method, "/") + if pos < 0 { + // Shouldn't happen, config already validated. + return "", "", errors.New("invalid method name: no / found") + } + return method[:pos], method[pos+1:], nil +} + func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter loggingExporter) { if len(clientRPCEvents) == 0 { return @@ -382,7 +395,7 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging eventConfig.MatchAll = true continue } - s, m, err := grpcutil.ParseMethod(method) + s, m, err := parseMethod(method) if err != nil { continue } @@ -390,7 +403,7 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging eventConfig.Services[s] = true continue } - eventConfig.ServiceMethod[method] = true + eventConfig.ServiceMethod["/"+method] = true } eventConfigs = append(eventConfigs, eventConfig) } @@ -419,15 +432,15 @@ func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter logging eventConfig.MatchAll = true continue } - s, m, err := grpcutil.ParseMethod(method) - if err != nil { // Shouldn't happen, already validated at this point. + s, m, err := parseMethod(method) + if err != nil { continue } if m == "*" { eventConfig.Services[s] = true continue } - eventConfig.ServiceMethod[method] = true + eventConfig.ServiceMethod["/"+method] = true } eventConfigs = append(eventConfigs, eventConfig) } diff --git a/gcp/observability/logging_test.go b/gcp/observability/logging_test.go index 1489a60ea22e..0265c45ddc04 100644 --- a/gcp/observability/logging_test.go +++ b/gcp/observability/logging_test.go @@ -24,6 +24,7 @@ import ( "encoding/json" "fmt" "io" + "strings" "sync" "testing" @@ -99,13 +100,14 @@ func setupObservabilitySystemWithConfig(cfg *config) (func(), error) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() err = Start(ctx) - if err != nil { - return nil, fmt.Errorf("error in Start: %v", err) - } - return func() { + cleanup := func() { End() envconfig.ObservabilityConfig = oldObservabilityConfig - }, nil + } + if err != nil { + return cleanup, fmt.Errorf("error in Start: %v", err) + } + return cleanup, nil } // TestClientRPCEventsLogAll tests the observability system configured with a @@ -777,18 +779,18 @@ func (s) TestPrecedenceOrderingInConfiguration(t *testing.T) { CloudLogging: &cloudLogging{ ClientRPCEvents: []clientRPCEvents{ { - Methods: []string{"/grpc.testing.TestService/UnaryCall"}, + Methods: []string{"grpc.testing.TestService/UnaryCall"}, MaxMetadataBytes: 30, MaxMessageBytes: 30, }, { - Methods: []string{"/grpc.testing.TestService/EmptyCall"}, + Methods: []string{"grpc.testing.TestService/EmptyCall"}, Exclude: true, MaxMetadataBytes: 30, MaxMessageBytes: 30, }, { - Methods: []string{"/grpc.testing.TestService/*"}, + Methods: []string{"grpc.testing.TestService/*"}, MaxMetadataBytes: 30, MaxMessageBytes: 30, }, @@ -1273,3 +1275,111 @@ func (s) TestMetadataTruncationAccountsKey(t *testing.T) { } fle.mu.Unlock() } + +// TestMethodInConfiguration tests different method names with an expectation on +// whether they should error or not. +func (s) TestMethodInConfiguration(t *testing.T) { + // To skip creating a stackdriver exporter. + fle := &fakeLoggingExporter{ + t: t, + } + + defer func(ne func(ctx context.Context, config *config) (loggingExporter, error)) { + newLoggingExporter = ne + }(newLoggingExporter) + + newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) { + return fle, nil + } + + tests := []struct { + name string + config *config + wantErr string + }{ + { + name: "leading-slash", + config: &config{ + ProjectID: "fake", + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"/service/method"}, + }, + }, + }, + }, + wantErr: "cannot have a leading slash", + }, + { + name: "wildcard service/method", + config: &config{ + ProjectID: "fake", + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"*/method"}, + }, + }, + }, + }, + wantErr: "cannot have service wildcard *", + }, + { + name: "/ in service name", + config: &config{ + ProjectID: "fake", + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"ser/vice/method"}, + }, + }, + }, + }, + wantErr: "only one /", + }, + { + name: "empty method name", + config: &config{ + ProjectID: "fake", + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"service/"}, + }, + }, + }, + }, + wantErr: "method name must be non empty", + }, + { + name: "normal", + config: &config{ + ProjectID: "fake", + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"service/method"}, + }, + }, + }, + }, + wantErr: "", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cleanup, gotErr := setupObservabilitySystemWithConfig(test.config) + if cleanup != nil { + defer cleanup() + } + if gotErr != nil && !strings.Contains(gotErr.Error(), test.wantErr) { + t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr) + } + if (gotErr != nil) != (test.wantErr != "") { + t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr) + } + }) + } +}