From 09f6f5f923ff22c0c5bfa3e66e4fed9192d14d1d Mon Sep 17 00:00:00 2001 From: Jacques Rascagneres Date: Tue, 22 Jun 2021 17:52:42 +0100 Subject: [PATCH] CBG-1504: Add ability to obtain mgmt endpoints from cluster (#5046) * CBG-1504: Add ability to obtain mgmt endpoints from cluster * Added bucket accessor * Address PR comments * Additional comment * Only use SystemCertPool if no caCertPath specified * Add X509 test * Handle error * Use ServerReadTimeout for WaitUntilReady if available --- base/bucket_gocb.go | 14 +++++-- base/gocb_utils.go | 76 ++++++++++++++++++++++++++++++++++++ db/database.go | 10 +++++ rest/server_context.go | 70 +++++++++++++++++++++++++++++++++ rest/server_context_test.go | 78 +++++++++++++++++++++++++++++++++++++ 5 files changed, 245 insertions(+), 3 deletions(-) diff --git a/base/bucket_gocb.go b/base/bucket_gocb.go index e604e5b191..fe15af0b4b 100644 --- a/base/bucket_gocb.go +++ b/base/bucket_gocb.go @@ -2037,11 +2037,19 @@ func AsGoCBBucket(bucket Bucket) (*CouchbaseBucketGoCB, bool) { return AsGoCBBucket(underlyingBucket) } -// Get one of the management endpoints. It will be a string such as http://couchbase -func GoCBBucketMgmtEndpoint(bucket *gocb.Bucket) (url string, err error) { +func GoCBBucketMgmtEndpoints(bucket *gocb.Bucket) (url []string, err error) { mgmtEps := bucket.IoRouter().MgmtEps() if len(mgmtEps) == 0 { - return "", fmt.Errorf("No available Couchbase Server nodes") + return nil, fmt.Errorf("No available Couchbase Server nodes") + } + return mgmtEps, nil +} + +// Get one of the management endpoints. It will be a string such as http://couchbase +func GoCBBucketMgmtEndpoint(bucket *gocb.Bucket) (url string, err error) { + mgmtEps, err := GoCBBucketMgmtEndpoints(bucket) + if err != nil { + return "", err } bucketEp := mgmtEps[rand.Intn(len(mgmtEps))] return bucketEp, nil diff --git a/base/gocb_utils.go b/base/gocb_utils.go index 6d3bceb8de..cbe8013695 100644 --- a/base/gocb_utils.go +++ b/base/gocb_utils.go @@ -8,6 +8,7 @@ import ( "time" "github.com/couchbase/gocb" + "github.com/couchbase/gocbcore" ) // GoCBv2SecurityConfig returns a gocb.SecurityConfig to use when connecting given a CA Cert path. @@ -60,3 +61,78 @@ func GoCBv2TimeoutsConfig(bucketOpTimeout, viewQueryTimeout *time.Duration) (tc } return tc } + +// GOCBCORE Utilities + +// CertificateAuthenticator allows for certificate auth in gocbcore +type CertificateAuthenticator struct { + ClientCertificate *tls.Certificate +} + +func (ca CertificateAuthenticator) SupportsTLS() bool { + return true +} +func (ca CertificateAuthenticator) SupportsNonTLS() bool { + return false +} +func (ca CertificateAuthenticator) Certificate(req gocbcore.AuthCertRequest) (*tls.Certificate, error) { + return ca.ClientCertificate, nil +} +func (ca CertificateAuthenticator) Credentials(req gocbcore.AuthCredsRequest) ([]gocbcore.UserPassPair, error) { + return []gocbcore.UserPassPair{{ + Username: "", + Password: "", + }}, nil +} + +// GoCBCoreAuthConfig returns a gocbcore.AuthProvider to use when connecting given a set of credentials via a gocbcore agent. +func GoCBCoreAuthConfig(username, password, certPath, keyPath string) (a gocbcore.AuthProvider, err error) { + if certPath != "" && keyPath != "" { + cert, certLoadErr := tls.LoadX509KeyPair(certPath, keyPath) + if certLoadErr != nil { + return nil, err + } + return CertificateAuthenticator{ + ClientCertificate: &cert, + }, nil + } + + return &gocbcore.PasswordAuthProvider{ + Username: username, + Password: password, + }, nil +} + +func GoCBCoreTLSRootCAProvider(caCertPath string) (func() *x509.CertPool, error) { + rootCAs, err := getRootCAs(caCertPath) + if err != nil { + return nil, err + } + + return func() *x509.CertPool { + return rootCAs + }, nil +} + +func getRootCAs(caCertPath string) (*x509.CertPool, error) { + if caCertPath != "" { + rootCAs := x509.NewCertPool() + + caCert, err := ioutil.ReadFile(caCertPath) + if err != nil { + return nil, err + } + + ok := rootCAs.AppendCertsFromPEM(caCert) + if !ok { + return nil, errors.New("Invalid CA cert") + } + + return rootCAs, nil + } + + // We're purposefully ignoring the error here Due to the fact that the main error case is that this call is not + // supported in Windows. + rootCAs, _ := x509.SystemCertPool() + return rootCAs, nil +} diff --git a/db/database.go b/db/database.go index 63bd94ae89..b9154a8512 100644 --- a/db/database.go +++ b/db/database.go @@ -1370,6 +1370,16 @@ func (db *Database) invalUserOrRoleChannels(name string, invalSeq uint64) { } } +func (context *DatabaseContext) ObtainManagementEndpoints() ([]string, error) { + gocbBucket, ok := base.AsGoCBBucket(context.Bucket) + if !ok { + base.Warnf("Database %v: Unable to get server management endpoints. Underlying bucket type was not GoCBBucket.", base.MD(context.Name)) + return nil, nil + } + + return base.GoCBBucketMgmtEndpoints(gocbBucket.Bucket) +} + func (context *DatabaseContext) GetUserViewsEnabled() bool { if context.Options.UnsupportedOptions.UserViews.Enabled != nil { return *context.Options.UnsupportedOptions.UserViews.Enabled diff --git a/rest/server_context.go b/rest/server_context.go index 7334401c37..565c97a802 100644 --- a/rest/server_context.go +++ b/rest/server_context.go @@ -22,6 +22,7 @@ import ( "time" "github.com/couchbase/go-couchbase" + "github.com/couchbase/gocbcore" sgbucket "github.com/couchbase/sg-bucket" "github.com/couchbase/sync_gateway/base" "github.com/couchbase/sync_gateway/db" @@ -1147,6 +1148,75 @@ func (sc *ServerContext) updateCalculatedStats() { } +func initClusterAgent(clusterAddress, clusterUser, clusterPass, certPath, keyPath, caCertPath string, timeoutSeconds *int) (*gocbcore.Agent, error) { + authenticator, err := base.GoCBCoreAuthConfig(clusterUser, clusterPass, certPath, keyPath) + if err != nil { + return nil, err + } + + tlsRootCAProvider, err := base.GoCBCoreTLSRootCAProvider(caCertPath) + if err != nil { + return nil, err + } + + config := gocbcore.AgentConfig{ + Auth: authenticator, + TLSRootCAProvider: tlsRootCAProvider, + } + + err = config.FromConnStr(clusterAddress) + if err != nil { + return nil, err + } + + agent, err := gocbcore.CreateAgent(&config) + if err != nil { + return nil, err + } + + agentWaitUntilReadyTimeoutSeconds := 5 * time.Second + if timeoutSeconds != nil { + agentWaitUntilReadyTimeoutSeconds = time.Duration(*timeoutSeconds) * time.Second + } + + agentReadyErr := make(chan error) + _, err = agent.WaitUntilReady(time.Now().Add(agentWaitUntilReadyTimeoutSeconds), gocbcore.WaitUntilReadyOptions{ServiceTypes: []gocbcore.ServiceType{gocbcore.MgmtService}}, func(result *gocbcore.WaitUntilReadyResult, err error) { + agentReadyErr <- err + }) + + if err != nil { + return nil, err + } + + if err := <-agentReadyErr; err != nil { + return nil, err + } + + return agent, nil +} + +// FIXME: Temporary connection settings. Awaiting bootstrap PR so we can use those details directly from server context +var tempConnectionDetailsForManagementEndpoints = func() (serverAddress string, username string, password string, certPath string, keyPath string, caCertPath string) { + return base.UnitTestUrl(), base.TestClusterUsername(), base.TestClusterPassword(), "", "", "" +} + +func (sc *ServerContext) ObtainManagementEndpoints() ([]string, error) { + clusterAddress, clusterUser, clusterPass, certPath, keyPath, caCertPath := tempConnectionDetailsForManagementEndpoints() + agent, err := initClusterAgent(clusterAddress, clusterUser, clusterPass, certPath, keyPath, caCertPath, sc.config.ServerReadTimeout) + if err != nil { + return nil, err + } + + managementEndpoints := agent.MgmtEps() + + err = agent.Close() + if err != nil { + return nil, err + } + + return managementEndpoints, nil +} + // For test use func (sc *ServerContext) Database(name string) *db.DatabaseContext { db, err := sc.GetDatabase(name) diff --git a/rest/server_context_test.go b/rest/server_context_test.go index 668639e583..96d07a472d 100644 --- a/rest/server_context_test.go +++ b/rest/server_context_test.go @@ -18,8 +18,10 @@ import ( "testing" "time" + "github.com/couchbase/gocbcore/connstr" "github.com/couchbase/sync_gateway/base" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Tests the ConfigServer feature. @@ -289,3 +291,79 @@ func TestStatsLoggerStopped(t *testing.T) { // sleep a bit to allow the "Stopping stats logging goroutine" debug logging to be printed time.Sleep(time.Millisecond * 10) } + +func TestObtainManagementEndpointsFromServerContext(t *testing.T) { + if base.UnitTestUrlIsWalrus() { + t.Skip("Test requires Couchbase Server") + } + + ctx := NewServerContext(&ServerConfig{}) + defer ctx.Close() + + eps, err := ctx.ObtainManagementEndpoints() + assert.NoError(t, err) + + clusterAddress, _, _, _, _, _ := tempConnectionDetailsForManagementEndpoints() + baseSpec, err := connstr.Parse(clusterAddress) + require.NoError(t, err) + + spec, err := connstr.Resolve(baseSpec) + require.NoError(t, err) + + existsOneMatchingEndpoint := false + +outerLoop: + for _, httpHost := range spec.HttpHosts { + for _, ep := range eps { + formattedHttpHost := fmt.Sprintf("http://%s:%d", httpHost.Host, httpHost.Port) + if formattedHttpHost == ep { + existsOneMatchingEndpoint = true + break outerLoop + } + } + } + + assert.True(t, existsOneMatchingEndpoint) +} + +func TestObtainManagementEndpointsFromServerContextWithX509(t *testing.T) { + tb, teardownFn, caCertPath, certPath, keyPath := setupX509Tests(t, true) + defer tb.Close() + defer teardownFn() + + original := tempConnectionDetailsForManagementEndpoints + defer func() { + tempConnectionDetailsForManagementEndpoints = original + }() + + tempConnectionDetailsForManagementEndpoints = func() (string, string, string, string, string, string) { + return base.UnitTestUrl(), base.TestClusterUsername(), base.TestClusterPassword(), certPath, keyPath, caCertPath + } + + ctx := NewServerContext(&ServerConfig{}) + defer ctx.Close() + + eps, err := ctx.ObtainManagementEndpoints() + assert.NoError(t, err) + + baseSpec, err := connstr.Parse(base.UnitTestUrl()) + require.NoError(t, err) + + spec, err := connstr.Resolve(baseSpec) + require.NoError(t, err) + + existsOneMatchingEndpoint := false + +outerLoop: + for _, httpHost := range spec.HttpHosts { + for _, ep := range eps { + formattedHttpHost := fmt.Sprintf("https://%s:%d", httpHost.Host, httpHost.Port) + if formattedHttpHost == ep { + existsOneMatchingEndpoint = true + break outerLoop + } + } + } + + assert.True(t, existsOneMatchingEndpoint) +}