diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 5ade437fbf28..7139f7800c03 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -299,6 +299,7 @@ ALL_TESTS = [ "//pkg/roachprod/vm/gce:gce_test", "//pkg/roachprod/vm/local:local_test", "//pkg/roachprod/vm:vm_test", + "//pkg/roachprod:roachprod_test", "//pkg/rpc/nodedialer:nodedialer_test", "//pkg/rpc:rpc_test", "//pkg/scheduledjobs/schedulebase:schedulebase_test", @@ -1583,6 +1584,7 @@ GO_TARGETS = [ "//pkg/roachprod/vm:vm", "//pkg/roachprod/vm:vm_test", "//pkg/roachprod:roachprod", + "//pkg/roachprod:roachprod_test", "//pkg/rpc/nodedialer:nodedialer", "//pkg/rpc/nodedialer:nodedialer_test", "//pkg/rpc/rpcpb:rpcpb", diff --git a/pkg/cmd/roachtest/BUILD.bazel b/pkg/cmd/roachtest/BUILD.bazel index beb571afef96..ca0bb48f31dc 100644 --- a/pkg/cmd/roachtest/BUILD.bazel +++ b/pkg/cmd/roachtest/BUILD.bazel @@ -104,6 +104,7 @@ go_test( "//pkg/cmd/roachtest/spec", "//pkg/cmd/roachtest/test", "//pkg/internal/team", + "//pkg/roachprod", "//pkg/roachprod/errors", "//pkg/roachprod/logger", "//pkg/roachprod/vm", diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index 3d29ac8e6b51..e7cc8fbcbc88 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg } } -func makeGCEClusterName(name string) string { - name = strings.ToLower(name) - name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-") - name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-") - return name -} - func makeClusterName(name string) string { - return makeGCEClusterName(name) + return vm.DNSSafeName(name) } // MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types. @@ -847,6 +840,10 @@ func (f *clusterFactory) clusterMock(cfg clusterConfig) *clusterImpl { } } +// create is a hook for tests to inject their own cluster create implementation. +// i.e. unit tests that don't want to actually access a provider. +var create = roachprod.Create + // newCluster creates a new roachprod cluster. // // setStatus is called with status messages indicating the stage of cluster @@ -957,7 +954,7 @@ func (f *clusterFactory) newCluster( l.PrintfCtx(ctx, "Attempting cluster creation (attempt #%d/%d)", i, maxAttempts) createVMOpts.ClusterName = c.name - err = roachprod.Create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer) + err = create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer) if err == nil { if err := f.r.registerCluster(c); err != nil { return nil, nil, err @@ -973,6 +970,9 @@ func (f *clusterFactory) newCluster( // or a destroy from the previous iteration failed. return nil, nil, err } + if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) { + return nil, nil, err + } l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err) c.Destroy(ctx, closeLogger, l) diff --git a/pkg/cmd/roachtest/test_test.go b/pkg/cmd/roachtest/test_test.go index d26e8c879038..9e9012e8b128 100644 --- a/pkg/cmd/roachtest/test_test.go +++ b/pkg/cmd/roachtest/test_test.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestflags" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" + "github.com/cockroachdb/cockroach/pkg/roachprod" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/testutils" @@ -433,3 +434,49 @@ func TestExitCode(t *testing.T) { err := runExitCodeTest(t, errors.New("boom")) require.True(t, errors.Is(err, errTestsFailed)) } + +func TestNewCluster(t *testing.T) { + ctx := context.Background() + factory := &clusterFactory{sem: make(chan struct{}, 1)} + cfg := clusterConfig{spec: spec.MakeClusterSpec(1)} + setStatus := func(string) {} + + defer func() { + create = roachprod.Create + }() + + var createCallsCounter int + + testCases := []struct { + name string + createMock func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) + expectedRetries int + }{ + { + "Malformed Cluster Name Error", + func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) { + createCallsCounter++ + return &roachprod.MalformedClusterNameError{} + }, + 1, /* expectedRetries */ + }, + { + "Cluster Already Exists Error", + func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) { + createCallsCounter++ + return &roachprod.ClusterAlreadyExistsError{} + }, + 1, /* expectedRetries */ + }, + } + + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + createCallsCounter = 0 + create = c.createMock + _, _, err := factory.newCluster(ctx, cfg, setStatus, true) + require.Error(t, err) + require.Equal(t, c.expectedRetries, createCallsCounter) + }) + } +} diff --git a/pkg/roachprod/BUILD.bazel b/pkg/roachprod/BUILD.bazel index cf264585eccf..b91703411bda 100644 --- a/pkg/roachprod/BUILD.bazel +++ b/pkg/roachprod/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "roachprod", @@ -38,3 +38,14 @@ go_library( "@org_golang_x_sys//unix", ], ) + +go_test( + name = "roachprod_test", + srcs = ["roachprod_test.go"], + embed = [":roachprod"], + deps = [ + "//pkg/roachprod/logger", + "//pkg/roachprod/vm", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 3085c502e165..3a55a2286a7e 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -66,6 +66,21 @@ const ( prometheusHostUrlEnv = "COCKROACH_PROM_HOST_URL" ) +// MalformedClusterNameError is returned when the cluster name passed to Create is invalid. +type MalformedClusterNameError struct { + name string + reason string + suggestions []string +} + +func (e *MalformedClusterNameError) Error() string { + return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions) +} + +// findActiveAccounts is a hook for tests to inject their own FindActiveAccounts +// implementation. i.e. unit tests that don't want to actually access a provider. +var findActiveAccounts = vm.FindActiveAccounts + // verifyClusterName ensures that the given name conforms to // our naming pattern of "-". The // username must match one of the vm.Provider account names @@ -75,12 +90,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { return fmt.Errorf("cluster name cannot be blank") } - alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`) - if err != nil { - return err - } - if !alphaNum.MatchString(clusterName) { - return errors.Errorf("cluster name must match %s", alphaNum.String()) + sanitizedName := vm.DNSSafeName(clusterName) + if sanitizedName != clusterName { + return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}} } if config.IsLocalClusterName(clusterName) { @@ -90,17 +102,21 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Use the vm.Provider account names, or --username. var accounts []string if len(username) > 0 { - accounts = []string{username} + cleanAccount := vm.DNSSafeName(username) + if cleanAccount != username { + l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username) + } + accounts = []string{cleanAccount} } else { seenAccounts := map[string]bool{} - active, err := vm.FindActiveAccounts(l) + active, err := findActiveAccounts(l) if err != nil { return err } for _, account := range active { if !seenAccounts[account] { seenAccounts[account] = true - cleanAccount := vm.DNSSafeAccount(account) + cleanAccount := vm.DNSSafeName(account) if cleanAccount != account { l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account) } @@ -118,17 +134,20 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Try to pick out a reasonable cluster name from the input. var suffix string + var reason string if i := strings.Index(clusterName, "-"); i != -1 { // The user specified a username prefix, but it didn't match an active // account name. For example, assuming the account is "peter", `roachprod // create joe-perf` should be specified as `roachprod create joe-perf -u // joe`. suffix = clusterName[i+1:] + reason = "username prefix does not match an active account name" } else { // The user didn't specify a username prefix. For example, assuming the // account is "peter", `roachprod create perf` should be specified as // `roachprod create peter-perf`. suffix = clusterName + reason = "cluster name should start with a username prefix: -" } // Suggest acceptable cluster names. @@ -136,8 +155,7 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { for _, account := range accounts { suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix)) } - return fmt.Errorf("malformed cluster name %s, did you mean one of %s", - clusterName, suggestions) + return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions} } func sortedClusters() []string { diff --git a/pkg/roachprod/roachprod_test.go b/pkg/roachprod/roachprod_test.go new file mode 100644 index 000000000000..a0bfca0b7356 --- /dev/null +++ b/pkg/roachprod/roachprod_test.go @@ -0,0 +1,82 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package roachprod + +import ( + "io" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/stretchr/testify/assert" +) + +func nilLogger() *logger.Logger { + lcfg := logger.Config{ + Stdout: io.Discard, + Stderr: io.Discard, + } + l, err := lcfg.NewLogger("" /* path */) + if err != nil { + panic(err) + } + return l +} + +func TestVerifyClusterName(t *testing.T) { + findActiveAccounts = func(l *logger.Logger) (map[string]string, error) { + return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil + } + defer func() { + findActiveAccounts = vm.FindActiveAccounts + }() + cases := []struct { + description, clusterName, username string + errorExpected bool + }{ + { + "username found", "user1-clustername", "", false, + }, + { + "username not found", "user3-clustername", "", true, + }, + { + "specified username", "user3-clustername", "user3", false, + }, + { + "specified username that doesn't match", "user1-clustername", "fakeuser", true, + }, + { + "clustername not sanitized", "UserName-clustername", "", true, + }, + { + "no username", "clustername", "", true, + }, + { + "no clustername", "user1", "", true, + }, + { + "unsanitized found username", "user4-clustername", "", false, + }, + { + "unsanitized specified username", "user3-clustername", "USER3", false, + }, + } + for _, c := range cases { + t.Run(c.description, func(t *testing.T) { + if c.errorExpected { + assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } else { + assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } + }) + } +} diff --git a/pkg/roachprod/vm/vm.go b/pkg/roachprod/vm/vm.go index 1e92d3692e8c..a6f737d1bcd3 100644 --- a/pkg/roachprod/vm/vm.go +++ b/pkg/roachprod/vm/vm.go @@ -710,21 +710,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) { return zones, nil } -// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries. +// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries. // Unsafe characters are dropped. No length check is performed. -func DNSSafeAccount(account string) string { +func DNSSafeName(name string) string { safe := func(r rune) rune { switch { case r >= 'a' && r <= 'z': return r case r >= 'A' && r <= 'Z': return unicode.ToLower(r) + case r >= '0' && r <= '9': + return r + case r == '-': + return r default: // Negative value tells strings.Map to drop the rune. return -1 } } - return strings.Map(safe, account) + name = strings.Map(safe, name) + + // DNS entries cannot start or end with hyphens. + name = strings.Trim(name, "-") + + // Consecutive hyphens are allowed in DNS entries, but disallow it for readability. + return regexp.MustCompile(`-+`).ReplaceAllString(name, "-") } // SanitizeLabel returns a version of the string that can be used as a label. diff --git a/pkg/roachprod/vm/vm_test.go b/pkg/roachprod/vm/vm_test.go index 9f1a5e5e4530..2c3951b5c44b 100644 --- a/pkg/roachprod/vm/vm_test.go +++ b/pkg/roachprod/vm/vm_test.go @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) { "dot and underscore", "u.ser_n.a_me", "username", }, { - "Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username", + "leading and trailing hyphens", "--username-clustername-&", "username-clustername", + }, + { + "consecutive hyphens", "username---clustername", "username-clustername", + }, + { + "Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username", }, } for _, c := range cases { t.Run(c.description, func(t *testing.T) { - assert.EqualValues(t, DNSSafeAccount(c.input), c.expected) + assert.EqualValues(t, c.expected, DNSSafeName(c.input)) }) } }