Skip to content

Commit

Permalink
roachprod, roachtest: use same cluster name sanitization
Browse files Browse the repository at this point in the history
Previously, roachtest had it's own function to sanitize
cluster names, while roachprod had it's own function to
verify cluster names. This change removes both and opts
instead to use vm.DNSSafeName.

This change also introduces MalformedClusterNameError
which gives a hint on what is wrong with the name and
tells roachtest not to retry cluster creation.
  • Loading branch information
DarrylWong committed May 15, 2024
1 parent 7641e6e commit 27ab2c2
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 26 deletions.
2 changes: 2 additions & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ ALL_TESTS = [
"//pkg/roachprod/vm/gce:gce_test",
"//pkg/roachprod/vm/local:local_test",
"//pkg/roachprod/vm:vm_test",
"//pkg/roachprod:roachprod_test",
"//pkg/rpc/nodedialer:nodedialer_test",
"//pkg/rpc:rpc_test",
"//pkg/scheduledjobs/schedulebase:schedulebase_test",
Expand Down Expand Up @@ -1583,6 +1584,7 @@ GO_TARGETS = [
"//pkg/roachprod/vm:vm",
"//pkg/roachprod/vm:vm_test",
"//pkg/roachprod:roachprod",
"//pkg/roachprod:roachprod_test",
"//pkg/rpc/nodedialer:nodedialer",
"//pkg/rpc/nodedialer:nodedialer_test",
"//pkg/rpc/rpcpb:rpcpb",
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/roachtest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ go_test(
"//pkg/cmd/roachtest/spec",
"//pkg/cmd/roachtest/test",
"//pkg/internal/team",
"//pkg/roachprod",
"//pkg/roachprod/errors",
"//pkg/roachprod/logger",
"//pkg/roachprod/vm",
Expand Down
18 changes: 9 additions & 9 deletions pkg/cmd/roachtest/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg
}
}

func makeGCEClusterName(name string) string {
name = strings.ToLower(name)
name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-")
name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
return name
}

func makeClusterName(name string) string {
return makeGCEClusterName(name)
return vm.DNSSafeName(name)
}

// MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types.
Expand Down Expand Up @@ -847,6 +840,10 @@ func (f *clusterFactory) clusterMock(cfg clusterConfig) *clusterImpl {
}
}

// create is a hook for tests to inject their own cluster create implementation.
// i.e. unit tests that don't want to actually access a provider.
var create = roachprod.Create

// newCluster creates a new roachprod cluster.
//
// setStatus is called with status messages indicating the stage of cluster
Expand Down Expand Up @@ -957,7 +954,7 @@ func (f *clusterFactory) newCluster(

l.PrintfCtx(ctx, "Attempting cluster creation (attempt #%d/%d)", i, maxAttempts)
createVMOpts.ClusterName = c.name
err = roachprod.Create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer)
err = create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer)
if err == nil {
if err := f.r.registerCluster(c); err != nil {
return nil, nil, err
Expand All @@ -973,6 +970,9 @@ func (f *clusterFactory) newCluster(
// or a destroy from the previous iteration failed.
return nil, nil, err
}
if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) {
return nil, nil, err
}

l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err)
c.Destroy(ctx, closeLogger, l)
Expand Down
47 changes: 47 additions & 0 deletions pkg/cmd/roachtest/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestflags"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
"github.com/cockroachdb/cockroach/pkg/roachprod"
"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/cockroachdb/cockroach/pkg/testutils"
Expand Down Expand Up @@ -433,3 +434,49 @@ func TestExitCode(t *testing.T) {
err := runExitCodeTest(t, errors.New("boom"))
require.True(t, errors.Is(err, errTestsFailed))
}

func TestNewCluster(t *testing.T) {
ctx := context.Background()
factory := &clusterFactory{sem: make(chan struct{}, 1)}
cfg := clusterConfig{spec: spec.MakeClusterSpec(1)}
setStatus := func(string) {}

defer func() {
create = roachprod.Create
}()

var createCallsCounter int

testCases := []struct {
name string
createMock func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error)
expectedRetries int
}{
{
"Malformed Cluster Name Error",
func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) {
createCallsCounter++
return &roachprod.MalformedClusterNameError{}
},
1, /* expectedRetries */
},
{
"Cluster Already Exists Error",
func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) {
createCallsCounter++
return &roachprod.ClusterAlreadyExistsError{}
},
1, /* expectedRetries */
},
}

for _, c := range testCases {
t.Run(c.name, func(t *testing.T) {
createCallsCounter = 0
create = c.createMock
_, _, err := factory.newCluster(ctx, cfg, setStatus, true)
require.Error(t, err)
require.Equal(t, c.expectedRetries, createCallsCounter)
})
}
}
13 changes: 12 additions & 1 deletion pkg/roachprod/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "roachprod",
Expand Down Expand Up @@ -38,3 +38,14 @@ go_library(
"@org_golang_x_sys//unix",
],
)

go_test(
name = "roachprod_test",
srcs = ["roachprod_test.go"],
embed = [":roachprod"],
deps = [
"//pkg/roachprod/logger",
"//pkg/roachprod/vm",
"@com_github_stretchr_testify//assert",
],
)
40 changes: 29 additions & 11 deletions pkg/roachprod/roachprod.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ const (
prometheusHostUrlEnv = "COCKROACH_PROM_HOST_URL"
)

// MalformedClusterNameError is returned when the cluster name passed to Create is invalid.
type MalformedClusterNameError struct {
name string
reason string
suggestions []string
}

func (e *MalformedClusterNameError) Error() string {
return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions)
}

// findActiveAccounts is a hook for tests to inject their own FindActiveAccounts
// implementation. i.e. unit tests that don't want to actually access a provider.
var findActiveAccounts = vm.FindActiveAccounts

// verifyClusterName ensures that the given name conforms to
// our naming pattern of "<username>-<clustername>". The
// username must match one of the vm.Provider account names
Expand All @@ -75,12 +90,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
return fmt.Errorf("cluster name cannot be blank")
}

alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`)
if err != nil {
return err
}
if !alphaNum.MatchString(clusterName) {
return errors.Errorf("cluster name must match %s", alphaNum.String())
sanitizedName := vm.DNSSafeName(clusterName)
if sanitizedName != clusterName {
return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}}
}

if config.IsLocalClusterName(clusterName) {
Expand All @@ -90,17 +102,21 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
// Use the vm.Provider account names, or --username.
var accounts []string
if len(username) > 0 {
accounts = []string{username}
cleanAccount := vm.DNSSafeName(username)
if cleanAccount != username {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username)
}
accounts = []string{cleanAccount}
} else {
seenAccounts := map[string]bool{}
active, err := vm.FindActiveAccounts(l)
active, err := findActiveAccounts(l)
if err != nil {
return err
}
for _, account := range active {
if !seenAccounts[account] {
seenAccounts[account] = true
cleanAccount := vm.DNSSafeAccount(account)
cleanAccount := vm.DNSSafeName(account)
if cleanAccount != account {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account)
}
Expand All @@ -118,26 +134,28 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {

// Try to pick out a reasonable cluster name from the input.
var suffix string
var reason string
if i := strings.Index(clusterName, "-"); i != -1 {
// The user specified a username prefix, but it didn't match an active
// account name. For example, assuming the account is "peter", `roachprod
// create joe-perf` should be specified as `roachprod create joe-perf -u
// joe`.
suffix = clusterName[i+1:]
reason = "username prefix does not match an active account name"
} else {
// The user didn't specify a username prefix. For example, assuming the
// account is "peter", `roachprod create perf` should be specified as
// `roachprod create peter-perf`.
suffix = clusterName
reason = "cluster name should start with a username prefix: <username>-<clustername>"
}

// Suggest acceptable cluster names.
var suggestions []string
for _, account := range accounts {
suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix))
}
return fmt.Errorf("malformed cluster name %s, did you mean one of %s",
clusterName, suggestions)
return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions}
}

func sortedClusters() []string {
Expand Down
82 changes: 82 additions & 0 deletions pkg/roachprod/roachprod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package roachprod

import (
"io"
"testing"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/stretchr/testify/assert"
)

func nilLogger() *logger.Logger {
lcfg := logger.Config{
Stdout: io.Discard,
Stderr: io.Discard,
}
l, err := lcfg.NewLogger("" /* path */)
if err != nil {
panic(err)
}
return l
}

func TestVerifyClusterName(t *testing.T) {
findActiveAccounts = func(l *logger.Logger) (map[string]string, error) {
return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil
}
defer func() {
findActiveAccounts = vm.FindActiveAccounts
}()
cases := []struct {
description, clusterName, username string
errorExpected bool
}{
{
"username found", "user1-clustername", "", false,
},
{
"username not found", "user3-clustername", "", true,
},
{
"specified username", "user3-clustername", "user3", false,
},
{
"specified username that doesn't match", "user1-clustername", "fakeuser", true,
},
{
"clustername not sanitized", "UserName-clustername", "", true,
},
{
"no username", "clustername", "", true,
},
{
"no clustername", "user1", "", true,
},
{
"unsanitized found username", "user4-clustername", "", false,
},
{
"unsanitized specified username", "user3-clustername", "USER3", false,
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
if c.errorExpected {
assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
} else {
assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
}
})
}
}
16 changes: 13 additions & 3 deletions pkg/roachprod/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,21 +710,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) {
return zones, nil
}

// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries.
// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries.
// Unsafe characters are dropped. No length check is performed.
func DNSSafeAccount(account string) string {
func DNSSafeName(name string) string {
safe := func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= 'A' && r <= 'Z':
return unicode.ToLower(r)
case r >= '0' && r <= '9':
return r
case r == '-':
return r
default:
// Negative value tells strings.Map to drop the rune.
return -1
}
}
return strings.Map(safe, account)
name = strings.Map(safe, name)

// DNS entries cannot start or end with hyphens.
name = strings.Trim(name, "-")

// Consecutive hyphens are allowed in DNS entries, but disallow it for readability.
return regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
}

// SanitizeLabel returns a version of the string that can be used as a label.
Expand Down
10 changes: 8 additions & 2 deletions pkg/roachprod/vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) {
"dot and underscore", "u.ser_n.a_me", "username",
},
{
"Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username",
"leading and trailing hyphens", "--username-clustername-&", "username-clustername",
},
{
"consecutive hyphens", "username---clustername", "username-clustername",
},
{
"Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username",
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
assert.EqualValues(t, DNSSafeAccount(c.input), c.expected)
assert.EqualValues(t, c.expected, DNSSafeName(c.input))
})
}
}
Expand Down

0 comments on commit 27ab2c2

Please sign in to comment.