From 74563f566bbcaf0a83c8a1dff1b9654e3932fdf5 Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Thu, 5 Jan 2023 09:43:03 -0600 Subject: [PATCH 1/6] add core state lockd eadlock detection config option v2 --- command/server.go | 1 + command/server/config.go | 9 ++++ command/server/config_test_helpers.go | 1 + helper/locking/deadlock.go | 21 -------- helper/locking/lock.go | 39 ++++++++++++--- http/sys_config_state_test.go | 1 + vault/core.go | 14 +++++- vault/core_test.go | 50 ++++++++++++++++++++ vault/expiration.go | 4 +- vault/testing.go | 1 + website/content/docs/configuration/index.mdx | 5 ++ 11 files changed, 116 insertions(+), 30 deletions(-) delete mode 100644 helper/locking/deadlock.go diff --git a/command/server.go b/command/server.go index f3f7db537b45..ed194f07c6fe 100644 --- a/command/server.go +++ b/command/server.go @@ -2619,6 +2619,7 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical. CredentialBackends: c.CredentialBackends, LogicalBackends: c.LogicalBackends, Logger: c.logger, + DetectDeadlocks: config.DetectDeadlocks, DisableSentinelTrace: config.DisableSentinelTrace, DisableCache: config.DisableCache, DisableMlock: config.DisableMlock, diff --git a/command/server/config.go b/command/server/config.go index b83a9fe2f7da..63b43def42d0 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -97,6 +97,8 @@ type Config struct { LogRequestsLevel string `hcl:"-"` LogRequestsLevelRaw interface{} `hcl:"log_requests_level"` + DetectDeadlocks string `hcl:"detect_deadlocks"` + EnableResponseHeaderRaftNodeID bool `hcl:"-"` EnableResponseHeaderRaftNodeIDRaw interface{} `hcl:"enable_response_header_raft_node_id"` @@ -389,6 +391,11 @@ func (c *Config) Merge(c2 *Config) *Config { result.LogRequestsLevel = c2.LogRequestsLevel } + result.DetectDeadlocks = c.DetectDeadlocks + if c2.DetectDeadlocks != "" { + result.DetectDeadlocks = c2.DetectDeadlocks + } + result.EnableResponseHeaderRaftNodeID = c.EnableResponseHeaderRaftNodeID if c2.EnableResponseHeaderRaftNodeID { result.EnableResponseHeaderRaftNodeID = c2.EnableResponseHeaderRaftNodeID @@ -1025,6 +1032,8 @@ func (c *Config) Sanitized() map[string]interface{} { "enable_response_header_raft_node_id": c.EnableResponseHeaderRaftNodeID, "log_requests_level": c.LogRequestsLevel, + + "detect_deadlocks": c.DetectDeadlocks, } for k, v := range sharedResult { result[k] = v diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index aac19b5df6dc..bb06dda93078 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -745,6 +745,7 @@ func testConfig_Sanitized(t *testing.T) { "raw_storage_endpoint": true, "introspection_endpoint": false, "disable_sentinel_trace": true, + "detect_deadlocks": "", "enable_ui": true, "enable_response_header_hostname": false, "enable_response_header_raft_node_id": false, diff --git a/helper/locking/deadlock.go b/helper/locking/deadlock.go deleted file mode 100644 index e250abd1aecb..000000000000 --- a/helper/locking/deadlock.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build deadlock - -package locking - -import ( - "github.com/sasha-s/go-deadlock" -) - -// DeadlockMutex, when the build tag `deadlock` is present, behaves like a -// sync.Mutex but does periodic checking to see if outstanding locks and requests -// look like a deadlock. If it finds a deadlock candidate it will output it -// prefixed with "POTENTIAL DEADLOCK", as described at -// https://github.com/sasha-s/go-deadlock -type DeadlockMutex struct { - deadlock.Mutex -} - -// DeadlockRWMutex is the RW version of DeadlockMutex. -type DeadlockRWMutex struct { - deadlock.RWMutex -} diff --git a/helper/locking/lock.go b/helper/locking/lock.go index 1b1fae3af9ec..8043f01ad617 100644 --- a/helper/locking/lock.go +++ b/helper/locking/lock.go @@ -1,19 +1,46 @@ -//go:build !deadlock - package locking import ( "sync" + + "github.com/sasha-s/go-deadlock" ) -// DeadlockMutex is just a sync.Mutex when the build tag `deadlock` is absent. -// See its other definition in the corresponding deadlock-build-tag-constrained -// file for more details. +// Common mutex interface to allow either built-in or imported deadlock use +type Mutex interface { + Lock() + Unlock() +} + +// Common r/w mutex interface to allow either built-in or imported deadlock use +type RWMutex interface { + Lock() + RLock() + RLocker() sync.Locker + RUnlock() + Unlock() +} + +// DeadlockMutex (used when requested via config option `detact_deadlocks`), +// behaves like a sync.Mutex but does periodic checking to see if outstanding +// locks and requests look like a deadlock. If it finds a deadlock candidate it +// will output it prefixed with "POTENTIAL DEADLOCK", as described at +// https://github.com/sasha-s/go-deadlock type DeadlockMutex struct { - sync.Mutex + deadlock.Mutex } // DeadlockRWMutex is the RW version of DeadlockMutex. type DeadlockRWMutex struct { + deadlock.RWMutex +} + +// Regular sync/mutex. +type SyncMutex struct { + sync.Mutex +} + +// DeadlockRWMutex is the RW version of SyncMutex. +type SyncRWMutex struct { sync.RWMutex } diff --git a/http/sys_config_state_test.go b/http/sys_config_state_test.go index 4cd2aae8b827..d55897854170 100644 --- a/http/sys_config_state_test.go +++ b/http/sys_config_state_test.go @@ -39,6 +39,7 @@ func TestSysConfigState_Sanitized(t *testing.T) { "disable_printable_check": false, "disable_sealwrap": false, "raw_storage_endpoint": false, + "detect_deadlocks": "", "introspection_endpoint": false, "disable_sentinel_trace": false, "enable_ui": false, diff --git a/vault/core.go b/vault/core.go index 2f8b8056832e..78952f682646 100644 --- a/vault/core.go +++ b/vault/core.go @@ -304,7 +304,7 @@ type Core struct { auditBackends map[string]audit.Factory // stateLock protects mutable state - stateLock locking.DeadlockRWMutex + stateLock locking.RWMutex sealed *uint32 standby bool @@ -713,6 +713,9 @@ type CoreConfig struct { Logger log.Logger + // Use the deadlocks library to detect deadlocks + DetectDeadlocks string + // Disables the trace display for Sentinel checks DisableSentinelTrace bool @@ -885,6 +888,14 @@ func CreateCore(conf *CoreConfig) (*Core, error) { conf.NumExpirationWorkers = numExpirationWorkersDefault } + // Use imported logging deadlock if requested + var stateLock locking.RWMutex + if conf.DetectDeadlocks != "" && strings.Contains(conf.DetectDeadlocks, "statelock") { + stateLock = &locking.DeadlockRWMutex{} + } else { + stateLock = &locking.SyncRWMutex{} + } + effectiveSDKVersion := conf.EffectiveSDKVersion if effectiveSDKVersion == "" { effectiveSDKVersion = version.GetVersion().Version @@ -903,6 +914,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { clusterListener: new(atomic.Value), customListenerHeader: new(atomic.Value), seal: conf.Seal, + stateLock: stateLock, router: NewRouter(), sealed: new(uint32), sealMigrationDone: new(uint32), diff --git a/vault/core_test.go b/vault/core_test.go index 3789c6853927..e72e705e8dbb 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "testing" "time" @@ -21,6 +22,7 @@ import ( "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical/inmem" "github.com/hashicorp/vault/version" + "github.com/sasha-s/go-deadlock" ) // invalidKey is used to test Unseal @@ -2836,3 +2838,51 @@ func TestCore_ServiceRegistration(t *testing.T) { t.Fatal(diff) } } + +func TestDetectedDeadlock(t *testing.T) { + testCore, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{DetectDeadlocks: "statelock"}) + InduceDeadlock(t, testCore, 1) +} + +func TestDefaultDeadlock(t *testing.T) { + testCore, _, _ := TestCoreUnsealed(t) + InduceDeadlock(t, testCore, 0) +} + +func RestoreDeadlockOpts() func() { + opts := deadlock.Opts + return func() { + deadlock.Opts = opts + } +} + +func InduceDeadlock(t *testing.T, vaultcore *Core, expected uint32) { + defer RestoreDeadlockOpts()() + var deadlocks uint32 + deadlock.Opts.OnPotentialDeadlock = func() { + atomic.AddUint32(&deadlocks, 1) + } + var mtx deadlock.Mutex + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + vaultcore.expiration.coreStateLock.Lock() + mtx.Lock() + mtx.Unlock() + vaultcore.expiration.coreStateLock.Unlock() + }() + wg.Wait() + wg.Add(1) + go func() { + defer wg.Done() + mtx.Lock() + vaultcore.expiration.coreStateLock.RLock() + vaultcore.expiration.coreStateLock.RUnlock() + mtx.Unlock() + }() + wg.Wait() + if atomic.LoadUint32(&deadlocks) != expected { + t.Fatalf("expected 1 deadlock, detected %d", deadlocks) + } +} diff --git a/vault/expiration.go b/vault/expiration.go index a5f1918dcf22..a49fed466aae 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -139,7 +139,7 @@ type ExpirationManager struct { quitCh chan struct{} // do not hold coreStateLock in any API handler code - it is already held - coreStateLock *locking.DeadlockRWMutex + coreStateLock locking.RWMutex quitContext context.Context leaseCheckCounter *uint32 @@ -350,7 +350,7 @@ func NewExpirationManager(c *Core, view *BarrierView, e ExpireLeaseStrategy, log restoreLocks: locksutil.CreateLocks(), quitCh: make(chan struct{}), - coreStateLock: &c.stateLock, + coreStateLock: c.stateLock, quitContext: c.activeContext, leaseCheckCounter: new(uint32), diff --git a/vault/testing.go b/vault/testing.go index 05853c09e787..d1438999ab58 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -212,6 +212,7 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core { conf.EnableResponseHeaderHostname = opts.EnableResponseHeaderHostname conf.DisableSSCTokens = opts.DisableSSCTokens conf.PluginDirectory = opts.PluginDirectory + conf.DetectDeadlocks = opts.DetectDeadlocks if opts.Logger != nil { conf.Logger = opts.Logger diff --git a/website/content/docs/configuration/index.mdx b/website/content/docs/configuration/index.mdx index b7166aa2e999..287d64d1fadd 100644 --- a/website/content/docs/configuration/index.mdx +++ b/website/content/docs/configuration/index.mdx @@ -149,6 +149,11 @@ to specify where the configuration is. maximum request duration allowed before Vault cancels the request. This can be overridden per listener via the `max_request_duration` value. +- `detect_deadlocks` `(string: "")` - Specifies the internal mutex locks that should be monitored for +potential deadlocks. Currently supported value is `statelock`, which will cause "POTENTIAL DEADLOCK:" +to be logged when an attempt at a core state lock appears to be deadlocked. Enabling this can have +a negative effect on performance due to the tracking of each lock attempt. + - `raw_storage_endpoint` `(bool: false)` – Enables the `sys/raw` endpoint which allows the decryption/encryption of raw data into and out of the security barrier. This is a highly privileged endpoint. From fd158d38da38414d3589bdb9e398dbfe5454580c Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Thu, 5 Jan 2023 09:51:57 -0600 Subject: [PATCH 2/6] add changelog --- changelog/18604.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog/18604.txt diff --git a/changelog/18604.txt b/changelog/18604.txt new file mode 100644 index 000000000000..7645cbb40394 --- /dev/null +++ b/changelog/18604.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core: add `detect_deadlocks` config to optionally detect core state deadlocks +``` \ No newline at end of file From fc4f1c75ef832fb85c8437c890e47c66e6e24db0 Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Fri, 6 Jan 2023 14:34:42 -0600 Subject: [PATCH 3/6] split out NewTestCluster function to maintain build flag --- vault/test_cluster.go | 553 +++++++++++++++++++++++++ vault/test_cluster_detect_deadlock.go | 556 ++++++++++++++++++++++++++ vault/testing.go | 523 ------------------------ 3 files changed, 1109 insertions(+), 523 deletions(-) create mode 100644 vault/test_cluster.go create mode 100644 vault/test_cluster_detect_deadlock.go diff --git a/vault/test_cluster.go b/vault/test_cluster.go new file mode 100644 index 000000000000..355482f34b66 --- /dev/null +++ b/vault/test_cluster.go @@ -0,0 +1,553 @@ +//go:build !deadlock + +package vault + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + mathrand "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/mitchellh/go-testing-interface" + + "github.com/hashicorp/go-secure-stdlib/reloadutil" + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical" + physInmem "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/hashicorp/vault/vault/cluster" +) + +// NewTestCluster creates a new test cluster based on the provided core config +// and test cluster options. +// +// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a +// core config for each core it creates. If separate seal per core is desired, opts.SealFunc +// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be +// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the +// provided Seal in coreConfig (i.e. base.Seal) is nil. +// +// If opts.Logger is provided, it takes precedence and will be used as the cluster +// logger and will be the basis for each core's logger. If no opts.Logger is +// given, one will be generated based on t.Name() for the cluster logger, and if +// no base.Logger is given will also be used as the basis for each core's logger. + +func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { + var err error + + var numCores int + if opts == nil || opts.NumCores == 0 { + numCores = DefaultNumCores + } else { + numCores = opts.NumCores + } + + certIPs := []net.IP{ + net.IPv6loopback, + net.ParseIP("127.0.0.1"), + } + var baseAddr *net.TCPAddr + if opts != nil && opts.BaseListenAddress != "" { + baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) + if err != nil { + t.Fatal("could not parse given base IP") + } + certIPs = append(certIPs, baseAddr.IP) + } else { + baseAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + } + + var testCluster TestCluster + testCluster.base = base + + switch { + case opts != nil && opts.Logger != nil: + testCluster.Logger = opts.Logger + default: + testCluster.Logger = NewTestLogger(t) + } + + if opts != nil && opts.TempDir != "" { + if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { + if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { + t.Fatal(err) + } + } + testCluster.TempDir = opts.TempDir + } else { + tempDir, err := ioutil.TempDir("", "vault-test-cluster-") + if err != nil { + t.Fatal(err) + } + testCluster.TempDir = tempDir + } + + var caKey *ecdsa.PrivateKey + if opts != nil && opts.CAKey != nil { + caKey = opts.CAKey + } else { + caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + } + testCluster.CAKey = caKey + var caBytes []byte + if opts != nil && len(opts.CACert) > 0 { + caBytes = opts.CACert + } else { + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: certIPs, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) + if err != nil { + t.Fatal(err) + } + } + caCert, err := x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + testCluster.CACert = caCert + testCluster.CACertBytes = caBytes + testCluster.RootCAs = x509.NewCertPool() + testCluster.RootCAs.AddCert(caCert) + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) + testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") + err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) + if err != nil { + t.Fatal(err) + } + marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) + if err != nil { + t.Fatal(err) + } + caKeyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledCAKey, + } + testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) + err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + + var certInfoSlice []*certInfo + + // + // Certs generation + // + for i := 0; i < numCores; i++ { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. + // This allows Prometheus running in docker to scrape the cluster for metrics. + DNSNames: []string{"localhost", "host.docker.internal"}, + IPAddresses: certIPs, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) + if err != nil { + t.Fatal(err) + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + certPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + } + certPEM := pem.EncodeToMemory(certPEMBlock) + marshaledKey, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + keyPEM := pem.EncodeToMemory(keyPEMBlock) + + certInfoSlice = append(certInfoSlice, &certInfo{ + cert: cert, + certPEM: certPEM, + certBytes: certBytes, + key: key, + keyPEM: keyPEM, + }) + } + + // + // Listener setup + // + addresses := []*net.TCPAddr{} + listeners := [][]*TestListener{} + servers := []*http.Server{} + handlers := []http.Handler{} + tlsConfigs := []*tls.Config{} + certGetters := []*reloadutil.CertificateGetter{} + for i := 0; i < numCores; i++ { + addr := &net.TCPAddr{ + IP: baseAddr.IP, + Port: 0, + } + if baseAddr.Port != 0 { + addr.Port = baseAddr.Port + i + } + + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatal(err) + } + addresses = append(addresses, addr) + + certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) + if err != nil { + t.Fatal(err) + } + certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") + certGetters = append(certGetters, certGetter) + certGetter.Reload() + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: testCluster.RootCAs, + ClientCAs: testCluster.RootCAs, + ClientAuth: tls.RequestClientCert, + NextProtos: []string{"h2", "http/1.1"}, + GetCertificate: certGetter.GetCertificate, + } + if opts != nil && opts.RequireClientAuth { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + testCluster.ClientAuthRequired = true + } + tlsConfigs = append(tlsConfigs, tlsConfig) + lns := []*TestListener{ + { + Listener: tls.NewListener(ln, tlsConfig), + Address: ln.Addr().(*net.TCPAddr), + }, + } + listeners = append(listeners, lns) + var handler http.Handler = http.NewServeMux() + handlers = append(handlers, handler) + server := &http.Server{ + Handler: handler, + ErrorLog: testCluster.Logger.StandardLogger(nil), + } + servers = append(servers, server) + } + + // Create three cores with the same physical and different redirect/cluster + // addrs. + // N.B.: On OSX, instead of random ports, it assigns new ports to new + // listeners sequentially. Aside from being a bad idea in a security sense, + // it also broke tests that assumed it was OK to just use the port above + // the redirect addr. This has now been changed to 105 ports above, but if + // we ever do more than three nodes in a cluster it may need to be bumped. + // Note: it's 105 so that we don't conflict with a running Consul by + // default. + coreConfig := &CoreConfig{ + LogicalBackends: make(map[string]logical.Factory), + CredentialBackends: make(map[string]logical.Factory), + AuditBackends: make(map[string]audit.Factory), + RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), + ClusterAddr: "https://127.0.0.1:0", + DisableMlock: true, + EnableUI: true, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + + if base != nil { + coreConfig.RawConfig = base.RawConfig + coreConfig.DisableCache = base.DisableCache + coreConfig.EnableUI = base.EnableUI + coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL + coreConfig.MaxLeaseTTL = base.MaxLeaseTTL + coreConfig.CacheSize = base.CacheSize + coreConfig.PluginDirectory = base.PluginDirectory + coreConfig.Seal = base.Seal + coreConfig.UnwrapSeal = base.UnwrapSeal + coreConfig.DevToken = base.DevToken + coreConfig.EnableRaw = base.EnableRaw + coreConfig.DisableSealWrap = base.DisableSealWrap + coreConfig.DisableCache = base.DisableCache + coreConfig.LicensingConfig = base.LicensingConfig + coreConfig.License = base.License + coreConfig.LicensePath = base.LicensePath + coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby + coreConfig.MetricsHelper = base.MetricsHelper + coreConfig.MetricSink = base.MetricSink + coreConfig.SecureRandomReader = base.SecureRandomReader + coreConfig.DisableSentinelTrace = base.DisableSentinelTrace + coreConfig.ClusterName = base.ClusterName + coreConfig.DisableAutopilot = base.DisableAutopilot + + if base.BuiltinRegistry != nil { + coreConfig.BuiltinRegistry = base.BuiltinRegistry + } + + if !coreConfig.DisableMlock { + base.DisableMlock = false + } + + if base.Physical != nil { + coreConfig.Physical = base.Physical + } + + if base.HAPhysical != nil { + coreConfig.HAPhysical = base.HAPhysical + } + + // Used to set something non-working to test fallback + switch base.ClusterAddr { + case "empty": + coreConfig.ClusterAddr = "" + case "": + default: + coreConfig.ClusterAddr = base.ClusterAddr + } + + if base.LogicalBackends != nil { + for k, v := range base.LogicalBackends { + coreConfig.LogicalBackends[k] = v + } + } + if base.CredentialBackends != nil { + for k, v := range base.CredentialBackends { + coreConfig.CredentialBackends[k] = v + } + } + if base.AuditBackends != nil { + for k, v := range base.AuditBackends { + coreConfig.AuditBackends[k] = v + } + } + if base.Logger != nil { + coreConfig.Logger = base.Logger + } + + coreConfig.ClusterCipherSuites = base.ClusterCipherSuites + coreConfig.DisableCache = base.DisableCache + coreConfig.DevToken = base.DevToken + coreConfig.RecoveryMode = base.RecoveryMode + coreConfig.ActivityLogConfig = base.ActivityLogConfig + coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname + coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID + coreConfig.RollbackPeriod = base.RollbackPeriod + coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed + coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase + testApplyEntBaseConfig(coreConfig, base) + } + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterHeartbeatInterval == 0 { + // Set this lower so that state populates quickly to standby nodes + coreConfig.ClusterHeartbeatInterval = 2 * time.Second + } + + if coreConfig.RawConfig == nil { + c := new(server.Config) + c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} + coreConfig.RawConfig = c + } + + addAuditBackend := len(coreConfig.AuditBackends) == 0 + if addAuditBackend { + AddNoopAudit(coreConfig, nil) + } + + if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { + coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + } + if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { + haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + coreConfig.HAPhysical = haPhys.(physical.HABackend) + } + + if testCluster.LicensePublicKey == nil { + pubKey, priKey, err := GenerateTestLicenseKeys() + if err != nil { + t.Fatalf("err: %v", err) + } + testCluster.LicensePublicKey = pubKey + testCluster.LicensePrivateKey = priKey + } + + if opts != nil && opts.InmemClusterLayers { + if opts.ClusterLayers != nil { + t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") + } + inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) + if err != nil { + t.Fatal(err) + } + opts.ClusterLayers = inmemCluster + } + + // Create cores + testCluster.cleanupFuncs = []func(){} + cores := []*Core{} + coreConfigs := []*CoreConfig{} + + for i := 0; i < numCores; i++ { + cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) + + testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) + cores = append(cores, c) + coreConfigs = append(coreConfigs, &localConfig) + + if handler != nil { + handlers[i] = handler + servers[i].Handler = handlers[i] + } + } + + // Clustering setup + for i := 0; i < numCores; i++ { + testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) + } + + // Create TestClusterCores + var ret []*TestClusterCore + for i := 0; i < numCores; i++ { + tcc := &TestClusterCore{ + Core: cores[i], + CoreConfig: coreConfigs[i], + ServerKey: certInfoSlice[i].key, + ServerKeyPEM: certInfoSlice[i].keyPEM, + ServerCert: certInfoSlice[i].cert, + ServerCertBytes: certInfoSlice[i].certBytes, + ServerCertPEM: certInfoSlice[i].certPEM, + Address: addresses[i], + Listeners: listeners[i], + Handler: handlers[i], + Server: servers[i], + TLSConfig: tlsConfigs[i], + Barrier: cores[i].barrier, + NodeID: fmt.Sprintf("core-%d", i), + UnderlyingRawStorage: coreConfigs[i].Physical, + UnderlyingHAStorage: coreConfigs[i].HAPhysical, + } + tcc.ReloadFuncs = &cores[i].reloadFuncs + tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock + tcc.ReloadFuncsLock.Lock() + (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} + tcc.ReloadFuncsLock.Unlock() + + testAdjustUnderlyingStorage(tcc) + + ret = append(ret, tcc) + } + testCluster.Cores = ret + + // Initialize cores + if opts == nil || !opts.SkipInit { + testCluster.initCores(t, opts, addAuditBackend) + } + + // Assign clients + for i := 0; i < numCores; i++ { + testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) + } + + // Extra Setup + for _, tcc := range testCluster.Cores { + testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) + } + + // Cleanup + testCluster.CleanupFunc = func() { + for _, c := range testCluster.cleanupFuncs { + c() + } + if l, ok := testCluster.Logger.(*TestLogger); ok { + if t.Failed() { + _ = l.File.Close() + } else { + _ = os.Remove(l.Path) + } + } + } + + // Setup + if opts != nil { + if opts.SetupFunc != nil { + testCluster.SetupFunc = func() { + opts.SetupFunc(t, &testCluster) + } + } + } + + testCluster.opts = opts + testCluster.start(t) + return &testCluster +} diff --git a/vault/test_cluster_detect_deadlock.go b/vault/test_cluster_detect_deadlock.go new file mode 100644 index 000000000000..07652938331a --- /dev/null +++ b/vault/test_cluster_detect_deadlock.go @@ -0,0 +1,556 @@ +//go:build deadlock + +package vault + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + mathrand "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/mitchellh/go-testing-interface" + + "github.com/hashicorp/go-secure-stdlib/reloadutil" + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical" + physInmem "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/hashicorp/vault/vault/cluster" +) + +// NewTestCluster creates a new test cluster based on the provided core config +// and test cluster options. +// +// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a +// core config for each core it creates. If separate seal per core is desired, opts.SealFunc +// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be +// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the +// provided Seal in coreConfig (i.e. base.Seal) is nil. +// +// If opts.Logger is provided, it takes precedence and will be used as the cluster +// logger and will be the basis for each core's logger. If no opts.Logger is +// given, one will be generated based on t.Name() for the cluster logger, and if +// no base.Logger is given will also be used as the basis for each core's logger. + +func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { + base.DetectDeadlocks = "stateLock" // detect deadlocks because of build tag "deadlock" + + var err error + + var numCores int + if opts == nil || opts.NumCores == 0 { + numCores = DefaultNumCores + } else { + numCores = opts.NumCores + } + + certIPs := []net.IP{ + net.IPv6loopback, + net.ParseIP("127.0.0.1"), + } + var baseAddr *net.TCPAddr + if opts != nil && opts.BaseListenAddress != "" { + baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) + + if err != nil { + t.Fatal("could not parse given base IP") + } + certIPs = append(certIPs, baseAddr.IP) + } else { + baseAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + } + + var testCluster TestCluster + testCluster.base = base + + switch { + case opts != nil && opts.Logger != nil: + testCluster.Logger = opts.Logger + default: + testCluster.Logger = NewTestLogger(t) + } + + if opts != nil && opts.TempDir != "" { + if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { + if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { + t.Fatal(err) + } + } + testCluster.TempDir = opts.TempDir + } else { + tempDir, err := ioutil.TempDir("", "vault-test-cluster-") + if err != nil { + t.Fatal(err) + } + testCluster.TempDir = tempDir + } + + var caKey *ecdsa.PrivateKey + if opts != nil && opts.CAKey != nil { + caKey = opts.CAKey + } else { + caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + } + testCluster.CAKey = caKey + var caBytes []byte + if opts != nil && len(opts.CACert) > 0 { + caBytes = opts.CACert + } else { + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: certIPs, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) + if err != nil { + t.Fatal(err) + } + } + caCert, err := x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + testCluster.CACert = caCert + testCluster.CACertBytes = caBytes + testCluster.RootCAs = x509.NewCertPool() + testCluster.RootCAs.AddCert(caCert) + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) + testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") + err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) + if err != nil { + t.Fatal(err) + } + marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) + if err != nil { + t.Fatal(err) + } + caKeyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledCAKey, + } + testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) + err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + + var certInfoSlice []*certInfo + + // + // Certs generation + // + for i := 0; i < numCores; i++ { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. + // This allows Prometheus running in docker to scrape the cluster for metrics. + DNSNames: []string{"localhost", "host.docker.internal"}, + IPAddresses: certIPs, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) + if err != nil { + t.Fatal(err) + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + certPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + } + certPEM := pem.EncodeToMemory(certPEMBlock) + marshaledKey, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + keyPEM := pem.EncodeToMemory(keyPEMBlock) + + certInfoSlice = append(certInfoSlice, &certInfo{ + cert: cert, + certPEM: certPEM, + certBytes: certBytes, + key: key, + keyPEM: keyPEM, + }) + } + + // + // Listener setup + // + addresses := []*net.TCPAddr{} + listeners := [][]*TestListener{} + servers := []*http.Server{} + handlers := []http.Handler{} + tlsConfigs := []*tls.Config{} + certGetters := []*reloadutil.CertificateGetter{} + for i := 0; i < numCores; i++ { + addr := &net.TCPAddr{ + IP: baseAddr.IP, + Port: 0, + } + if baseAddr.Port != 0 { + addr.Port = baseAddr.Port + i + } + + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatal(err) + } + addresses = append(addresses, addr) + + certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) + if err != nil { + t.Fatal(err) + } + certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") + certGetters = append(certGetters, certGetter) + certGetter.Reload() + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: testCluster.RootCAs, + ClientCAs: testCluster.RootCAs, + ClientAuth: tls.RequestClientCert, + NextProtos: []string{"h2", "http/1.1"}, + GetCertificate: certGetter.GetCertificate, + } + if opts != nil && opts.RequireClientAuth { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + testCluster.ClientAuthRequired = true + } + tlsConfigs = append(tlsConfigs, tlsConfig) + lns := []*TestListener{ + { + Listener: tls.NewListener(ln, tlsConfig), + Address: ln.Addr().(*net.TCPAddr), + }, + } + listeners = append(listeners, lns) + var handler http.Handler = http.NewServeMux() + handlers = append(handlers, handler) + server := &http.Server{ + Handler: handler, + ErrorLog: testCluster.Logger.StandardLogger(nil), + } + servers = append(servers, server) + } + + // Create three cores with the same physical and different redirect/cluster + // addrs. + // N.B.: On OSX, instead of random ports, it assigns new ports to new + // listeners sequentially. Aside from being a bad idea in a security sense, + // it also broke tests that assumed it was OK to just use the port above + // the redirect addr. This has now been changed to 105 ports above, but if + // we ever do more than three nodes in a cluster it may need to be bumped. + // Note: it's 105 so that we don't conflict with a running Consul by + // default. + coreConfig := &CoreConfig{ + LogicalBackends: make(map[string]logical.Factory), + CredentialBackends: make(map[string]logical.Factory), + AuditBackends: make(map[string]audit.Factory), + RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), + ClusterAddr: "https://127.0.0.1:0", + DisableMlock: true, + EnableUI: true, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + + if base != nil { + coreConfig.RawConfig = base.RawConfig + coreConfig.DisableCache = base.DisableCache + coreConfig.EnableUI = base.EnableUI + coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL + coreConfig.MaxLeaseTTL = base.MaxLeaseTTL + coreConfig.CacheSize = base.CacheSize + coreConfig.PluginDirectory = base.PluginDirectory + coreConfig.Seal = base.Seal + coreConfig.UnwrapSeal = base.UnwrapSeal + coreConfig.DevToken = base.DevToken + coreConfig.EnableRaw = base.EnableRaw + coreConfig.DisableSealWrap = base.DisableSealWrap + coreConfig.DisableCache = base.DisableCache + coreConfig.LicensingConfig = base.LicensingConfig + coreConfig.License = base.License + coreConfig.LicensePath = base.LicensePath + coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby + coreConfig.MetricsHelper = base.MetricsHelper + coreConfig.MetricSink = base.MetricSink + coreConfig.SecureRandomReader = base.SecureRandomReader + coreConfig.DisableSentinelTrace = base.DisableSentinelTrace + coreConfig.ClusterName = base.ClusterName + coreConfig.DisableAutopilot = base.DisableAutopilot + + if base.BuiltinRegistry != nil { + coreConfig.BuiltinRegistry = base.BuiltinRegistry + } + + if !coreConfig.DisableMlock { + base.DisableMlock = false + } + + if base.Physical != nil { + coreConfig.Physical = base.Physical + } + + if base.HAPhysical != nil { + coreConfig.HAPhysical = base.HAPhysical + } + + // Used to set something non-working to test fallback + switch base.ClusterAddr { + case "empty": + coreConfig.ClusterAddr = "" + case "": + default: + coreConfig.ClusterAddr = base.ClusterAddr + } + + if base.LogicalBackends != nil { + for k, v := range base.LogicalBackends { + coreConfig.LogicalBackends[k] = v + } + } + if base.CredentialBackends != nil { + for k, v := range base.CredentialBackends { + coreConfig.CredentialBackends[k] = v + } + } + if base.AuditBackends != nil { + for k, v := range base.AuditBackends { + coreConfig.AuditBackends[k] = v + } + } + if base.Logger != nil { + coreConfig.Logger = base.Logger + } + + coreConfig.ClusterCipherSuites = base.ClusterCipherSuites + coreConfig.DisableCache = base.DisableCache + coreConfig.DevToken = base.DevToken + coreConfig.RecoveryMode = base.RecoveryMode + coreConfig.ActivityLogConfig = base.ActivityLogConfig + coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname + coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID + coreConfig.RollbackPeriod = base.RollbackPeriod + coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed + coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase + testApplyEntBaseConfig(coreConfig, base) + } + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterHeartbeatInterval == 0 { + // Set this lower so that state populates quickly to standby nodes + coreConfig.ClusterHeartbeatInterval = 2 * time.Second + } + + if coreConfig.RawConfig == nil { + c := new(server.Config) + c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} + coreConfig.RawConfig = c + } + + addAuditBackend := len(coreConfig.AuditBackends) == 0 + if addAuditBackend { + AddNoopAudit(coreConfig, nil) + } + + if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { + coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + } + if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { + haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + coreConfig.HAPhysical = haPhys.(physical.HABackend) + } + + if testCluster.LicensePublicKey == nil { + pubKey, priKey, err := GenerateTestLicenseKeys() + if err != nil { + t.Fatalf("err: %v", err) + } + testCluster.LicensePublicKey = pubKey + testCluster.LicensePrivateKey = priKey + } + + if opts != nil && opts.InmemClusterLayers { + if opts.ClusterLayers != nil { + t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") + } + inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) + if err != nil { + t.Fatal(err) + } + opts.ClusterLayers = inmemCluster + } + + // Create cores + testCluster.cleanupFuncs = []func(){} + cores := []*Core{} + coreConfigs := []*CoreConfig{} + + for i := 0; i < numCores; i++ { + cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) + + testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) + cores = append(cores, c) + coreConfigs = append(coreConfigs, &localConfig) + + if handler != nil { + handlers[i] = handler + servers[i].Handler = handlers[i] + } + } + + // Clustering setup + for i := 0; i < numCores; i++ { + testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) + } + + // Create TestClusterCores + var ret []*TestClusterCore + for i := 0; i < numCores; i++ { + tcc := &TestClusterCore{ + Core: cores[i], + CoreConfig: coreConfigs[i], + ServerKey: certInfoSlice[i].key, + ServerKeyPEM: certInfoSlice[i].keyPEM, + ServerCert: certInfoSlice[i].cert, + ServerCertBytes: certInfoSlice[i].certBytes, + ServerCertPEM: certInfoSlice[i].certPEM, + Address: addresses[i], + Listeners: listeners[i], + Handler: handlers[i], + Server: servers[i], + TLSConfig: tlsConfigs[i], + Barrier: cores[i].barrier, + NodeID: fmt.Sprintf("core-%d", i), + UnderlyingRawStorage: coreConfigs[i].Physical, + UnderlyingHAStorage: coreConfigs[i].HAPhysical, + } + tcc.ReloadFuncs = &cores[i].reloadFuncs + tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock + tcc.ReloadFuncsLock.Lock() + (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} + tcc.ReloadFuncsLock.Unlock() + + testAdjustUnderlyingStorage(tcc) + + ret = append(ret, tcc) + } + testCluster.Cores = ret + + // Initialize cores + if opts == nil || !opts.SkipInit { + testCluster.initCores(t, opts, addAuditBackend) + } + + // Assign clients + for i := 0; i < numCores; i++ { + testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) + } + + // Extra Setup + for _, tcc := range testCluster.Cores { + testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) + } + + // Cleanup + testCluster.CleanupFunc = func() { + for _, c := range testCluster.cleanupFuncs { + c() + } + if l, ok := testCluster.Logger.(*TestLogger); ok { + if t.Failed() { + _ = l.File.Close() + } else { + _ = os.Remove(l.Path) + } + } + } + + // Setup + if opts != nil { + if opts.SetupFunc != nil { + testCluster.SetupFunc = func() { + opts.SetupFunc(t, &testCluster) + } + } + } + + testCluster.opts = opts + testCluster.start(t) + return &testCluster +} diff --git a/vault/testing.go b/vault/testing.go index d1438999ab58..86fca41050d3 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -4,20 +4,15 @@ import ( "bytes" "context" "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "encoding/base64" - "encoding/pem" "errors" "fmt" "io" "io/ioutil" - "math/big" - mathrand "math/rand" "net" "net/http" "os" @@ -1353,524 +1348,6 @@ func (tl *TestLogger) StopLogging() { tl.Logger.(log.InterceptLogger).DeregisterSink(tl.sink) } -// NewTestCluster creates a new test cluster based on the provided core config -// and test cluster options. -// -// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a -// core config for each core it creates. If separate seal per core is desired, opts.SealFunc -// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be -// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the -// provided Seal in coreConfig (i.e. base.Seal) is nil. -// -// If opts.Logger is provided, it takes precedence and will be used as the cluster -// logger and will be the basis for each core's logger. If no opts.Logger is -// given, one will be generated based on t.Name() for the cluster logger, and if -// no base.Logger is given will also be used as the basis for each core's logger. -func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { - var err error - - var numCores int - if opts == nil || opts.NumCores == 0 { - numCores = DefaultNumCores - } else { - numCores = opts.NumCores - } - - certIPs := []net.IP{ - net.IPv6loopback, - net.ParseIP("127.0.0.1"), - } - var baseAddr *net.TCPAddr - if opts != nil && opts.BaseListenAddress != "" { - baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) - if err != nil { - t.Fatal("could not parse given base IP") - } - certIPs = append(certIPs, baseAddr.IP) - } else { - baseAddr = &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - } - - var testCluster TestCluster - testCluster.base = base - - switch { - case opts != nil && opts.Logger != nil: - testCluster.Logger = opts.Logger - default: - testCluster.Logger = NewTestLogger(t) - } - - if opts != nil && opts.TempDir != "" { - if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { - if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { - t.Fatal(err) - } - } - testCluster.TempDir = opts.TempDir - } else { - tempDir, err := ioutil.TempDir("", "vault-test-cluster-") - if err != nil { - t.Fatal(err) - } - testCluster.TempDir = tempDir - } - - var caKey *ecdsa.PrivateKey - if opts != nil && opts.CAKey != nil { - caKey = opts.CAKey - } else { - caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - } - testCluster.CAKey = caKey - var caBytes []byte - if opts != nil && len(opts.CACert) > 0 { - caBytes = opts.CACert - } else { - caCertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) - if err != nil { - t.Fatal(err) - } - } - caCert, err := x509.ParseCertificate(caBytes) - if err != nil { - t.Fatal(err) - } - testCluster.CACert = caCert - testCluster.CACertBytes = caBytes - testCluster.RootCAs = x509.NewCertPool() - testCluster.RootCAs.AddCert(caCert) - caCertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - } - testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) - testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") - err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) - if err != nil { - t.Fatal(err) - } - marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) - if err != nil { - t.Fatal(err) - } - caKeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledCAKey, - } - testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - - var certInfoSlice []*certInfo - - // - // Certs generation - // - for i := 0; i < numCores; i++ { - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - certTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. - // This allows Prometheus running in docker to scrape the cluster for metrics. - DNSNames: []string{"localhost", "host.docker.internal"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - t.Fatal(err) - } - certPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - } - certPEM := pem.EncodeToMemory(certPEMBlock) - marshaledKey, err := x509.MarshalECPrivateKey(key) - if err != nil { - t.Fatal(err) - } - keyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledKey, - } - keyPEM := pem.EncodeToMemory(keyPEMBlock) - - certInfoSlice = append(certInfoSlice, &certInfo{ - cert: cert, - certPEM: certPEM, - certBytes: certBytes, - key: key, - keyPEM: keyPEM, - }) - } - - // - // Listener setup - // - addresses := []*net.TCPAddr{} - listeners := [][]*TestListener{} - servers := []*http.Server{} - handlers := []http.Handler{} - tlsConfigs := []*tls.Config{} - certGetters := []*reloadutil.CertificateGetter{} - for i := 0; i < numCores; i++ { - addr := &net.TCPAddr{ - IP: baseAddr.IP, - Port: 0, - } - if baseAddr.Port != 0 { - addr.Port = baseAddr.Port + i - } - - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Fatal(err) - } - addresses = append(addresses, addr) - - certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) - if err != nil { - t.Fatal(err) - } - certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") - certGetters = append(certGetters, certGetter) - certGetter.Reload() - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.RequestClientCert, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: certGetter.GetCertificate, - } - if opts != nil && opts.RequireClientAuth { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - testCluster.ClientAuthRequired = true - } - tlsConfigs = append(tlsConfigs, tlsConfig) - lns := []*TestListener{ - { - Listener: tls.NewListener(ln, tlsConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - listeners = append(listeners, lns) - var handler http.Handler = http.NewServeMux() - handlers = append(handlers, handler) - server := &http.Server{ - Handler: handler, - ErrorLog: testCluster.Logger.StandardLogger(nil), - } - servers = append(servers, server) - } - - // Create three cores with the same physical and different redirect/cluster - // addrs. - // N.B.: On OSX, instead of random ports, it assigns new ports to new - // listeners sequentially. Aside from being a bad idea in a security sense, - // it also broke tests that assumed it was OK to just use the port above - // the redirect addr. This has now been changed to 105 ports above, but if - // we ever do more than three nodes in a cluster it may need to be bumped. - // Note: it's 105 so that we don't conflict with a running Consul by - // default. - coreConfig := &CoreConfig{ - LogicalBackends: make(map[string]logical.Factory), - CredentialBackends: make(map[string]logical.Factory), - AuditBackends: make(map[string]audit.Factory), - RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), - ClusterAddr: "https://127.0.0.1:0", - DisableMlock: true, - EnableUI: true, - EnableRaw: true, - BuiltinRegistry: NewMockBuiltinRegistry(), - } - - if base != nil { - coreConfig.RawConfig = base.RawConfig - coreConfig.DisableCache = base.DisableCache - coreConfig.EnableUI = base.EnableUI - coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL - coreConfig.MaxLeaseTTL = base.MaxLeaseTTL - coreConfig.CacheSize = base.CacheSize - coreConfig.PluginDirectory = base.PluginDirectory - coreConfig.Seal = base.Seal - coreConfig.UnwrapSeal = base.UnwrapSeal - coreConfig.DevToken = base.DevToken - coreConfig.EnableRaw = base.EnableRaw - coreConfig.DisableSealWrap = base.DisableSealWrap - coreConfig.DisableCache = base.DisableCache - coreConfig.LicensingConfig = base.LicensingConfig - coreConfig.License = base.License - coreConfig.LicensePath = base.LicensePath - coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby - coreConfig.MetricsHelper = base.MetricsHelper - coreConfig.MetricSink = base.MetricSink - coreConfig.SecureRandomReader = base.SecureRandomReader - coreConfig.DisableSentinelTrace = base.DisableSentinelTrace - coreConfig.ClusterName = base.ClusterName - coreConfig.DisableAutopilot = base.DisableAutopilot - - if base.BuiltinRegistry != nil { - coreConfig.BuiltinRegistry = base.BuiltinRegistry - } - - if !coreConfig.DisableMlock { - base.DisableMlock = false - } - - if base.Physical != nil { - coreConfig.Physical = base.Physical - } - - if base.HAPhysical != nil { - coreConfig.HAPhysical = base.HAPhysical - } - - // Used to set something non-working to test fallback - switch base.ClusterAddr { - case "empty": - coreConfig.ClusterAddr = "" - case "": - default: - coreConfig.ClusterAddr = base.ClusterAddr - } - - if base.LogicalBackends != nil { - for k, v := range base.LogicalBackends { - coreConfig.LogicalBackends[k] = v - } - } - if base.CredentialBackends != nil { - for k, v := range base.CredentialBackends { - coreConfig.CredentialBackends[k] = v - } - } - if base.AuditBackends != nil { - for k, v := range base.AuditBackends { - coreConfig.AuditBackends[k] = v - } - } - if base.Logger != nil { - coreConfig.Logger = base.Logger - } - - coreConfig.ClusterCipherSuites = base.ClusterCipherSuites - coreConfig.DisableCache = base.DisableCache - coreConfig.DevToken = base.DevToken - coreConfig.RecoveryMode = base.RecoveryMode - coreConfig.ActivityLogConfig = base.ActivityLogConfig - coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname - coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID - coreConfig.RollbackPeriod = base.RollbackPeriod - coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed - coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase - testApplyEntBaseConfig(coreConfig, base) - } - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterHeartbeatInterval == 0 { - // Set this lower so that state populates quickly to standby nodes - coreConfig.ClusterHeartbeatInterval = 2 * time.Second - } - - if coreConfig.RawConfig == nil { - c := new(server.Config) - c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} - coreConfig.RawConfig = c - } - - addAuditBackend := len(coreConfig.AuditBackends) == 0 - if addAuditBackend { - AddNoopAudit(coreConfig, nil) - } - - if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { - coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - } - if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { - haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - coreConfig.HAPhysical = haPhys.(physical.HABackend) - } - - if testCluster.LicensePublicKey == nil { - pubKey, priKey, err := GenerateTestLicenseKeys() - if err != nil { - t.Fatalf("err: %v", err) - } - testCluster.LicensePublicKey = pubKey - testCluster.LicensePrivateKey = priKey - } - - if opts != nil && opts.InmemClusterLayers { - if opts.ClusterLayers != nil { - t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") - } - inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) - if err != nil { - t.Fatal(err) - } - opts.ClusterLayers = inmemCluster - } - - // Create cores - testCluster.cleanupFuncs = []func(){} - cores := []*Core{} - coreConfigs := []*CoreConfig{} - - for i := 0; i < numCores; i++ { - cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) - - testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) - cores = append(cores, c) - coreConfigs = append(coreConfigs, &localConfig) - - if handler != nil { - handlers[i] = handler - servers[i].Handler = handlers[i] - } - } - - // Clustering setup - for i := 0; i < numCores; i++ { - testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) - } - - // Create TestClusterCores - var ret []*TestClusterCore - for i := 0; i < numCores; i++ { - tcc := &TestClusterCore{ - Core: cores[i], - CoreConfig: coreConfigs[i], - ServerKey: certInfoSlice[i].key, - ServerKeyPEM: certInfoSlice[i].keyPEM, - ServerCert: certInfoSlice[i].cert, - ServerCertBytes: certInfoSlice[i].certBytes, - ServerCertPEM: certInfoSlice[i].certPEM, - Address: addresses[i], - Listeners: listeners[i], - Handler: handlers[i], - Server: servers[i], - TLSConfig: tlsConfigs[i], - Barrier: cores[i].barrier, - NodeID: fmt.Sprintf("core-%d", i), - UnderlyingRawStorage: coreConfigs[i].Physical, - UnderlyingHAStorage: coreConfigs[i].HAPhysical, - } - tcc.ReloadFuncs = &cores[i].reloadFuncs - tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock - tcc.ReloadFuncsLock.Lock() - (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} - tcc.ReloadFuncsLock.Unlock() - - testAdjustUnderlyingStorage(tcc) - - ret = append(ret, tcc) - } - testCluster.Cores = ret - - // Initialize cores - if opts == nil || !opts.SkipInit { - testCluster.initCores(t, opts, addAuditBackend) - } - - // Assign clients - for i := 0; i < numCores; i++ { - testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) - } - - // Extra Setup - for _, tcc := range testCluster.Cores { - testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) - } - - // Cleanup - testCluster.CleanupFunc = func() { - for _, c := range testCluster.cleanupFuncs { - c() - } - if l, ok := testCluster.Logger.(*TestLogger); ok { - if t.Failed() { - _ = l.File.Close() - } else { - _ = os.Remove(l.Path) - } - } - } - - // Setup - if opts != nil { - if opts.SetupFunc != nil { - testCluster.SetupFunc = func() { - opts.SetupFunc(t, &testCluster) - } - } - } - - testCluster.opts = opts - testCluster.start(t) - return &testCluster -} - // StopCore performs an orderly shutdown of a core. func (cluster *TestCluster) StopCore(t testing.T, idx int) { t.Helper() From 70578b7b4a95e5da82a4180e38e7be07075e4661 Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Fri, 6 Jan 2023 15:43:11 -0600 Subject: [PATCH 4/6] replace long func with constant --- vault/test_cluster.go | 550 +------------------------ vault/test_cluster_detect_deadlock.go | 553 +------------------------- vault/testing.go | 527 ++++++++++++++++++++++++ 3 files changed, 529 insertions(+), 1101 deletions(-) diff --git a/vault/test_cluster.go b/vault/test_cluster.go index 355482f34b66..06cf1a94a102 100644 --- a/vault/test_cluster.go +++ b/vault/test_cluster.go @@ -2,552 +2,4 @@ package vault -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "io/ioutil" - "math/big" - mathrand "math/rand" - "net" - "net/http" - "os" - "path/filepath" - "time" - - "github.com/mitchellh/go-testing-interface" - - "github.com/hashicorp/go-secure-stdlib/reloadutil" - "github.com/hashicorp/vault/audit" - "github.com/hashicorp/vault/command/server" - "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/sdk/helper/logging" - "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/sdk/physical" - physInmem "github.com/hashicorp/vault/sdk/physical/inmem" - "github.com/hashicorp/vault/vault/cluster" -) - -// NewTestCluster creates a new test cluster based on the provided core config -// and test cluster options. -// -// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a -// core config for each core it creates. If separate seal per core is desired, opts.SealFunc -// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be -// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the -// provided Seal in coreConfig (i.e. base.Seal) is nil. -// -// If opts.Logger is provided, it takes precedence and will be used as the cluster -// logger and will be the basis for each core's logger. If no opts.Logger is -// given, one will be generated based on t.Name() for the cluster logger, and if -// no base.Logger is given will also be used as the basis for each core's logger. - -func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { - var err error - - var numCores int - if opts == nil || opts.NumCores == 0 { - numCores = DefaultNumCores - } else { - numCores = opts.NumCores - } - - certIPs := []net.IP{ - net.IPv6loopback, - net.ParseIP("127.0.0.1"), - } - var baseAddr *net.TCPAddr - if opts != nil && opts.BaseListenAddress != "" { - baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) - if err != nil { - t.Fatal("could not parse given base IP") - } - certIPs = append(certIPs, baseAddr.IP) - } else { - baseAddr = &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - } - - var testCluster TestCluster - testCluster.base = base - - switch { - case opts != nil && opts.Logger != nil: - testCluster.Logger = opts.Logger - default: - testCluster.Logger = NewTestLogger(t) - } - - if opts != nil && opts.TempDir != "" { - if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { - if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { - t.Fatal(err) - } - } - testCluster.TempDir = opts.TempDir - } else { - tempDir, err := ioutil.TempDir("", "vault-test-cluster-") - if err != nil { - t.Fatal(err) - } - testCluster.TempDir = tempDir - } - - var caKey *ecdsa.PrivateKey - if opts != nil && opts.CAKey != nil { - caKey = opts.CAKey - } else { - caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - } - testCluster.CAKey = caKey - var caBytes []byte - if opts != nil && len(opts.CACert) > 0 { - caBytes = opts.CACert - } else { - caCertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) - if err != nil { - t.Fatal(err) - } - } - caCert, err := x509.ParseCertificate(caBytes) - if err != nil { - t.Fatal(err) - } - testCluster.CACert = caCert - testCluster.CACertBytes = caBytes - testCluster.RootCAs = x509.NewCertPool() - testCluster.RootCAs.AddCert(caCert) - caCertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - } - testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) - testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") - err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) - if err != nil { - t.Fatal(err) - } - marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) - if err != nil { - t.Fatal(err) - } - caKeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledCAKey, - } - testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - - var certInfoSlice []*certInfo - - // - // Certs generation - // - for i := 0; i < numCores; i++ { - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - certTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. - // This allows Prometheus running in docker to scrape the cluster for metrics. - DNSNames: []string{"localhost", "host.docker.internal"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - t.Fatal(err) - } - certPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - } - certPEM := pem.EncodeToMemory(certPEMBlock) - marshaledKey, err := x509.MarshalECPrivateKey(key) - if err != nil { - t.Fatal(err) - } - keyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledKey, - } - keyPEM := pem.EncodeToMemory(keyPEMBlock) - - certInfoSlice = append(certInfoSlice, &certInfo{ - cert: cert, - certPEM: certPEM, - certBytes: certBytes, - key: key, - keyPEM: keyPEM, - }) - } - - // - // Listener setup - // - addresses := []*net.TCPAddr{} - listeners := [][]*TestListener{} - servers := []*http.Server{} - handlers := []http.Handler{} - tlsConfigs := []*tls.Config{} - certGetters := []*reloadutil.CertificateGetter{} - for i := 0; i < numCores; i++ { - addr := &net.TCPAddr{ - IP: baseAddr.IP, - Port: 0, - } - if baseAddr.Port != 0 { - addr.Port = baseAddr.Port + i - } - - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Fatal(err) - } - addresses = append(addresses, addr) - - certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) - if err != nil { - t.Fatal(err) - } - certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") - certGetters = append(certGetters, certGetter) - certGetter.Reload() - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.RequestClientCert, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: certGetter.GetCertificate, - } - if opts != nil && opts.RequireClientAuth { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - testCluster.ClientAuthRequired = true - } - tlsConfigs = append(tlsConfigs, tlsConfig) - lns := []*TestListener{ - { - Listener: tls.NewListener(ln, tlsConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - listeners = append(listeners, lns) - var handler http.Handler = http.NewServeMux() - handlers = append(handlers, handler) - server := &http.Server{ - Handler: handler, - ErrorLog: testCluster.Logger.StandardLogger(nil), - } - servers = append(servers, server) - } - - // Create three cores with the same physical and different redirect/cluster - // addrs. - // N.B.: On OSX, instead of random ports, it assigns new ports to new - // listeners sequentially. Aside from being a bad idea in a security sense, - // it also broke tests that assumed it was OK to just use the port above - // the redirect addr. This has now been changed to 105 ports above, but if - // we ever do more than three nodes in a cluster it may need to be bumped. - // Note: it's 105 so that we don't conflict with a running Consul by - // default. - coreConfig := &CoreConfig{ - LogicalBackends: make(map[string]logical.Factory), - CredentialBackends: make(map[string]logical.Factory), - AuditBackends: make(map[string]audit.Factory), - RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), - ClusterAddr: "https://127.0.0.1:0", - DisableMlock: true, - EnableUI: true, - EnableRaw: true, - BuiltinRegistry: NewMockBuiltinRegistry(), - } - - if base != nil { - coreConfig.RawConfig = base.RawConfig - coreConfig.DisableCache = base.DisableCache - coreConfig.EnableUI = base.EnableUI - coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL - coreConfig.MaxLeaseTTL = base.MaxLeaseTTL - coreConfig.CacheSize = base.CacheSize - coreConfig.PluginDirectory = base.PluginDirectory - coreConfig.Seal = base.Seal - coreConfig.UnwrapSeal = base.UnwrapSeal - coreConfig.DevToken = base.DevToken - coreConfig.EnableRaw = base.EnableRaw - coreConfig.DisableSealWrap = base.DisableSealWrap - coreConfig.DisableCache = base.DisableCache - coreConfig.LicensingConfig = base.LicensingConfig - coreConfig.License = base.License - coreConfig.LicensePath = base.LicensePath - coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby - coreConfig.MetricsHelper = base.MetricsHelper - coreConfig.MetricSink = base.MetricSink - coreConfig.SecureRandomReader = base.SecureRandomReader - coreConfig.DisableSentinelTrace = base.DisableSentinelTrace - coreConfig.ClusterName = base.ClusterName - coreConfig.DisableAutopilot = base.DisableAutopilot - - if base.BuiltinRegistry != nil { - coreConfig.BuiltinRegistry = base.BuiltinRegistry - } - - if !coreConfig.DisableMlock { - base.DisableMlock = false - } - - if base.Physical != nil { - coreConfig.Physical = base.Physical - } - - if base.HAPhysical != nil { - coreConfig.HAPhysical = base.HAPhysical - } - - // Used to set something non-working to test fallback - switch base.ClusterAddr { - case "empty": - coreConfig.ClusterAddr = "" - case "": - default: - coreConfig.ClusterAddr = base.ClusterAddr - } - - if base.LogicalBackends != nil { - for k, v := range base.LogicalBackends { - coreConfig.LogicalBackends[k] = v - } - } - if base.CredentialBackends != nil { - for k, v := range base.CredentialBackends { - coreConfig.CredentialBackends[k] = v - } - } - if base.AuditBackends != nil { - for k, v := range base.AuditBackends { - coreConfig.AuditBackends[k] = v - } - } - if base.Logger != nil { - coreConfig.Logger = base.Logger - } - - coreConfig.ClusterCipherSuites = base.ClusterCipherSuites - coreConfig.DisableCache = base.DisableCache - coreConfig.DevToken = base.DevToken - coreConfig.RecoveryMode = base.RecoveryMode - coreConfig.ActivityLogConfig = base.ActivityLogConfig - coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname - coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID - coreConfig.RollbackPeriod = base.RollbackPeriod - coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed - coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase - testApplyEntBaseConfig(coreConfig, base) - } - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterHeartbeatInterval == 0 { - // Set this lower so that state populates quickly to standby nodes - coreConfig.ClusterHeartbeatInterval = 2 * time.Second - } - - if coreConfig.RawConfig == nil { - c := new(server.Config) - c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} - coreConfig.RawConfig = c - } - - addAuditBackend := len(coreConfig.AuditBackends) == 0 - if addAuditBackend { - AddNoopAudit(coreConfig, nil) - } - - if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { - coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - } - if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { - haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - coreConfig.HAPhysical = haPhys.(physical.HABackend) - } - - if testCluster.LicensePublicKey == nil { - pubKey, priKey, err := GenerateTestLicenseKeys() - if err != nil { - t.Fatalf("err: %v", err) - } - testCluster.LicensePublicKey = pubKey - testCluster.LicensePrivateKey = priKey - } - - if opts != nil && opts.InmemClusterLayers { - if opts.ClusterLayers != nil { - t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") - } - inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) - if err != nil { - t.Fatal(err) - } - opts.ClusterLayers = inmemCluster - } - - // Create cores - testCluster.cleanupFuncs = []func(){} - cores := []*Core{} - coreConfigs := []*CoreConfig{} - - for i := 0; i < numCores; i++ { - cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) - - testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) - cores = append(cores, c) - coreConfigs = append(coreConfigs, &localConfig) - - if handler != nil { - handlers[i] = handler - servers[i].Handler = handlers[i] - } - } - - // Clustering setup - for i := 0; i < numCores; i++ { - testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) - } - - // Create TestClusterCores - var ret []*TestClusterCore - for i := 0; i < numCores; i++ { - tcc := &TestClusterCore{ - Core: cores[i], - CoreConfig: coreConfigs[i], - ServerKey: certInfoSlice[i].key, - ServerKeyPEM: certInfoSlice[i].keyPEM, - ServerCert: certInfoSlice[i].cert, - ServerCertBytes: certInfoSlice[i].certBytes, - ServerCertPEM: certInfoSlice[i].certPEM, - Address: addresses[i], - Listeners: listeners[i], - Handler: handlers[i], - Server: servers[i], - TLSConfig: tlsConfigs[i], - Barrier: cores[i].barrier, - NodeID: fmt.Sprintf("core-%d", i), - UnderlyingRawStorage: coreConfigs[i].Physical, - UnderlyingHAStorage: coreConfigs[i].HAPhysical, - } - tcc.ReloadFuncs = &cores[i].reloadFuncs - tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock - tcc.ReloadFuncsLock.Lock() - (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} - tcc.ReloadFuncsLock.Unlock() - - testAdjustUnderlyingStorage(tcc) - - ret = append(ret, tcc) - } - testCluster.Cores = ret - - // Initialize cores - if opts == nil || !opts.SkipInit { - testCluster.initCores(t, opts, addAuditBackend) - } - - // Assign clients - for i := 0; i < numCores; i++ { - testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) - } - - // Extra Setup - for _, tcc := range testCluster.Cores { - testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) - } - - // Cleanup - testCluster.CleanupFunc = func() { - for _, c := range testCluster.cleanupFuncs { - c() - } - if l, ok := testCluster.Logger.(*TestLogger); ok { - if t.Failed() { - _ = l.File.Close() - } else { - _ = os.Remove(l.Path) - } - } - } - - // Setup - if opts != nil { - if opts.SetupFunc != nil { - testCluster.SetupFunc = func() { - opts.SetupFunc(t, &testCluster) - } - } - } - - testCluster.opts = opts - testCluster.start(t) - return &testCluster -} +const TestDeadlockDetection = "" diff --git a/vault/test_cluster_detect_deadlock.go b/vault/test_cluster_detect_deadlock.go index 07652938331a..154a948f463e 100644 --- a/vault/test_cluster_detect_deadlock.go +++ b/vault/test_cluster_detect_deadlock.go @@ -2,555 +2,4 @@ package vault -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "io/ioutil" - "math/big" - mathrand "math/rand" - "net" - "net/http" - "os" - "path/filepath" - "time" - - "github.com/mitchellh/go-testing-interface" - - "github.com/hashicorp/go-secure-stdlib/reloadutil" - "github.com/hashicorp/vault/audit" - "github.com/hashicorp/vault/command/server" - "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/sdk/helper/logging" - "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/sdk/physical" - physInmem "github.com/hashicorp/vault/sdk/physical/inmem" - "github.com/hashicorp/vault/vault/cluster" -) - -// NewTestCluster creates a new test cluster based on the provided core config -// and test cluster options. -// -// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a -// core config for each core it creates. If separate seal per core is desired, opts.SealFunc -// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be -// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the -// provided Seal in coreConfig (i.e. base.Seal) is nil. -// -// If opts.Logger is provided, it takes precedence and will be used as the cluster -// logger and will be the basis for each core's logger. If no opts.Logger is -// given, one will be generated based on t.Name() for the cluster logger, and if -// no base.Logger is given will also be used as the basis for each core's logger. - -func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { - base.DetectDeadlocks = "stateLock" // detect deadlocks because of build tag "deadlock" - - var err error - - var numCores int - if opts == nil || opts.NumCores == 0 { - numCores = DefaultNumCores - } else { - numCores = opts.NumCores - } - - certIPs := []net.IP{ - net.IPv6loopback, - net.ParseIP("127.0.0.1"), - } - var baseAddr *net.TCPAddr - if opts != nil && opts.BaseListenAddress != "" { - baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) - - if err != nil { - t.Fatal("could not parse given base IP") - } - certIPs = append(certIPs, baseAddr.IP) - } else { - baseAddr = &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - } - - var testCluster TestCluster - testCluster.base = base - - switch { - case opts != nil && opts.Logger != nil: - testCluster.Logger = opts.Logger - default: - testCluster.Logger = NewTestLogger(t) - } - - if opts != nil && opts.TempDir != "" { - if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { - if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { - t.Fatal(err) - } - } - testCluster.TempDir = opts.TempDir - } else { - tempDir, err := ioutil.TempDir("", "vault-test-cluster-") - if err != nil { - t.Fatal(err) - } - testCluster.TempDir = tempDir - } - - var caKey *ecdsa.PrivateKey - if opts != nil && opts.CAKey != nil { - caKey = opts.CAKey - } else { - caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - } - testCluster.CAKey = caKey - var caBytes []byte - if opts != nil && len(opts.CACert) > 0 { - caBytes = opts.CACert - } else { - caCertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) - if err != nil { - t.Fatal(err) - } - } - caCert, err := x509.ParseCertificate(caBytes) - if err != nil { - t.Fatal(err) - } - testCluster.CACert = caCert - testCluster.CACertBytes = caBytes - testCluster.RootCAs = x509.NewCertPool() - testCluster.RootCAs.AddCert(caCert) - caCertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - } - testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) - testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") - err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) - if err != nil { - t.Fatal(err) - } - marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) - if err != nil { - t.Fatal(err) - } - caKeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledCAKey, - } - testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - - var certInfoSlice []*certInfo - - // - // Certs generation - // - for i := 0; i < numCores; i++ { - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - certTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. - // This allows Prometheus running in docker to scrape the cluster for metrics. - DNSNames: []string{"localhost", "host.docker.internal"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - t.Fatal(err) - } - certPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - } - certPEM := pem.EncodeToMemory(certPEMBlock) - marshaledKey, err := x509.MarshalECPrivateKey(key) - if err != nil { - t.Fatal(err) - } - keyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: marshaledKey, - } - keyPEM := pem.EncodeToMemory(keyPEMBlock) - - certInfoSlice = append(certInfoSlice, &certInfo{ - cert: cert, - certPEM: certPEM, - certBytes: certBytes, - key: key, - keyPEM: keyPEM, - }) - } - - // - // Listener setup - // - addresses := []*net.TCPAddr{} - listeners := [][]*TestListener{} - servers := []*http.Server{} - handlers := []http.Handler{} - tlsConfigs := []*tls.Config{} - certGetters := []*reloadutil.CertificateGetter{} - for i := 0; i < numCores; i++ { - addr := &net.TCPAddr{ - IP: baseAddr.IP, - Port: 0, - } - if baseAddr.Port != 0 { - addr.Port = baseAddr.Port + i - } - - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - t.Fatal(err) - } - addresses = append(addresses, addr) - - certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) - if err != nil { - t.Fatal(err) - } - tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) - if err != nil { - t.Fatal(err) - } - certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") - certGetters = append(certGetters, certGetter) - certGetter.Reload() - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.RequestClientCert, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: certGetter.GetCertificate, - } - if opts != nil && opts.RequireClientAuth { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - testCluster.ClientAuthRequired = true - } - tlsConfigs = append(tlsConfigs, tlsConfig) - lns := []*TestListener{ - { - Listener: tls.NewListener(ln, tlsConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - listeners = append(listeners, lns) - var handler http.Handler = http.NewServeMux() - handlers = append(handlers, handler) - server := &http.Server{ - Handler: handler, - ErrorLog: testCluster.Logger.StandardLogger(nil), - } - servers = append(servers, server) - } - - // Create three cores with the same physical and different redirect/cluster - // addrs. - // N.B.: On OSX, instead of random ports, it assigns new ports to new - // listeners sequentially. Aside from being a bad idea in a security sense, - // it also broke tests that assumed it was OK to just use the port above - // the redirect addr. This has now been changed to 105 ports above, but if - // we ever do more than three nodes in a cluster it may need to be bumped. - // Note: it's 105 so that we don't conflict with a running Consul by - // default. - coreConfig := &CoreConfig{ - LogicalBackends: make(map[string]logical.Factory), - CredentialBackends: make(map[string]logical.Factory), - AuditBackends: make(map[string]audit.Factory), - RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), - ClusterAddr: "https://127.0.0.1:0", - DisableMlock: true, - EnableUI: true, - EnableRaw: true, - BuiltinRegistry: NewMockBuiltinRegistry(), - } - - if base != nil { - coreConfig.RawConfig = base.RawConfig - coreConfig.DisableCache = base.DisableCache - coreConfig.EnableUI = base.EnableUI - coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL - coreConfig.MaxLeaseTTL = base.MaxLeaseTTL - coreConfig.CacheSize = base.CacheSize - coreConfig.PluginDirectory = base.PluginDirectory - coreConfig.Seal = base.Seal - coreConfig.UnwrapSeal = base.UnwrapSeal - coreConfig.DevToken = base.DevToken - coreConfig.EnableRaw = base.EnableRaw - coreConfig.DisableSealWrap = base.DisableSealWrap - coreConfig.DisableCache = base.DisableCache - coreConfig.LicensingConfig = base.LicensingConfig - coreConfig.License = base.License - coreConfig.LicensePath = base.LicensePath - coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby - coreConfig.MetricsHelper = base.MetricsHelper - coreConfig.MetricSink = base.MetricSink - coreConfig.SecureRandomReader = base.SecureRandomReader - coreConfig.DisableSentinelTrace = base.DisableSentinelTrace - coreConfig.ClusterName = base.ClusterName - coreConfig.DisableAutopilot = base.DisableAutopilot - - if base.BuiltinRegistry != nil { - coreConfig.BuiltinRegistry = base.BuiltinRegistry - } - - if !coreConfig.DisableMlock { - base.DisableMlock = false - } - - if base.Physical != nil { - coreConfig.Physical = base.Physical - } - - if base.HAPhysical != nil { - coreConfig.HAPhysical = base.HAPhysical - } - - // Used to set something non-working to test fallback - switch base.ClusterAddr { - case "empty": - coreConfig.ClusterAddr = "" - case "": - default: - coreConfig.ClusterAddr = base.ClusterAddr - } - - if base.LogicalBackends != nil { - for k, v := range base.LogicalBackends { - coreConfig.LogicalBackends[k] = v - } - } - if base.CredentialBackends != nil { - for k, v := range base.CredentialBackends { - coreConfig.CredentialBackends[k] = v - } - } - if base.AuditBackends != nil { - for k, v := range base.AuditBackends { - coreConfig.AuditBackends[k] = v - } - } - if base.Logger != nil { - coreConfig.Logger = base.Logger - } - - coreConfig.ClusterCipherSuites = base.ClusterCipherSuites - coreConfig.DisableCache = base.DisableCache - coreConfig.DevToken = base.DevToken - coreConfig.RecoveryMode = base.RecoveryMode - coreConfig.ActivityLogConfig = base.ActivityLogConfig - coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname - coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID - coreConfig.RollbackPeriod = base.RollbackPeriod - coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed - coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase - testApplyEntBaseConfig(coreConfig, base) - } - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterName == "" { - coreConfig.ClusterName = t.Name() - } - - if coreConfig.ClusterHeartbeatInterval == 0 { - // Set this lower so that state populates quickly to standby nodes - coreConfig.ClusterHeartbeatInterval = 2 * time.Second - } - - if coreConfig.RawConfig == nil { - c := new(server.Config) - c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} - coreConfig.RawConfig = c - } - - addAuditBackend := len(coreConfig.AuditBackends) == 0 - if addAuditBackend { - AddNoopAudit(coreConfig, nil) - } - - if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { - coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - } - if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { - haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) - if err != nil { - t.Fatal(err) - } - coreConfig.HAPhysical = haPhys.(physical.HABackend) - } - - if testCluster.LicensePublicKey == nil { - pubKey, priKey, err := GenerateTestLicenseKeys() - if err != nil { - t.Fatalf("err: %v", err) - } - testCluster.LicensePublicKey = pubKey - testCluster.LicensePrivateKey = priKey - } - - if opts != nil && opts.InmemClusterLayers { - if opts.ClusterLayers != nil { - t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") - } - inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) - if err != nil { - t.Fatal(err) - } - opts.ClusterLayers = inmemCluster - } - - // Create cores - testCluster.cleanupFuncs = []func(){} - cores := []*Core{} - coreConfigs := []*CoreConfig{} - - for i := 0; i < numCores; i++ { - cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) - - testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) - cores = append(cores, c) - coreConfigs = append(coreConfigs, &localConfig) - - if handler != nil { - handlers[i] = handler - servers[i].Handler = handlers[i] - } - } - - // Clustering setup - for i := 0; i < numCores; i++ { - testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) - } - - // Create TestClusterCores - var ret []*TestClusterCore - for i := 0; i < numCores; i++ { - tcc := &TestClusterCore{ - Core: cores[i], - CoreConfig: coreConfigs[i], - ServerKey: certInfoSlice[i].key, - ServerKeyPEM: certInfoSlice[i].keyPEM, - ServerCert: certInfoSlice[i].cert, - ServerCertBytes: certInfoSlice[i].certBytes, - ServerCertPEM: certInfoSlice[i].certPEM, - Address: addresses[i], - Listeners: listeners[i], - Handler: handlers[i], - Server: servers[i], - TLSConfig: tlsConfigs[i], - Barrier: cores[i].barrier, - NodeID: fmt.Sprintf("core-%d", i), - UnderlyingRawStorage: coreConfigs[i].Physical, - UnderlyingHAStorage: coreConfigs[i].HAPhysical, - } - tcc.ReloadFuncs = &cores[i].reloadFuncs - tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock - tcc.ReloadFuncsLock.Lock() - (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} - tcc.ReloadFuncsLock.Unlock() - - testAdjustUnderlyingStorage(tcc) - - ret = append(ret, tcc) - } - testCluster.Cores = ret - - // Initialize cores - if opts == nil || !opts.SkipInit { - testCluster.initCores(t, opts, addAuditBackend) - } - - // Assign clients - for i := 0; i < numCores; i++ { - testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) - } - - // Extra Setup - for _, tcc := range testCluster.Cores { - testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) - } - - // Cleanup - testCluster.CleanupFunc = func() { - for _, c := range testCluster.cleanupFuncs { - c() - } - if l, ok := testCluster.Logger.(*TestLogger); ok { - if t.Failed() { - _ = l.File.Close() - } else { - _ = os.Remove(l.Path) - } - } - } - - // Setup - if opts != nil { - if opts.SetupFunc != nil { - testCluster.SetupFunc = func() { - opts.SetupFunc(t, &testCluster) - } - } - } - - testCluster.opts = opts - testCluster.start(t) - return &testCluster -} +const TestDeadlockDetection = "statelock" diff --git a/vault/testing.go b/vault/testing.go index 86fca41050d3..f7a3ba17e9b3 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -4,15 +4,20 @@ import ( "bytes" "context" "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" "encoding/base64" + "encoding/pem" "errors" "fmt" "io" "io/ioutil" + "math/big" + mathrand "math/rand" "net" "net/http" "os" @@ -1348,6 +1353,528 @@ func (tl *TestLogger) StopLogging() { tl.Logger.(log.InterceptLogger).DeregisterSink(tl.sink) } +// NewTestCluster creates a new test cluster based on the provided core config +// and test cluster options. +// +// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a +// core config for each core it creates. If separate seal per core is desired, opts.SealFunc +// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be +// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the +// provided Seal in coreConfig (i.e. base.Seal) is nil. +// +// If opts.Logger is provided, it takes precedence and will be used as the cluster +// logger and will be the basis for each core's logger. If no opts.Logger is +// given, one will be generated based on t.Name() for the cluster logger, and if +// no base.Logger is given will also be used as the basis for each core's logger. + +func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { + base.DetectDeadlocks = TestDeadlockDetection + + var err error + + var numCores int + if opts == nil || opts.NumCores == 0 { + numCores = DefaultNumCores + } else { + numCores = opts.NumCores + } + + certIPs := []net.IP{ + net.IPv6loopback, + net.ParseIP("127.0.0.1"), + } + var baseAddr *net.TCPAddr + if opts != nil && opts.BaseListenAddress != "" { + baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress) + + if err != nil { + t.Fatal("could not parse given base IP") + } + certIPs = append(certIPs, baseAddr.IP) + } else { + baseAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + } + + var testCluster TestCluster + testCluster.base = base + + switch { + case opts != nil && opts.Logger != nil: + testCluster.Logger = opts.Logger + default: + testCluster.Logger = NewTestLogger(t) + } + + if opts != nil && opts.TempDir != "" { + if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) { + if err := os.MkdirAll(opts.TempDir, 0o700); err != nil { + t.Fatal(err) + } + } + testCluster.TempDir = opts.TempDir + } else { + tempDir, err := ioutil.TempDir("", "vault-test-cluster-") + if err != nil { + t.Fatal(err) + } + testCluster.TempDir = tempDir + } + + var caKey *ecdsa.PrivateKey + if opts != nil && opts.CAKey != nil { + caKey = opts.CAKey + } else { + caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + } + testCluster.CAKey = caKey + var caBytes []byte + if opts != nil && len(opts.CACert) > 0 { + caBytes = opts.CACert + } else { + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: certIPs, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey) + if err != nil { + t.Fatal(err) + } + } + caCert, err := x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + testCluster.CACert = caCert + testCluster.CACertBytes = caBytes + testCluster.RootCAs = x509.NewCertPool() + testCluster.RootCAs.AddCert(caCert) + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock) + testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem") + err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0o755) + if err != nil { + t.Fatal(err) + } + marshaledCAKey, err := x509.MarshalECPrivateKey(caKey) + if err != nil { + t.Fatal(err) + } + caKeyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledCAKey, + } + testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock) + err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + + var certInfoSlice []*certInfo + + // + // Certs generation + // + for i := 0; i < numCores; i++ { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + // Include host.docker.internal for the sake of benchmark-vault running on MacOS/Windows. + // This allows Prometheus running in docker to scrape the cluster for metrics. + DNSNames: []string{"localhost", "host.docker.internal"}, + IPAddresses: certIPs, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) + if err != nil { + t.Fatal(err) + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + certPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + } + certPEM := pem.EncodeToMemory(certPEMBlock) + marshaledKey, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + keyPEM := pem.EncodeToMemory(keyPEMBlock) + + certInfoSlice = append(certInfoSlice, &certInfo{ + cert: cert, + certPEM: certPEM, + certBytes: certBytes, + key: key, + keyPEM: keyPEM, + }) + } + + // + // Listener setup + // + addresses := []*net.TCPAddr{} + listeners := [][]*TestListener{} + servers := []*http.Server{} + handlers := []http.Handler{} + tlsConfigs := []*tls.Config{} + certGetters := []*reloadutil.CertificateGetter{} + for i := 0; i < numCores; i++ { + addr := &net.TCPAddr{ + IP: baseAddr.IP, + Port: 0, + } + if baseAddr.Port != 0 { + addr.Port = baseAddr.Port + i + } + + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatal(err) + } + addresses = append(addresses, addr) + + certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0o755) + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0o755) + if err != nil { + t.Fatal(err) + } + tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) + if err != nil { + t.Fatal(err) + } + certGetter := reloadutil.NewCertificateGetter(certFile, keyFile, "") + certGetters = append(certGetters, certGetter) + certGetter.Reload() + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: testCluster.RootCAs, + ClientCAs: testCluster.RootCAs, + ClientAuth: tls.RequestClientCert, + NextProtos: []string{"h2", "http/1.1"}, + GetCertificate: certGetter.GetCertificate, + } + if opts != nil && opts.RequireClientAuth { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + testCluster.ClientAuthRequired = true + } + tlsConfigs = append(tlsConfigs, tlsConfig) + lns := []*TestListener{ + { + Listener: tls.NewListener(ln, tlsConfig), + Address: ln.Addr().(*net.TCPAddr), + }, + } + listeners = append(listeners, lns) + var handler http.Handler = http.NewServeMux() + handlers = append(handlers, handler) + server := &http.Server{ + Handler: handler, + ErrorLog: testCluster.Logger.StandardLogger(nil), + } + servers = append(servers, server) + } + + // Create three cores with the same physical and different redirect/cluster + // addrs. + // N.B.: On OSX, instead of random ports, it assigns new ports to new + // listeners sequentially. Aside from being a bad idea in a security sense, + // it also broke tests that assumed it was OK to just use the port above + // the redirect addr. This has now been changed to 105 ports above, but if + // we ever do more than three nodes in a cluster it may need to be bumped. + // Note: it's 105 so that we don't conflict with a running Consul by + // default. + coreConfig := &CoreConfig{ + LogicalBackends: make(map[string]logical.Factory), + CredentialBackends: make(map[string]logical.Factory), + AuditBackends: make(map[string]audit.Factory), + RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), + ClusterAddr: "https://127.0.0.1:0", + DisableMlock: true, + EnableUI: true, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + + if base != nil { + coreConfig.RawConfig = base.RawConfig + coreConfig.DisableCache = base.DisableCache + coreConfig.EnableUI = base.EnableUI + coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL + coreConfig.MaxLeaseTTL = base.MaxLeaseTTL + coreConfig.CacheSize = base.CacheSize + coreConfig.PluginDirectory = base.PluginDirectory + coreConfig.Seal = base.Seal + coreConfig.UnwrapSeal = base.UnwrapSeal + coreConfig.DevToken = base.DevToken + coreConfig.EnableRaw = base.EnableRaw + coreConfig.DisableSealWrap = base.DisableSealWrap + coreConfig.DisableCache = base.DisableCache + coreConfig.LicensingConfig = base.LicensingConfig + coreConfig.License = base.License + coreConfig.LicensePath = base.LicensePath + coreConfig.DisablePerformanceStandby = base.DisablePerformanceStandby + coreConfig.MetricsHelper = base.MetricsHelper + coreConfig.MetricSink = base.MetricSink + coreConfig.SecureRandomReader = base.SecureRandomReader + coreConfig.DisableSentinelTrace = base.DisableSentinelTrace + coreConfig.ClusterName = base.ClusterName + coreConfig.DisableAutopilot = base.DisableAutopilot + + if base.BuiltinRegistry != nil { + coreConfig.BuiltinRegistry = base.BuiltinRegistry + } + + if !coreConfig.DisableMlock { + base.DisableMlock = false + } + + if base.Physical != nil { + coreConfig.Physical = base.Physical + } + + if base.HAPhysical != nil { + coreConfig.HAPhysical = base.HAPhysical + } + + // Used to set something non-working to test fallback + switch base.ClusterAddr { + case "empty": + coreConfig.ClusterAddr = "" + case "": + default: + coreConfig.ClusterAddr = base.ClusterAddr + } + + if base.LogicalBackends != nil { + for k, v := range base.LogicalBackends { + coreConfig.LogicalBackends[k] = v + } + } + if base.CredentialBackends != nil { + for k, v := range base.CredentialBackends { + coreConfig.CredentialBackends[k] = v + } + } + if base.AuditBackends != nil { + for k, v := range base.AuditBackends { + coreConfig.AuditBackends[k] = v + } + } + if base.Logger != nil { + coreConfig.Logger = base.Logger + } + + coreConfig.ClusterCipherSuites = base.ClusterCipherSuites + coreConfig.DisableCache = base.DisableCache + coreConfig.DevToken = base.DevToken + coreConfig.RecoveryMode = base.RecoveryMode + coreConfig.ActivityLogConfig = base.ActivityLogConfig + coreConfig.EnableResponseHeaderHostname = base.EnableResponseHeaderHostname + coreConfig.EnableResponseHeaderRaftNodeID = base.EnableResponseHeaderRaftNodeID + coreConfig.RollbackPeriod = base.RollbackPeriod + coreConfig.PendingRemovalMountsAllowed = base.PendingRemovalMountsAllowed + coreConfig.ExpirationRevokeRetryBase = base.ExpirationRevokeRetryBase + testApplyEntBaseConfig(coreConfig, base) + } + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterName == "" { + coreConfig.ClusterName = t.Name() + } + + if coreConfig.ClusterHeartbeatInterval == 0 { + // Set this lower so that state populates quickly to standby nodes + coreConfig.ClusterHeartbeatInterval = 2 * time.Second + } + + if coreConfig.RawConfig == nil { + c := new(server.Config) + c.SharedConfig = &configutil.SharedConfig{LogFormat: logging.UnspecifiedFormat.String()} + coreConfig.RawConfig = c + } + + addAuditBackend := len(coreConfig.AuditBackends) == 0 + if addAuditBackend { + AddNoopAudit(coreConfig, nil) + } + + if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { + coreConfig.Physical, err = physInmem.NewInmem(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + } + if coreConfig.HAPhysical == nil && (opts == nil || opts.PhysicalFactory == nil) { + haPhys, err := physInmem.NewInmemHA(nil, testCluster.Logger) + if err != nil { + t.Fatal(err) + } + coreConfig.HAPhysical = haPhys.(physical.HABackend) + } + + if testCluster.LicensePublicKey == nil { + pubKey, priKey, err := GenerateTestLicenseKeys() + if err != nil { + t.Fatalf("err: %v", err) + } + testCluster.LicensePublicKey = pubKey + testCluster.LicensePrivateKey = priKey + } + + if opts != nil && opts.InmemClusterLayers { + if opts.ClusterLayers != nil { + t.Fatalf("cannot specify ClusterLayers when InmemClusterLayers is true") + } + inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", numCores, testCluster.Logger.Named("inmem-cluster")) + if err != nil { + t.Fatal(err) + } + opts.ClusterLayers = inmemCluster + } + + // Create cores + testCluster.cleanupFuncs = []func(){} + cores := []*Core{} + coreConfigs := []*CoreConfig{} + + for i := 0; i < numCores; i++ { + cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], testCluster.LicensePublicKey) + + testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) + cores = append(cores, c) + coreConfigs = append(coreConfigs, &localConfig) + + if handler != nil { + handlers[i] = handler + servers[i].Handler = handlers[i] + } + } + + // Clustering setup + for i := 0; i < numCores; i++ { + testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) + } + + // Create TestClusterCores + var ret []*TestClusterCore + for i := 0; i < numCores; i++ { + tcc := &TestClusterCore{ + Core: cores[i], + CoreConfig: coreConfigs[i], + ServerKey: certInfoSlice[i].key, + ServerKeyPEM: certInfoSlice[i].keyPEM, + ServerCert: certInfoSlice[i].cert, + ServerCertBytes: certInfoSlice[i].certBytes, + ServerCertPEM: certInfoSlice[i].certPEM, + Address: addresses[i], + Listeners: listeners[i], + Handler: handlers[i], + Server: servers[i], + TLSConfig: tlsConfigs[i], + Barrier: cores[i].barrier, + NodeID: fmt.Sprintf("core-%d", i), + UnderlyingRawStorage: coreConfigs[i].Physical, + UnderlyingHAStorage: coreConfigs[i].HAPhysical, + } + tcc.ReloadFuncs = &cores[i].reloadFuncs + tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock + tcc.ReloadFuncsLock.Lock() + (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} + tcc.ReloadFuncsLock.Unlock() + + testAdjustUnderlyingStorage(tcc) + + ret = append(ret, tcc) + } + testCluster.Cores = ret + + // Initialize cores + if opts == nil || !opts.SkipInit { + testCluster.initCores(t, opts, addAuditBackend) + } + + // Assign clients + for i := 0; i < numCores; i++ { + testCluster.Cores[i].Client = testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) + } + + // Extra Setup + for _, tcc := range testCluster.Cores { + testExtraTestCoreSetup(t, testCluster.LicensePrivateKey, tcc) + } + + // Cleanup + testCluster.CleanupFunc = func() { + for _, c := range testCluster.cleanupFuncs { + c() + } + if l, ok := testCluster.Logger.(*TestLogger); ok { + if t.Failed() { + _ = l.File.Close() + } else { + _ = os.Remove(l.Path) + } + } + } + + // Setup + if opts != nil { + if opts.SetupFunc != nil { + testCluster.SetupFunc = func() { + opts.SetupFunc(t, &testCluster) + } + } + } + + testCluster.opts = opts + testCluster.start(t) + return &testCluster +} + // StopCore performs an orderly shutdown of a core. func (cluster *TestCluster) StopCore(t testing.T, idx int) { t.Helper() From e404ad0c64153621bc43e7e9006180976bf10119 Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Fri, 6 Jan 2023 15:50:36 -0600 Subject: [PATCH 5/6] remove line --- vault/testing.go | 1 - 1 file changed, 1 deletion(-) diff --git a/vault/testing.go b/vault/testing.go index f7a3ba17e9b3..98734a206ab0 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1366,7 +1366,6 @@ func (tl *TestLogger) StopLogging() { // logger and will be the basis for each core's logger. If no opts.Logger is // given, one will be generated based on t.Name() for the cluster logger, and if // no base.Logger is given will also be used as the basis for each core's logger. - func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { base.DetectDeadlocks = TestDeadlockDetection From a51ac377208fa0d9d716f830420eb533829210cb Mon Sep 17 00:00:00 2001 From: Ellie Sterner Date: Mon, 9 Jan 2023 10:40:36 -0600 Subject: [PATCH 6/6] rename file, and move where detect deadlock flag is set --- vault/core.go | 2 +- ...{test_cluster.go => test_cluster_do_not_detect_deadlock.go} | 0 vault/testing.go | 3 +-- 3 files changed, 2 insertions(+), 3 deletions(-) rename vault/{test_cluster.go => test_cluster_do_not_detect_deadlock.go} (100%) diff --git a/vault/core.go b/vault/core.go index 78952f682646..08fc01f23757 100644 --- a/vault/core.go +++ b/vault/core.go @@ -890,7 +890,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // Use imported logging deadlock if requested var stateLock locking.RWMutex - if conf.DetectDeadlocks != "" && strings.Contains(conf.DetectDeadlocks, "statelock") { + if strings.Contains(conf.DetectDeadlocks, "statelock") { stateLock = &locking.DeadlockRWMutex{} } else { stateLock = &locking.SyncRWMutex{} diff --git a/vault/test_cluster.go b/vault/test_cluster_do_not_detect_deadlock.go similarity index 100% rename from vault/test_cluster.go rename to vault/test_cluster_do_not_detect_deadlock.go diff --git a/vault/testing.go b/vault/testing.go index 98734a206ab0..0b3d63dd74e2 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1367,8 +1367,6 @@ func (tl *TestLogger) StopLogging() { // given, one will be generated based on t.Name() for the cluster logger, and if // no base.Logger is given will also be used as the basis for each core's logger. func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { - base.DetectDeadlocks = TestDeadlockDetection - var err error var numCores int @@ -1637,6 +1635,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } if base != nil { + coreConfig.DetectDeadlocks = TestDeadlockDetection coreConfig.RawConfig = base.RawConfig coreConfig.DisableCache = base.DisableCache coreConfig.EnableUI = base.EnableUI