From 27ab2c2377ebb2b0a416cb21d09e17541b856122 Mon Sep 17 00:00:00 2001 From: DarrylWong Date: Fri, 10 May 2024 14:42:13 -0400 Subject: [PATCH] roachprod, roachtest: use same cluster name sanitization Previously, roachtest had it's own function to sanitize cluster names, while roachprod had it's own function to verify cluster names. This change removes both and opts instead to use vm.DNSSafeName. This change also introduces MalformedClusterNameError which gives a hint on what is wrong with the name and tells roachtest not to retry cluster creation. --- pkg/BUILD.bazel | 2 + pkg/cmd/roachtest/BUILD.bazel | 1 + pkg/cmd/roachtest/cluster.go | 18 ++++---- pkg/cmd/roachtest/test_test.go | 47 +++++++++++++++++++ pkg/roachprod/BUILD.bazel | 13 +++++- pkg/roachprod/roachprod.go | 40 +++++++++++----- pkg/roachprod/roachprod_test.go | 82 +++++++++++++++++++++++++++++++++ pkg/roachprod/vm/vm.go | 16 +++++-- pkg/roachprod/vm/vm_test.go | 10 +++- 9 files changed, 203 insertions(+), 26 deletions(-) create mode 100644 pkg/roachprod/roachprod_test.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 5ade437fbf28..7139f7800c03 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -299,6 +299,7 @@ ALL_TESTS = [ "//pkg/roachprod/vm/gce:gce_test", "//pkg/roachprod/vm/local:local_test", "//pkg/roachprod/vm:vm_test", + "//pkg/roachprod:roachprod_test", "//pkg/rpc/nodedialer:nodedialer_test", "//pkg/rpc:rpc_test", "//pkg/scheduledjobs/schedulebase:schedulebase_test", @@ -1583,6 +1584,7 @@ GO_TARGETS = [ "//pkg/roachprod/vm:vm", "//pkg/roachprod/vm:vm_test", "//pkg/roachprod:roachprod", + "//pkg/roachprod:roachprod_test", "//pkg/rpc/nodedialer:nodedialer", "//pkg/rpc/nodedialer:nodedialer_test", "//pkg/rpc/rpcpb:rpcpb", diff --git a/pkg/cmd/roachtest/BUILD.bazel b/pkg/cmd/roachtest/BUILD.bazel index beb571afef96..ca0bb48f31dc 100644 --- a/pkg/cmd/roachtest/BUILD.bazel +++ b/pkg/cmd/roachtest/BUILD.bazel @@ -104,6 +104,7 @@ go_test( "//pkg/cmd/roachtest/spec", "//pkg/cmd/roachtest/test", "//pkg/internal/team", + "//pkg/roachprod", "//pkg/roachprod/errors", "//pkg/roachprod/logger", "//pkg/roachprod/vm", diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index 3d29ac8e6b51..e7cc8fbcbc88 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg } } -func makeGCEClusterName(name string) string { - name = strings.ToLower(name) - name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-") - name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-") - return name -} - func makeClusterName(name string) string { - return makeGCEClusterName(name) + return vm.DNSSafeName(name) } // MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types. @@ -847,6 +840,10 @@ func (f *clusterFactory) clusterMock(cfg clusterConfig) *clusterImpl { } } +// create is a hook for tests to inject their own cluster create implementation. +// i.e. unit tests that don't want to actually access a provider. +var create = roachprod.Create + // newCluster creates a new roachprod cluster. // // setStatus is called with status messages indicating the stage of cluster @@ -957,7 +954,7 @@ func (f *clusterFactory) newCluster( l.PrintfCtx(ctx, "Attempting cluster creation (attempt #%d/%d)", i, maxAttempts) createVMOpts.ClusterName = c.name - err = roachprod.Create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer) + err = create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer) if err == nil { if err := f.r.registerCluster(c); err != nil { return nil, nil, err @@ -973,6 +970,9 @@ func (f *clusterFactory) newCluster( // or a destroy from the previous iteration failed. return nil, nil, err } + if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) { + return nil, nil, err + } l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err) c.Destroy(ctx, closeLogger, l) diff --git a/pkg/cmd/roachtest/test_test.go b/pkg/cmd/roachtest/test_test.go index d26e8c879038..9e9012e8b128 100644 --- a/pkg/cmd/roachtest/test_test.go +++ b/pkg/cmd/roachtest/test_test.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestflags" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" + "github.com/cockroachdb/cockroach/pkg/roachprod" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/testutils" @@ -433,3 +434,49 @@ func TestExitCode(t *testing.T) { err := runExitCodeTest(t, errors.New("boom")) require.True(t, errors.Is(err, errTestsFailed)) } + +func TestNewCluster(t *testing.T) { + ctx := context.Background() + factory := &clusterFactory{sem: make(chan struct{}, 1)} + cfg := clusterConfig{spec: spec.MakeClusterSpec(1)} + setStatus := func(string) {} + + defer func() { + create = roachprod.Create + }() + + var createCallsCounter int + + testCases := []struct { + name string + createMock func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) + expectedRetries int + }{ + { + "Malformed Cluster Name Error", + func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) { + createCallsCounter++ + return &roachprod.MalformedClusterNameError{} + }, + 1, /* expectedRetries */ + }, + { + "Cluster Already Exists Error", + func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) { + createCallsCounter++ + return &roachprod.ClusterAlreadyExistsError{} + }, + 1, /* expectedRetries */ + }, + } + + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + createCallsCounter = 0 + create = c.createMock + _, _, err := factory.newCluster(ctx, cfg, setStatus, true) + require.Error(t, err) + require.Equal(t, c.expectedRetries, createCallsCounter) + }) + } +} diff --git a/pkg/roachprod/BUILD.bazel b/pkg/roachprod/BUILD.bazel index cf264585eccf..b91703411bda 100644 --- a/pkg/roachprod/BUILD.bazel +++ b/pkg/roachprod/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "roachprod", @@ -38,3 +38,14 @@ go_library( "@org_golang_x_sys//unix", ], ) + +go_test( + name = "roachprod_test", + srcs = ["roachprod_test.go"], + embed = [":roachprod"], + deps = [ + "//pkg/roachprod/logger", + "//pkg/roachprod/vm", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 3085c502e165..3a55a2286a7e 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -66,6 +66,21 @@ const ( prometheusHostUrlEnv = "COCKROACH_PROM_HOST_URL" ) +// MalformedClusterNameError is returned when the cluster name passed to Create is invalid. +type MalformedClusterNameError struct { + name string + reason string + suggestions []string +} + +func (e *MalformedClusterNameError) Error() string { + return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions) +} + +// findActiveAccounts is a hook for tests to inject their own FindActiveAccounts +// implementation. i.e. unit tests that don't want to actually access a provider. +var findActiveAccounts = vm.FindActiveAccounts + // verifyClusterName ensures that the given name conforms to // our naming pattern of "-". The // username must match one of the vm.Provider account names @@ -75,12 +90,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { return fmt.Errorf("cluster name cannot be blank") } - alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`) - if err != nil { - return err - } - if !alphaNum.MatchString(clusterName) { - return errors.Errorf("cluster name must match %s", alphaNum.String()) + sanitizedName := vm.DNSSafeName(clusterName) + if sanitizedName != clusterName { + return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}} } if config.IsLocalClusterName(clusterName) { @@ -90,17 +102,21 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Use the vm.Provider account names, or --username. var accounts []string if len(username) > 0 { - accounts = []string{username} + cleanAccount := vm.DNSSafeName(username) + if cleanAccount != username { + l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username) + } + accounts = []string{cleanAccount} } else { seenAccounts := map[string]bool{} - active, err := vm.FindActiveAccounts(l) + active, err := findActiveAccounts(l) if err != nil { return err } for _, account := range active { if !seenAccounts[account] { seenAccounts[account] = true - cleanAccount := vm.DNSSafeAccount(account) + cleanAccount := vm.DNSSafeName(account) if cleanAccount != account { l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account) } @@ -118,17 +134,20 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Try to pick out a reasonable cluster name from the input. var suffix string + var reason string if i := strings.Index(clusterName, "-"); i != -1 { // The user specified a username prefix, but it didn't match an active // account name. For example, assuming the account is "peter", `roachprod // create joe-perf` should be specified as `roachprod create joe-perf -u // joe`. suffix = clusterName[i+1:] + reason = "username prefix does not match an active account name" } else { // The user didn't specify a username prefix. For example, assuming the // account is "peter", `roachprod create perf` should be specified as // `roachprod create peter-perf`. suffix = clusterName + reason = "cluster name should start with a username prefix: -" } // Suggest acceptable cluster names. @@ -136,8 +155,7 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { for _, account := range accounts { suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix)) } - return fmt.Errorf("malformed cluster name %s, did you mean one of %s", - clusterName, suggestions) + return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions} } func sortedClusters() []string { diff --git a/pkg/roachprod/roachprod_test.go b/pkg/roachprod/roachprod_test.go new file mode 100644 index 000000000000..a0bfca0b7356 --- /dev/null +++ b/pkg/roachprod/roachprod_test.go @@ -0,0 +1,82 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package roachprod + +import ( + "io" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/stretchr/testify/assert" +) + +func nilLogger() *logger.Logger { + lcfg := logger.Config{ + Stdout: io.Discard, + Stderr: io.Discard, + } + l, err := lcfg.NewLogger("" /* path */) + if err != nil { + panic(err) + } + return l +} + +func TestVerifyClusterName(t *testing.T) { + findActiveAccounts = func(l *logger.Logger) (map[string]string, error) { + return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil + } + defer func() { + findActiveAccounts = vm.FindActiveAccounts + }() + cases := []struct { + description, clusterName, username string + errorExpected bool + }{ + { + "username found", "user1-clustername", "", false, + }, + { + "username not found", "user3-clustername", "", true, + }, + { + "specified username", "user3-clustername", "user3", false, + }, + { + "specified username that doesn't match", "user1-clustername", "fakeuser", true, + }, + { + "clustername not sanitized", "UserName-clustername", "", true, + }, + { + "no username", "clustername", "", true, + }, + { + "no clustername", "user1", "", true, + }, + { + "unsanitized found username", "user4-clustername", "", false, + }, + { + "unsanitized specified username", "user3-clustername", "USER3", false, + }, + } + for _, c := range cases { + t.Run(c.description, func(t *testing.T) { + if c.errorExpected { + assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } else { + assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } + }) + } +} diff --git a/pkg/roachprod/vm/vm.go b/pkg/roachprod/vm/vm.go index 1e92d3692e8c..a6f737d1bcd3 100644 --- a/pkg/roachprod/vm/vm.go +++ b/pkg/roachprod/vm/vm.go @@ -710,21 +710,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) { return zones, nil } -// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries. +// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries. // Unsafe characters are dropped. No length check is performed. -func DNSSafeAccount(account string) string { +func DNSSafeName(name string) string { safe := func(r rune) rune { switch { case r >= 'a' && r <= 'z': return r case r >= 'A' && r <= 'Z': return unicode.ToLower(r) + case r >= '0' && r <= '9': + return r + case r == '-': + return r default: // Negative value tells strings.Map to drop the rune. return -1 } } - return strings.Map(safe, account) + name = strings.Map(safe, name) + + // DNS entries cannot start or end with hyphens. + name = strings.Trim(name, "-") + + // Consecutive hyphens are allowed in DNS entries, but disallow it for readability. + return regexp.MustCompile(`-+`).ReplaceAllString(name, "-") } // SanitizeLabel returns a version of the string that can be used as a label. diff --git a/pkg/roachprod/vm/vm_test.go b/pkg/roachprod/vm/vm_test.go index 9f1a5e5e4530..2c3951b5c44b 100644 --- a/pkg/roachprod/vm/vm_test.go +++ b/pkg/roachprod/vm/vm_test.go @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) { "dot and underscore", "u.ser_n.a_me", "username", }, { - "Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username", + "leading and trailing hyphens", "--username-clustername-&", "username-clustername", + }, + { + "consecutive hyphens", "username---clustername", "username-clustername", + }, + { + "Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username", }, } for _, c := range cases { t.Run(c.description, func(t *testing.T) { - assert.EqualValues(t, DNSSafeAccount(c.input), c.expected) + assert.EqualValues(t, c.expected, DNSSafeName(c.input)) }) } }