Skip to content

Commit

Permalink
roachprod, roachtest: use same cluster name sanitization
Browse files Browse the repository at this point in the history
Previously, roachtest had it's own function to sanitize
cluster names, while roachprod had it's own function to
verify cluster names. This change removes both and opts
instead to use vm.DNSSafeName.

This change also introduces MalformedClusterNameError
which gives a hint on what is wrong with the name and
tells roachtest not to retry cluster creation.
  • Loading branch information
DarrylWong committed May 14, 2024
1 parent cd17692 commit 927338d
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 24 deletions.
12 changes: 4 additions & 8 deletions pkg/cmd/roachtest/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg
}
}

func makeGCEClusterName(name string) string {
name = strings.ToLower(name)
name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-")
name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
return name
}

func makeClusterName(name string) string {
return makeGCEClusterName(name)
return vm.DNSSafeName(name)
}

// MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types.
Expand Down Expand Up @@ -973,6 +966,9 @@ func (f *clusterFactory) newCluster(
// or a destroy from the previous iteration failed.
return nil, nil, err
}
if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) {
return nil, nil, err
}

l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err)
c.Destroy(ctx, closeLogger, l)
Expand Down
48 changes: 37 additions & 11 deletions pkg/roachprod/roachprod.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ const (
prometheusHostUrlEnv = "COCKROACH_PROM_HOST_URL"
)

// MalformedClusterNameError is returned when the cluster name passed to Create is invalid.
type MalformedClusterNameError struct {
name string
reason string
suggestions []string
}

func (e *MalformedClusterNameError) Error() string {
return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions)
}

// _findActiveAccounts is a test only variable, used for mocking a provider finding
// active accounts in a unit test, where we don't want to actually access a provider.
var _findActiveAccounts func(l *logger.Logger) (map[string]string, error)

// verifyClusterName ensures that the given name conforms to
// our naming pattern of "<username>-<clustername>". The
// username must match one of the vm.Provider account names
Expand All @@ -75,12 +90,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
return fmt.Errorf("cluster name cannot be blank")
}

alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`)
if err != nil {
return err
}
if !alphaNum.MatchString(clusterName) {
return errors.Errorf("cluster name must match %s", alphaNum.String())
sanitizedName := vm.DNSSafeName(clusterName)
if sanitizedName != clusterName {
return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}}
}

if config.IsLocalClusterName(clusterName) {
Expand All @@ -90,17 +102,29 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
// Use the vm.Provider account names, or --username.
var accounts []string
if len(username) > 0 {
accounts = []string{username}
cleanAccount := vm.DNSSafeName(username)
if cleanAccount != username {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username)
}
accounts = []string{cleanAccount}
} else {
seenAccounts := map[string]bool{}
active, err := vm.FindActiveAccounts(l)
var active map[string]string
var err error
if _findActiveAccounts != nil {
// Test only path used by unit tests.
active, err = _findActiveAccounts(l)
} else {
active, err = vm.FindActiveAccounts(l)
}

if err != nil {
return err
}
for _, account := range active {
if !seenAccounts[account] {
seenAccounts[account] = true
cleanAccount := vm.DNSSafeAccount(account)
cleanAccount := vm.DNSSafeName(account)
if cleanAccount != account {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account)
}
Expand All @@ -118,26 +142,28 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {

// Try to pick out a reasonable cluster name from the input.
var suffix string
var reason string
if i := strings.Index(clusterName, "-"); i != -1 {
// The user specified a username prefix, but it didn't match an active
// account name. For example, assuming the account is "peter", `roachprod
// create joe-perf` should be specified as `roachprod create joe-perf -u
// joe`.
suffix = clusterName[i+1:]
reason = "username prefix does not match an active account name"
} else {
// The user didn't specify a username prefix. For example, assuming the
// account is "peter", `roachprod create perf` should be specified as
// `roachprod create peter-perf`.
suffix = clusterName
reason = "cluster name should start with a username prefix: <username>-<clustername>"
}

// Suggest acceptable cluster names.
var suggestions []string
for _, account := range accounts {
suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix))
}
return fmt.Errorf("malformed cluster name %s, did you mean one of %s",
clusterName, suggestions)
return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions}
}

func sortedClusters() []string {
Expand Down
81 changes: 81 additions & 0 deletions pkg/roachprod/roachprod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package roachprod

import (
"io"
"testing"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/stretchr/testify/assert"
)

func nilLogger() *logger.Logger {
lcfg := logger.Config{
Stdout: io.Discard,
Stderr: io.Discard,
}
l, err := lcfg.NewLogger("" /* path */)
if err != nil {
panic(err)
}
return l
}

func TestVerifyClusterName(t *testing.T) {
_findActiveAccounts = func(l *logger.Logger) (map[string]string, error) {
return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil
}
defer func() {
_findActiveAccounts = nil
}()
cases := []struct {
description, clusterName, username string
errorExpected bool
}{
{
"username found", "user1-clustername", "", false,
},
{
"username not found", "user3-clustername", "", true,
},
{
"specified username", "user3-clustername", "user3", false,
},
{
"specified username that doesn't match", "user1-clustername", "fakeuser", true,
},
{
"clustername not sanitized", "UserName-clustername", "", true,
},
{
"no username", "clustername", "", true,
},
{
"no clustername", "user1", "", true,
},
{
"unsanitized found username", "user4-clustername", "", false,
},
{
"unsanitized specified username", "user3-clustername", "USER3", false,
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
if c.errorExpected {
assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
} else {
assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
}
})
}
}
16 changes: 13 additions & 3 deletions pkg/roachprod/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,21 +710,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) {
return zones, nil
}

// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries.
// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries.
// Unsafe characters are dropped. No length check is performed.
func DNSSafeAccount(account string) string {
func DNSSafeName(name string) string {
safe := func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= 'A' && r <= 'Z':
return unicode.ToLower(r)
case r >= '0' && r <= '9':
return r
case r == '-':
return r
default:
// Negative value tells strings.Map to drop the rune.
return -1
}
}
return strings.Map(safe, account)
name = strings.Map(safe, name)

// DNS entries cannot start or end with hyphens.
name = strings.Trim(name, "-")

// Consecutive hyphens are allowed in DNS entries, but disallow it for readability.
return regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
}

// SanitizeLabel returns a version of the string that can be used as a label.
Expand Down
10 changes: 8 additions & 2 deletions pkg/roachprod/vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) {
"dot and underscore", "u.ser_n.a_me", "username",
},
{
"Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username",
"leading and trailing hyphens", "--username-clustername-&", "username-clustername",
},
{
"consecutive hyphens", "username---clustername", "username-clustername",
},
{
"Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username",
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
assert.EqualValues(t, DNSSafeAccount(c.input), c.expected)
assert.EqualValues(t, c.expected, DNSSafeName(c.input))
})
}
}
Expand Down

0 comments on commit 927338d

Please sign in to comment.