diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index cec6b8365e54..af17802c1011 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg } } -func makeGCEClusterName(name string) string { - name = strings.ToLower(name) - name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-") - name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-") - return name -} - func makeClusterName(name string) string { - return makeGCEClusterName(name) + return vm.DNSSafeName(name) } // MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types. @@ -973,6 +966,9 @@ func (f *clusterFactory) newCluster( // or a destroy from the previous iteration failed. return nil, nil, err } + if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) { + return nil, nil, err + } l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err) c.Destroy(ctx, closeLogger, l) diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 1493c63ba5ec..f2f261c52c6e 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -66,6 +66,21 @@ const ( prometheusHostUrlEnv = "COCKROACH_PROM_HOST_URL" ) +// MalformedClusterNameError is returned when the cluster name passed to Create is invalid. +type MalformedClusterNameError struct { + name string + reason string + suggestions []string +} + +func (e *MalformedClusterNameError) Error() string { + return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions) +} + +// _findActiveAccounts is a test only variable, used for mocking a provider finding +// active accounts in a unit test, where we don't want to actually access a provider. +var _findActiveAccounts func(l *logger.Logger) (map[string]string, error) + // verifyClusterName ensures that the given name conforms to // our naming pattern of "-". The // username must match one of the vm.Provider account names @@ -75,12 +90,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { return fmt.Errorf("cluster name cannot be blank") } - alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`) - if err != nil { - return err - } - if !alphaNum.MatchString(clusterName) { - return errors.Errorf("cluster name must match %s", alphaNum.String()) + sanitizedName := vm.DNSSafeName(clusterName) + if sanitizedName != clusterName { + return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}} } if config.IsLocalClusterName(clusterName) { @@ -90,17 +102,29 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Use the vm.Provider account names, or --username. var accounts []string if len(username) > 0 { - accounts = []string{username} + cleanAccount := vm.DNSSafeName(username) + if cleanAccount != username { + l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username) + } + accounts = []string{cleanAccount} } else { seenAccounts := map[string]bool{} - active, err := vm.FindActiveAccounts(l) + var active map[string]string + var err error + if _findActiveAccounts != nil { + // Test only path used by unit tests. + active, err = _findActiveAccounts(l) + } else { + active, err = vm.FindActiveAccounts(l) + } + if err != nil { return err } for _, account := range active { if !seenAccounts[account] { seenAccounts[account] = true - cleanAccount := vm.DNSSafeAccount(account) + cleanAccount := vm.DNSSafeName(account) if cleanAccount != account { l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account) } @@ -118,17 +142,20 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { // Try to pick out a reasonable cluster name from the input. var suffix string + var reason string if i := strings.Index(clusterName, "-"); i != -1 { // The user specified a username prefix, but it didn't match an active // account name. For example, assuming the account is "peter", `roachprod // create joe-perf` should be specified as `roachprod create joe-perf -u // joe`. suffix = clusterName[i+1:] + reason = "username prefix does not match an active account name" } else { // The user didn't specify a username prefix. For example, assuming the // account is "peter", `roachprod create perf` should be specified as // `roachprod create peter-perf`. suffix = clusterName + reason = "cluster name should start with a username prefix: -" } // Suggest acceptable cluster names. @@ -136,8 +163,7 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error { for _, account := range accounts { suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix)) } - return fmt.Errorf("malformed cluster name %s, did you mean one of %s", - clusterName, suggestions) + return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions} } func sortedClusters() []string { diff --git a/pkg/roachprod/roachprod_test.go b/pkg/roachprod/roachprod_test.go new file mode 100644 index 000000000000..a64b661f1cc1 --- /dev/null +++ b/pkg/roachprod/roachprod_test.go @@ -0,0 +1,81 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package roachprod + +import ( + "io" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/stretchr/testify/assert" +) + +func nilLogger() *logger.Logger { + lcfg := logger.Config{ + Stdout: io.Discard, + Stderr: io.Discard, + } + l, err := lcfg.NewLogger("" /* path */) + if err != nil { + panic(err) + } + return l +} + +func TestVerifyClusterName(t *testing.T) { + _findActiveAccounts = func(l *logger.Logger) (map[string]string, error) { + return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil + } + defer func() { + _findActiveAccounts = nil + }() + cases := []struct { + description, clusterName, username string + errorExpected bool + }{ + { + "username found", "user1-clustername", "", false, + }, + { + "username not found", "user3-clustername", "", true, + }, + { + "specified username", "user3-clustername", "user3", false, + }, + { + "specified username that doesn't match", "user1-clustername", "fakeuser", true, + }, + { + "clustername not sanitized", "UserName-clustername", "", true, + }, + { + "no username", "clustername", "", true, + }, + { + "no clustername", "user1", "", true, + }, + { + "unsanitized found username", "user4-clustername", "", false, + }, + { + "unsanitized specified username", "user3-clustername", "USER3", false, + }, + } + for _, c := range cases { + t.Run(c.description, func(t *testing.T) { + if c.errorExpected { + assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } else { + assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username)) + } + }) + } +} diff --git a/pkg/roachprod/vm/vm.go b/pkg/roachprod/vm/vm.go index 1e92d3692e8c..a6f737d1bcd3 100644 --- a/pkg/roachprod/vm/vm.go +++ b/pkg/roachprod/vm/vm.go @@ -710,21 +710,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) { return zones, nil } -// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries. +// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries. // Unsafe characters are dropped. No length check is performed. -func DNSSafeAccount(account string) string { +func DNSSafeName(name string) string { safe := func(r rune) rune { switch { case r >= 'a' && r <= 'z': return r case r >= 'A' && r <= 'Z': return unicode.ToLower(r) + case r >= '0' && r <= '9': + return r + case r == '-': + return r default: // Negative value tells strings.Map to drop the rune. return -1 } } - return strings.Map(safe, account) + name = strings.Map(safe, name) + + // DNS entries cannot start or end with hyphens. + name = strings.Trim(name, "-") + + // Consecutive hyphens are allowed in DNS entries, but disallow it for readability. + return regexp.MustCompile(`-+`).ReplaceAllString(name, "-") } // SanitizeLabel returns a version of the string that can be used as a label. diff --git a/pkg/roachprod/vm/vm_test.go b/pkg/roachprod/vm/vm_test.go index 9f1a5e5e4530..2c3951b5c44b 100644 --- a/pkg/roachprod/vm/vm_test.go +++ b/pkg/roachprod/vm/vm_test.go @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) { "dot and underscore", "u.ser_n.a_me", "username", }, { - "Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username", + "leading and trailing hyphens", "--username-clustername-&", "username-clustername", + }, + { + "consecutive hyphens", "username---clustername", "username-clustername", + }, + { + "Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username", }, } for _, c := range cases { t.Run(c.description, func(t *testing.T) { - assert.EqualValues(t, DNSSafeAccount(c.input), c.expected) + assert.EqualValues(t, c.expected, DNSSafeName(c.input)) }) } }