diff --git a/agent.go b/agent.go index c688623..5681f98 100644 --- a/agent.go +++ b/agent.go @@ -36,10 +36,8 @@ type Agent struct { clientID string userAgent string auth AuthProvider - authHandler authFunc + authHandler authFuncHandler authMechanisms []AuthMechanism - nextAuth func(mechanism AuthMechanism) - nextAuthLock sync.Mutex bucketName string bucketLock sync.Mutex tlsConfig *tls.Config @@ -131,7 +129,9 @@ func (agent *Agent) getErrorMap() *kvErrorMap { type AuthFunc func(client AuthClient, deadline time.Time, continueCb func(), completedCb func(error)) error // authFunc wraps AuthFunc to provide a better to the user. -type authFunc func(client AuthClient, deadline time.Time) (completedCh chan BytesAndError, continueCh chan bool, err error) +type authFunc func() (completedCh chan BytesAndError, continueCh chan bool, err error) + +type authFuncHandler func(client AuthClient, deadline time.Time, mechanism AuthMechanism) authFunc // CreateAgent creates an agent for performing normal operations. func CreateAgent(config *AgentConfig) (*Agent, error) { @@ -198,6 +198,7 @@ func createAgent(config *AgentConfig, initFn memdInitFunc) (*Agent, error) { GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return config.Auth.Certificate(AuthCertRequest{}) }, + InsecureSkipVerify: config.TLSSkipVerify, } } @@ -369,24 +370,15 @@ func createAgent(config *AgentConfig, initFn memdInitFunc) (*Agent, error) { return c, nil } -func (agent *Agent) buildAuthHandler(address string) (func(mechanism AuthMechanism), error) { - - if len(agent.authMechanisms) == 0 { - // If we're using something like client auth then we might not want an auth handler. - return nil, nil - } - - var nextAuth func(mechanism AuthMechanism) - creds, err := getKvAuthCreds(agent.auth, address) - if err != nil { - return nil, err - } +func (agent *Agent) buildAuthHandler() authFuncHandler { + return func(client AuthClient, deadline time.Time, mechanism AuthMechanism) authFunc { + creds, err := getKvAuthCreds(agent.auth, client.Address()) + if err != nil { + return nil + } - if creds.Username != "" || creds.Password != "" { - // If we only have 1 auth mechanism then we've either we've already decided what mechanism to use - // or the user has only decided to support 1. Either way we don't need to check what the server supports. - getAuthFunc := func(mechanism AuthMechanism) authFunc { - return func(client AuthClient, deadline time.Time) (chan BytesAndError, chan bool, error) { + if creds.Username != "" || creds.Password != "" { + return func() (chan BytesAndError, chan bool, error) { continueCh := make(chan bool, 1) completedCh := make(chan BytesAndError, 1) hasContinued := int32(0) @@ -412,17 +404,8 @@ func (agent *Agent) buildAuthHandler(address string) (func(mechanism AuthMechani } } - if len(agent.authMechanisms) == 1 { - agent.authHandler = getAuthFunc(agent.authMechanisms[0]) - } else { - nextAuth = func(mechanism AuthMechanism) { - agent.authHandler = getAuthFunc(mechanism) - } - agent.authHandler = getAuthFunc(agent.authMechanisms[0]) - } + return nil } - - return nextAuth, nil } func (agent *Agent) connectWithBucket(memdAddrs, httpAddrs []string, deadline time.Time) error { @@ -446,11 +429,7 @@ func (agent *Agent) connectWithBucket(memdAddrs, httpAddrs []string, deadline ti } if agent.authHandler == nil { - agent.nextAuth, err = agent.buildAuthHandler(client.Address()) - if err != nil { - logDebugf("Building auth failed %p/%s! %v", agent, thisHostPort, err) - continue - } + agent.authHandler = agent.buildAuthHandler() } logDebugf("Trying to bootstrap agent %p against %s", agent, thisHostPort) @@ -569,11 +548,7 @@ func (agent *Agent) connectG3CP(memdAddrs, httpAddrs []string, deadline time.Tim } if agent.authHandler == nil { - agent.nextAuth, err = agent.buildAuthHandler(client.Address()) - if err != nil { - logDebugf("Building auth failed %p/%s! %v", agent, thisHostPort, err) - continue - } + agent.authHandler = agent.buildAuthHandler() } logDebugf("Trying to bootstrap agent %p against %s", agent, thisHostPort) @@ -718,11 +693,7 @@ func (agent *Agent) tryStartHTTPLooper(httpAddrs []string) error { return true } - _, err = agent.buildAuthHandler(srcServer) - if err != nil { - logDebugf("Building auth failed %p/%s! %v", agent, srcServer, err) - return false - } + agent.authHandler = agent.buildAuthHandler() if agent.useCollections { agent.supportsCollections = cfg.supports("collections") @@ -1280,15 +1251,3 @@ func (agent *Agent) newMemdClientMux(hostPorts []string) *memdClientMux { return newMemdClientMux(hostPorts, agent.kvPoolSize, agent.maxQueueSize, agent.slowDialMemdClient, agent.circuitBreakerConfig) } - -func (agent *Agent) getNextAuth() func(AuthMechanism) { - agent.nextAuthLock.Lock() - defer agent.nextAuthLock.Unlock() - return agent.nextAuth -} - -func (agent *Agent) setNextAuth(fn func(AuthMechanism)) { - agent.nextAuthLock.Lock() - agent.nextAuth = fn - defer agent.nextAuthLock.Unlock() -} diff --git a/agentrouting.go b/agentrouting.go index 957d4a9..37c6e63 100644 --- a/agentrouting.go +++ b/agentrouting.go @@ -99,7 +99,6 @@ func (agent *Agent) storeErrorMap(mapBytes []byte, client *memdClient) { } func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { - sclient := syncClient{ client: client, } @@ -109,6 +108,7 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { bucket := agent.bucket() features := agent.helloFeatures() clientInfoStr := agent.clientInfoString(client.connID) + authMechanisms := agent.authMechanisms helloCh, err := sclient.ExecHello(clientInfoStr, features, deadline) if err != nil { @@ -123,10 +123,10 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { } var listMechsCh chan SaslListMechsCompleted - nextAuth := agent.getNextAuth() - if nextAuth != nil { + firstAuthMethod := agent.authHandler(&sclient, deadline, authMechanisms[0]) + // If the auth method is nil then we don't actually need to do any auth so no need to get the mechanisms. + if firstAuthMethod != nil { listMechsCh = make(chan SaslListMechsCompleted) - // We only need to list mechs if there's more than 1 way to do auth. err = sclient.SaslListMechs(deadline, func(mechs []AuthMechanism, err error) { if err != nil { logDebugf("Failed to fetch list auth mechs (%v)", err) @@ -143,8 +143,8 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { var completedAuthCh chan BytesAndError var continueAuthCh chan bool - if agent.authHandler != nil { - completedAuthCh, continueAuthCh, err = agent.authHandler(&sclient, deadline) + if firstAuthMethod != nil { + completedAuthCh, continueAuthCh, err = firstAuthMethod() if err != nil { logDebugf("Failed to execute auth (%v)", err) return err @@ -188,15 +188,16 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { } } + // If completedAuthCh isn't nil then we have attempted to do auth so we need to wait on the result of that. if completedAuthCh != nil { authResp := <-completedAuthCh if authResp.Err != nil { logDebugf("Failed to perform auth against server (%v)", authResp.Err) - if agent.getNextAuth() == nil || errors.Is(authResp.Err, ErrAuthenticationFailure) { + // If there's an auth failure or there was only 1 mechanism to use then fail. + if len(authMechanisms) == 1 || errors.Is(authResp.Err, ErrAuthenticationFailure) { return authResp.Err } - authMechanisms := agent.authMechanisms for { var found bool var mech AuthMechanism @@ -206,8 +207,13 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { return authResp.Err } - nextAuth(mech) - completedAuthCh, continueAuthCh, err = agent.authHandler(&sclient, deadline) + nextAuthFunc := agent.authHandler(&sclient, deadline, mech) + if nextAuthFunc == nil { + // This can't really happen but just in case it somehow does. + logDebugf("Failed to authenticate, no available credentials") + return authResp.Err + } + completedAuthCh, continueAuthCh, err = nextAuthFunc() if err != nil { logDebugf("Failed to execute auth (%v)", err) return err @@ -234,8 +240,6 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error { } } } - // prevent the next client from attempting to figure out what auth to use as we already have. - agent.setNextAuth(nil) logDebugf("Authenticated successfully") }