Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cassandra: Refactor PEM parsing logic #11861

Merged
merged 21 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 50 additions & 11 deletions helper/testhelpers/cassandra/cassandrahelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,41 @@ import (
)

type containerConfig struct {
version string
copyFromTo map[string]string
sslOpts *gocql.SslOptions
containerName string
doNotAppendUUID bool
imageName string
version string
copyFromTo map[string]string
env []string

sslOpts *gocql.SslOptions
}

type ContainerOpt func(*containerConfig)

func ContainerName(name string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.containerName = name
}
}

func Image(imageName string, version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.imageName = imageName
cfg.version = version

// Reset the environment because there's a very good chance the default environment doesn't apply to the
// non-default image being used
cfg.env = nil
}
}

func DoNotAppendUUID(doNotAppendUUID bool) ContainerOpt {
return func(cfg *containerConfig) {
cfg.doNotAppendUUID = doNotAppendUUID
}
}

func Version(version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.version = version
Expand All @@ -33,6 +61,12 @@ func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
}
}

func Env(keyValue string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.env = append(cfg.env, keyValue)
}
}

func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
return func(cfg *containerConfig) {
cfg.sslOpts = sslOpts
Expand Down Expand Up @@ -63,7 +97,9 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
}

containerCfg := &containerConfig{
version: "3.11",
imageName: "cassandra",
version: "3.11",
env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
}

for _, opt := range opts {
Expand All @@ -79,13 +115,16 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
copyFromTo[absFrom] = to
}

runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "cassandra",
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
})
runOpts := docker.RunOptions{
ContainerName: containerCfg.containerName,
DoNotAppendUUID: containerCfg.doNotAppendUUID,
ImageRepo: containerCfg.imageName,
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: containerCfg.env,
}
runner, err := docker.NewServiceRunner(runOpts)
if err != nil {
t.Fatalf("Could not start docker cassandra: %s", err)
}
Expand Down
21 changes: 14 additions & 7 deletions helper/testhelpers/docker/testhelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ type Runner struct {
}

type RunOptions struct {
ImageRepo string
ImageTag string
ContainerName string
ImageRepo string
ImageTag string

ContainerName string
// DoNotAppendUUID to the container name for uniqueness
DoNotAppendUUID bool
calvn marked this conversation as resolved.
Show resolved Hide resolved

Cmd []string
Env []string
NetworkID string
Expand Down Expand Up @@ -186,11 +190,14 @@ type Service struct {
}

func (d *Runner) Start(ctx context.Context) (*types.ContainerJSON, []string, error) {
suffix, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
name := d.RunOptions.ContainerName
if !d.RunOptions.DoNotAppendUUID {
suffix, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
}
name += "-" + suffix
}
name := d.RunOptions.ContainerName + "-" + suffix

cfg := &container.Config{
Hostname: name,
Expand Down
27 changes: 13 additions & 14 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
}

func TestInitialize(t *testing.T) {
// getCassandra performs an Initialize call
db, cleanup := getCassandra(t, 4)
defer cleanup()

err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}

db, cleanup = getCassandra(t, "4")
calvn marked this conversation as resolved.
Show resolved Hide resolved
defer cleanup()
}

func TestCreateUser(t *testing.T) {
Expand All @@ -74,7 +72,7 @@ func TestCreateUser(t *testing.T) {
newUserReq dbplugin.NewUserRequest
expectErr bool
expectedUsernameRegex string
assertCreds func(t testing.TB, address string, port int, username, password string, timeout time.Duration)
assertCreds func(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration)
}

tests := map[string]testCase{
Expand Down Expand Up @@ -160,7 +158,7 @@ func TestCreateUser(t *testing.T) {
t.Fatalf("no error expected, got: %s", err)
}
require.Regexp(t, test.expectedUsernameRegex, newUserResp.Username)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, 5*time.Second)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, nil, 5*time.Second)
})
}
}
Expand All @@ -184,7 +182,7 @@ func TestUpdateUserPassword(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

newPassword := "somenewpassword"
updateReq := dbplugin.UpdateUserRequest{
Expand All @@ -198,7 +196,7 @@ func TestUpdateUserPassword(t *testing.T) {

dbtesting.AssertUpdateUser(t, db, updateReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, nil, 5*time.Second)
}

func TestDeleteUser(t *testing.T) {
Expand All @@ -220,21 +218,21 @@ func TestDeleteUser(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

deleteReq := dbplugin.DeleteUserRequest{
Username: createResp.Username,
}

dbtesting.AssertDeleteUser(t, db, deleteReq)

assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)
}

func assertCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()
op := func() error {
return connect(t, address, port, username, password)
return connect(t, address, port, username, password, sslOpts)
}
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = timeout
Expand All @@ -248,7 +246,7 @@ func assertCreds(t testing.TB, address string, port int, username, password stri
}
}

func connect(t testing.TB, address string, port int, username, password string) error {
func connect(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions) error {
t.Helper()
clusterConfig := gocql.NewCluster(address)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Expand All @@ -257,6 +255,7 @@ func connect(t testing.TB, address string, port int, username, password string)
}
clusterConfig.ProtoVersion = 4
clusterConfig.Port = port
clusterConfig.SslOpts = sslOpts

session, err := clusterConfig.CreateSession()
if err != nil {
Expand All @@ -266,12 +265,12 @@ func connect(t testing.TB, address string, port int, username, password string)
return nil
}

func assertNoCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertNoCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()

op := func() error {
// "Invert" the error so the backoff logic sees a failure to connect as a success
err := connect(t, address, port, username, password)
err := connect(t, address, port, username, password, sslOpts)
if err != nil {
return nil
}
Expand Down
100 changes: 27 additions & 73 deletions plugins/database/cassandra/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -40,7 +39,7 @@ type cassandraConnectionProducer struct {

connectTimeout time.Duration
socketKeepAlive time.Duration
certBundle *certutil.CertBundle
sslOpts *gocql.SslOptions
rawConfig map[string]interface{}

Initialized bool
Expand Down Expand Up @@ -83,38 +82,46 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
return fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return fmt.Errorf("password cannot be empty")
case len(c.PemJSON) > 0 && len(c.PemBundle) > 0:
return fmt.Errorf("cannot specify both pem_json and pem_bundle")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is good to do, but note that we previously allowed both to be specified at the same time, with preference on pem_json so it's technically a breaking change. Would be good to get the docs updated as well to indicate that this is a one of between the two parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, this is a breaking change. I'm not a fan of allowing for both as it can create confusion ("which one is being used?"). I'll update the docs to describe this change unless anyone has any objections to making this change.

}

var tlsMinVersion uint16 = tls.VersionTLS12
if c.TLSMinVersion != "" {
ver, exists := tlsutil.TLSLookup[c.TLSMinVersion]
if !exists {
return fmt.Errorf("unrecognized TLS version [%s]", c.TLSMinVersion)
}
tlsMinVersion = ver
}

var certBundle *certutil.CertBundle
var parsedCertBundle *certutil.ParsedCertBundle
switch {
case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
cfg, err := jsonBundleToTLSConfig(c.PemJSON, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %w", err)
return fmt.Errorf("failed to parse pem_json: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true

case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
cfg, err := pemBundleToTLSConfig(c.PemBundle, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("error parsing the given PEM information: %w", err)
return fmt.Errorf("failed to parse pem_bundle: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true
}

if c.InsecureTLS {
c.TLS = true
case c.InsecureTLS:
c.sslOpts = &gocql.SslOptions{
EnableHostVerification: !c.InsecureTLS,
}
}

// Set initialized to true at this point since all fields are set,
Expand Down Expand Up @@ -183,14 +190,7 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

clusterConfig.Timeout = c.connectTimeout
clusterConfig.SocketKeepalive = c.socketKeepAlive

if c.TLS {
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return nil, err
}
clusterConfig.SslOpts = sslOpts
}
clusterConfig.SslOpts = c.sslOpts

if c.LocalDatacenter != "" {
clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
Expand Down Expand Up @@ -231,52 +231,6 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
return session, nil
}

func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion, serverName string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
tlsConfig := &tls.Config{}
if certBundle != nil {
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
}
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}

parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}

tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
}

tlsConfig.InsecureSkipVerify = insecureSkipVerify

if serverName != "" {
tlsConfig.ServerName = serverName
}

if minTLSVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}

opts := &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !insecureSkipVerify,
}
return opts, nil
}

func (c *cassandraConnectionProducer) secretValues() map[string]string {
return map[string]string{
c.Password: "[password]",
Expand Down