Skip to content

Commit

Permalink
GOCBC-682: Ensure that connections use the correct credentials
Browse files Browse the repository at this point in the history
Motivation
----------
When we initially setup authentication we tie the credentials that
we use for all further auth to those of the first endpoint tried.
This fails when different endpoints have different credentials.

Changes
-------
Slightly refactor how we perform authentication so that each time
we bootstrap on a connection we fetch the credentials for that
individual endpoint.

Change-Id: I91fe8897f8fc52a44ac8e7adeb05bc300674de43
Reviewed-on: http://review.couchbase.org/122030
Reviewed-by: Brett Lawson <brett19@gmail.com>
Reviewed-by: Abhinav Dangeti <abhinav@couchbase.com>
Tested-by: Abhinav Dangeti <abhinav@couchbase.com>
  • Loading branch information
chvck authored and abhinavdangeti committed Feb 11, 2020
1 parent 5f6bf0c commit a5c00dd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 70 deletions.
75 changes: 17 additions & 58 deletions agent.go
Expand Up @@ -36,10 +36,8 @@ type Agent struct {
clientID string
userAgent string
auth AuthProvider
authHandler authFunc
authHandler authFuncHandler
authMechanisms []AuthMechanism
nextAuth func(mechanism AuthMechanism)
nextAuthLock sync.Mutex
bucketName string
bucketLock sync.Mutex
tlsConfig *tls.Config
Expand Down Expand Up @@ -131,7 +129,9 @@ func (agent *Agent) getErrorMap() *kvErrorMap {
type AuthFunc func(client AuthClient, deadline time.Time, continueCb func(), completedCb func(error)) error

// authFunc wraps AuthFunc to provide a better to the user.
type authFunc func(client AuthClient, deadline time.Time) (completedCh chan BytesAndError, continueCh chan bool, err error)
type authFunc func() (completedCh chan BytesAndError, continueCh chan bool, err error)

type authFuncHandler func(client AuthClient, deadline time.Time, mechanism AuthMechanism) authFunc

// CreateAgent creates an agent for performing normal operations.
func CreateAgent(config *AgentConfig) (*Agent, error) {
Expand Down Expand Up @@ -198,6 +198,7 @@ func createAgent(config *AgentConfig, initFn memdInitFunc) (*Agent, error) {
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return config.Auth.Certificate(AuthCertRequest{})
},
InsecureSkipVerify: config.TLSSkipVerify,
}
}

Expand Down Expand Up @@ -369,24 +370,15 @@ func createAgent(config *AgentConfig, initFn memdInitFunc) (*Agent, error) {
return c, nil
}

func (agent *Agent) buildAuthHandler(address string) (func(mechanism AuthMechanism), error) {

if len(agent.authMechanisms) == 0 {
// If we're using something like client auth then we might not want an auth handler.
return nil, nil
}

var nextAuth func(mechanism AuthMechanism)
creds, err := getKvAuthCreds(agent.auth, address)
if err != nil {
return nil, err
}
func (agent *Agent) buildAuthHandler() authFuncHandler {
return func(client AuthClient, deadline time.Time, mechanism AuthMechanism) authFunc {
creds, err := getKvAuthCreds(agent.auth, client.Address())
if err != nil {
return nil
}

if creds.Username != "" || creds.Password != "" {
// If we only have 1 auth mechanism then we've either we've already decided what mechanism to use
// or the user has only decided to support 1. Either way we don't need to check what the server supports.
getAuthFunc := func(mechanism AuthMechanism) authFunc {
return func(client AuthClient, deadline time.Time) (chan BytesAndError, chan bool, error) {
if creds.Username != "" || creds.Password != "" {
return func() (chan BytesAndError, chan bool, error) {
continueCh := make(chan bool, 1)
completedCh := make(chan BytesAndError, 1)
hasContinued := int32(0)
Expand All @@ -412,17 +404,8 @@ func (agent *Agent) buildAuthHandler(address string) (func(mechanism AuthMechani
}
}

if len(agent.authMechanisms) == 1 {
agent.authHandler = getAuthFunc(agent.authMechanisms[0])
} else {
nextAuth = func(mechanism AuthMechanism) {
agent.authHandler = getAuthFunc(mechanism)
}
agent.authHandler = getAuthFunc(agent.authMechanisms[0])
}
return nil
}

return nextAuth, nil
}

func (agent *Agent) connectWithBucket(memdAddrs, httpAddrs []string, deadline time.Time) error {
Expand All @@ -446,11 +429,7 @@ func (agent *Agent) connectWithBucket(memdAddrs, httpAddrs []string, deadline ti
}

if agent.authHandler == nil {
agent.nextAuth, err = agent.buildAuthHandler(client.Address())
if err != nil {
logDebugf("Building auth failed %p/%s! %v", agent, thisHostPort, err)
continue
}
agent.authHandler = agent.buildAuthHandler()
}

logDebugf("Trying to bootstrap agent %p against %s", agent, thisHostPort)
Expand Down Expand Up @@ -569,11 +548,7 @@ func (agent *Agent) connectG3CP(memdAddrs, httpAddrs []string, deadline time.Tim
}

if agent.authHandler == nil {
agent.nextAuth, err = agent.buildAuthHandler(client.Address())
if err != nil {
logDebugf("Building auth failed %p/%s! %v", agent, thisHostPort, err)
continue
}
agent.authHandler = agent.buildAuthHandler()
}

logDebugf("Trying to bootstrap agent %p against %s", agent, thisHostPort)
Expand Down Expand Up @@ -718,11 +693,7 @@ func (agent *Agent) tryStartHTTPLooper(httpAddrs []string) error {
return true
}

_, err = agent.buildAuthHandler(srcServer)
if err != nil {
logDebugf("Building auth failed %p/%s! %v", agent, srcServer, err)
return false
}
agent.authHandler = agent.buildAuthHandler()

if agent.useCollections {
agent.supportsCollections = cfg.supports("collections")
Expand Down Expand Up @@ -1280,15 +1251,3 @@ func (agent *Agent) newMemdClientMux(hostPorts []string) *memdClientMux {

return newMemdClientMux(hostPorts, agent.kvPoolSize, agent.maxQueueSize, agent.slowDialMemdClient, agent.circuitBreakerConfig)
}

func (agent *Agent) getNextAuth() func(AuthMechanism) {
agent.nextAuthLock.Lock()
defer agent.nextAuthLock.Unlock()
return agent.nextAuth
}

func (agent *Agent) setNextAuth(fn func(AuthMechanism)) {
agent.nextAuthLock.Lock()
agent.nextAuth = fn
defer agent.nextAuthLock.Unlock()
}
28 changes: 16 additions & 12 deletions agentrouting.go
Expand Up @@ -99,7 +99,6 @@ func (agent *Agent) storeErrorMap(mapBytes []byte, client *memdClient) {
}

func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {

sclient := syncClient{
client: client,
}
Expand All @@ -109,6 +108,7 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {
bucket := agent.bucket()
features := agent.helloFeatures()
clientInfoStr := agent.clientInfoString(client.connID)
authMechanisms := agent.authMechanisms

helloCh, err := sclient.ExecHello(clientInfoStr, features, deadline)
if err != nil {
Expand All @@ -123,10 +123,10 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {
}

var listMechsCh chan SaslListMechsCompleted
nextAuth := agent.getNextAuth()
if nextAuth != nil {
firstAuthMethod := agent.authHandler(&sclient, deadline, authMechanisms[0])
// If the auth method is nil then we don't actually need to do any auth so no need to get the mechanisms.
if firstAuthMethod != nil {
listMechsCh = make(chan SaslListMechsCompleted)
// We only need to list mechs if there's more than 1 way to do auth.
err = sclient.SaslListMechs(deadline, func(mechs []AuthMechanism, err error) {
if err != nil {
logDebugf("Failed to fetch list auth mechs (%v)", err)
Expand All @@ -143,8 +143,8 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {

var completedAuthCh chan BytesAndError
var continueAuthCh chan bool
if agent.authHandler != nil {
completedAuthCh, continueAuthCh, err = agent.authHandler(&sclient, deadline)
if firstAuthMethod != nil {
completedAuthCh, continueAuthCh, err = firstAuthMethod()
if err != nil {
logDebugf("Failed to execute auth (%v)", err)
return err
Expand Down Expand Up @@ -188,15 +188,16 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {
}
}

// If completedAuthCh isn't nil then we have attempted to do auth so we need to wait on the result of that.
if completedAuthCh != nil {
authResp := <-completedAuthCh
if authResp.Err != nil {
logDebugf("Failed to perform auth against server (%v)", authResp.Err)
if agent.getNextAuth() == nil || errors.Is(authResp.Err, ErrAuthenticationFailure) {
// If there's an auth failure or there was only 1 mechanism to use then fail.
if len(authMechanisms) == 1 || errors.Is(authResp.Err, ErrAuthenticationFailure) {
return authResp.Err
}

authMechanisms := agent.authMechanisms
for {
var found bool
var mech AuthMechanism
Expand All @@ -206,8 +207,13 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {
return authResp.Err
}

nextAuth(mech)
completedAuthCh, continueAuthCh, err = agent.authHandler(&sclient, deadline)
nextAuthFunc := agent.authHandler(&sclient, deadline, mech)
if nextAuthFunc == nil {
// This can't really happen but just in case it somehow does.
logDebugf("Failed to authenticate, no available credentials")
return authResp.Err
}
completedAuthCh, continueAuthCh, err = nextAuthFunc()
if err != nil {
logDebugf("Failed to execute auth (%v)", err)
return err
Expand All @@ -234,8 +240,6 @@ func (agent *Agent) bootstrap(client *memdClient, deadline time.Time) error {
}
}
}
// prevent the next client from attempting to figure out what auth to use as we already have.
agent.setNextAuth(nil)
logDebugf("Authenticated successfully")
}

Expand Down

0 comments on commit a5c00dd

Please sign in to comment.