From 91dced95fb01a37487293dfb948985f85daffb6b Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Fri, 17 Nov 2023 10:51:04 -0700 Subject: [PATCH] Bug fix for OAuth m2m scopes Updated m2m authenticator to use "all-apis" scope. Added a new constructor function for m2m authenticator that allows client to pass in additional scopes. Signed-off-by: Raymond Cypher --- auth/oauth/m2m/m2m.go | 14 +++++++++++- auth/oauth/m2m/m2m_test.go | 40 +++++++++++++++++++++++++++++++++ auth/oauth/oauth.go | 8 +++---- auth/oauth/u2m/authenticator.go | 12 +++++----- driverctx/ctx.go | 2 ++ 5 files changed, 65 insertions(+), 11 deletions(-) create mode 100644 auth/oauth/m2m/m2m_test.go diff --git a/auth/oauth/m2m/m2m.go b/auth/oauth/m2m/m2m.go index aa72eb9..7b68404 100644 --- a/auth/oauth/m2m/m2m.go +++ b/auth/oauth/m2m/m2m.go @@ -17,7 +17,11 @@ import ( ) func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator { - scopes := oauth.GetScopes(hostName, []string{}) + return NewAuthenticatorWithScopes(clientID, clientSecret, hostName, []string{}) +} + +func NewAuthenticatorWithScopes(clientID, clientSecret, hostName string, scopes []string) auth.Authenticator { + scopes = GetScopes(hostName, scopes) return &authClient{ clientID: clientID, clientSecret: clientSecret, @@ -89,3 +93,11 @@ func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, sc return config, nil } + +func GetScopes(hostName string, scopes []string) []string { + if !oauth.HasScope(scopes, "all-apis") { + scopes = append(scopes, "all-apis") + } + + return scopes +} diff --git a/auth/oauth/m2m/m2m_test.go b/auth/oauth/m2m/m2m_test.go new file mode 100644 index 0000000..fe1e51c --- /dev/null +++ b/auth/oauth/m2m/m2m_test.go @@ -0,0 +1,40 @@ +package m2m + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestM2MScopes(t *testing.T) { + t.Run("default should be [all-apis]", func(t *testing.T) { + auth := NewAuthenticator("id", "secret", "staging.cloud.company.com").(*authClient) + assert.Equal(t, "id", auth.clientID) + assert.Equal(t, "secret", auth.clientSecret) + assert.Equal(t, []string{"all-apis"}, auth.scopes) + + auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", nil).(*authClient) + assert.Equal(t, "id", auth.clientID) + assert.Equal(t, "secret", auth.clientSecret) + assert.Equal(t, []string{"all-apis"}, auth.scopes) + + auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{}).(*authClient) + assert.Equal(t, "id", auth.clientID) + assert.Equal(t, "secret", auth.clientSecret) + assert.Equal(t, []string{"all-apis"}, auth.scopes) + }) + + t.Run("should add all-apis to passed scopes", func(t *testing.T) { + auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"my-scope"}).(*authClient) + assert.Equal(t, "id", auth.clientID) + assert.Equal(t, "secret", auth.clientSecret) + assert.Equal(t, []string{"my-scope", "all-apis"}, auth.scopes) + }) + + t.Run("should not add all-apis if already in passed scopes", func(t *testing.T) { + auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"all-apis", "my-scope"}).(*authClient) + assert.Equal(t, "id", auth.clientID) + assert.Equal(t, "secret", auth.clientSecret) + assert.Equal(t, []string{"all-apis", "my-scope"}, auth.scopes) + }) +} diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index 2e94dba..2ffd28a 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -45,7 +45,7 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) func GetScopes(hostName string, scopes []string) []string { for _, s := range []string{oidc.ScopeOfflineAccess} { - if !hasScope(scopes, s) { + if !HasScope(scopes, s) { scopes = append(scopes, s) } } @@ -53,11 +53,11 @@ func GetScopes(hostName string, scopes []string) []string { cloudType := InferCloudFromHost(hostName) if cloudType == Azure { userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId) - if !hasScope(scopes, userImpersonationScope) { + if !HasScope(scopes, userImpersonationScope) { scopes = append(scopes, userImpersonationScope) } } else { - if !hasScope(scopes, "sql") { + if !HasScope(scopes, "sql") { scopes = append(scopes, "sql") } } @@ -65,7 +65,7 @@ func GetScopes(hostName string, scopes []string) []string { return scopes } -func hasScope(scopes []string, scope string) bool { +func HasScope(scopes []string, scope string) bool { for _, s := range scopes { if s == scope { return true diff --git a/auth/oauth/u2m/authenticator.go b/auth/oauth/u2m/authenticator.go index 29a973c..5a5aabb 100644 --- a/auth/oauth/u2m/authenticator.go +++ b/auth/oauth/u2m/authenticator.go @@ -25,11 +25,11 @@ import ( ) const ( - azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5" - azureRedirctURL = "localhost:8030" + azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5" + azureRedirectURL = "localhost:8030" - awsClientId = "databricks-sql-connector" - awsRedirctURL = "localhost:8030" + awsClientId = "databricks-sql-connector" + awsRedirectURL = "localhost:8030" ) func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) { @@ -39,10 +39,10 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato var clientID, redirectURL string if cloud == oauth.AWS { clientID = awsClientId - redirectURL = awsRedirctURL + redirectURL = awsRedirectURL } else if cloud == oauth.Azure { clientID = azureClientId - redirectURL = azureRedirctURL + redirectURL = azureRedirectURL } else { return nil, errors.New("unhandled cloud type: " + cloud.String()) } diff --git a/driverctx/ctx.go b/driverctx/ctx.go index 4ccbbbe..f8f4674 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -38,6 +38,7 @@ func CorrelationIdFromContext(ctx context.Context) string { } // NewContextWithConnId creates a new context with connectionId value. +// The connection ID will be displayed in log messages and other dianostic information. func NewContextWithConnId(ctx context.Context, connId string) context.Context { if callback, ok := ctx.Value(ConnIdCallbackKey).(IdCallbackFunc); ok { callback(connId) @@ -59,6 +60,7 @@ func ConnIdFromContext(ctx context.Context) string { } // NewContextWithQueryId creates a new context with queryId value. +// The query id will be displayed in log messages and other diagnostic information. func NewContextWithQueryId(ctx context.Context, queryId string) context.Context { if callback, ok := ctx.Value(QueryIdCallbackKey).(IdCallbackFunc); ok { callback(queryId)