Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gcp/observability: update method name validation #5951

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions gcp/observability/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,14 @@ import (
"errors"
"fmt"
"os"
"regexp"
"strings"

gcplogging "cloud.google.com/go/logging"
"golang.org/x/oauth2/google"
"google.golang.org/grpc/internal/envconfig"
)

const (
envProjectID = "GOOGLE_CLOUD_PROJECT"
methodStringRegexpStr = `^([\w./]+)/((?:\w+)|[*])$`
)

var methodStringRegexp = regexp.MustCompile(methodStringRegexpStr)
const envProjectID = "GOOGLE_CLOUD_PROJECT"

// fetchDefaultProjectID fetches the default GCP project id from environment.
func fetchDefaultProjectID(ctx context.Context) string {
Expand All @@ -59,6 +54,25 @@ func fetchDefaultProjectID(ctx context.Context) string {
return credentials.ProjectID
}

// validateMethodString validates whether the string passed in is a valid
// pattern.
func validateMethodString(method string) error {
if strings.HasPrefix(method, "/") {
return errors.New("cannot have a leading slash")
}
serviceMethod := strings.Split(method, "/")
if len(serviceMethod) != 2 {
return errors.New("/ must come in between service and method, only one /")
}
if serviceMethod[1] == "" {
return errors.New("method name must be non empty")
}
if serviceMethod[0] == "*" {
return errors.New("cannot have service wildcard * i.e. (*/m)")
}
return nil
}

func validateLogEventMethod(methods []string, exclude bool) error {
for _, method := range methods {
if method == "*" {
Expand All @@ -67,9 +81,8 @@ func validateLogEventMethod(methods []string, exclude bool) error {
}
continue
}
match := methodStringRegexp.FindStringSubmatch(method)
if match == nil {
return fmt.Errorf("invalid method string: %v", method)
if err := validateMethodString(method); err != nil {
return fmt.Errorf("invalid method string: %v, err: %v", method, err)
}
}
return nil
Expand Down
23 changes: 18 additions & 5 deletions gcp/observability/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -322,6 +323,7 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) {
}

type eventConfig struct {
// ServiceMethod has /s/m syntax for fast matching.
ServiceMethod map[string]bool
Services map[string]bool
MatchAll bool
Expand Down Expand Up @@ -364,6 +366,17 @@ func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger {
return nil
}

// parseMethod splits service and method from the input. It expects format
// "service/method".
func parseMethod(method string) (string, string, error) {
pos := strings.Index(method, "/")
if pos < 0 {
// Shouldn't happen, config already validated.
return "", "", errors.New("invalid method name: no / found")
}
return method[:pos], method[pos+1:], nil
}

func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter loggingExporter) {
if len(clientRPCEvents) == 0 {
return
Expand All @@ -382,15 +395,15 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging
eventConfig.MatchAll = true
continue
}
s, m, err := grpcutil.ParseMethod(method)
s, m, err := parseMethod(method)
if err != nil {
continue
}
if m == "*" {
eventConfig.Services[s] = true
continue
}
eventConfig.ServiceMethod[method] = true
eventConfig.ServiceMethod["/"+method] = true
}
eventConfigs = append(eventConfigs, eventConfig)
}
Expand Down Expand Up @@ -419,15 +432,15 @@ func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter logging
eventConfig.MatchAll = true
continue
}
s, m, err := grpcutil.ParseMethod(method)
if err != nil { // Shouldn't happen, already validated at this point.
s, m, err := parseMethod(method)
if err != nil {
continue
}
if m == "*" {
eventConfig.Services[s] = true
continue
}
eventConfig.ServiceMethod[method] = true
eventConfig.ServiceMethod["/"+method] = true
}
eventConfigs = append(eventConfigs, eventConfig)
}
Expand Down
126 changes: 118 additions & 8 deletions gcp/observability/logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -99,13 +100,14 @@ func setupObservabilitySystemWithConfig(cfg *config) (func(), error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err = Start(ctx)
if err != nil {
return nil, fmt.Errorf("error in Start: %v", err)
}
return func() {
cleanup := func() {
End()
envconfig.ObservabilityConfig = oldObservabilityConfig
}, nil
}
if err != nil {
return cleanup, fmt.Errorf("error in Start: %v", err)
}
return cleanup, nil
}

// TestClientRPCEventsLogAll tests the observability system configured with a
Expand Down Expand Up @@ -777,18 +779,18 @@ func (s) TestPrecedenceOrderingInConfiguration(t *testing.T) {
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"/grpc.testing.TestService/UnaryCall"},
Methods: []string{"grpc.testing.TestService/UnaryCall"},
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
{
Methods: []string{"/grpc.testing.TestService/EmptyCall"},
Methods: []string{"grpc.testing.TestService/EmptyCall"},
Exclude: true,
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
{
Methods: []string{"/grpc.testing.TestService/*"},
Methods: []string{"grpc.testing.TestService/*"},
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
Expand Down Expand Up @@ -1273,3 +1275,111 @@ func (s) TestMetadataTruncationAccountsKey(t *testing.T) {
}
fle.mu.Unlock()
}

// TestMethodInConfiguration tests different method names with an expectation on
// whether they should error or not.
func (s) TestMethodInConfiguration(t *testing.T) {
// To skip creating a stackdriver exporter.
fle := &fakeLoggingExporter{
t: t,
}

defer func(ne func(ctx context.Context, config *config) (loggingExporter, error)) {
newLoggingExporter = ne
}(newLoggingExporter)

newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) {
return fle, nil
}

tests := []struct {
name string
config *config
wantErr string
}{
{
name: "leading-slash",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"/service/method"},
},
},
},
},
wantErr: "cannot have a leading slash",
},
{
name: "wildcard service/method",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"*/method"},
},
},
},
},
wantErr: "cannot have service wildcard *",
},
{
name: "/ in service name",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"ser/vice/method"},
},
},
},
},
wantErr: "only one /",
},
{
name: "empty method name",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"service/"},
},
},
},
},
wantErr: "method name must be non empty",
},
{
name: "normal",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"service/method"},
},
},
},
},
wantErr: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cleanup, gotErr := setupObservabilitySystemWithConfig(test.config)
if cleanup != nil {
defer cleanup()
}
if gotErr != nil && !strings.Contains(gotErr.Error(), test.wantErr) {
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
}
if (gotErr != nil) != (test.wantErr != "") {
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
}
})
}
}