Skip to content

Commit

Permalink
Bug fix for OAuth m2m scopes (#178)
Browse files Browse the repository at this point in the history
Updated m2m authenticator to use "all-apis" scope.
Added a new constructor function for m2m authenticator that allows
client to pass in additional scopes.
Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
  • Loading branch information
rcypher-databricks committed Nov 17, 2023
2 parents 714e264 + 91dced9 commit 750c8a0
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
14 changes: 13 additions & 1 deletion auth/oauth/m2m/m2m.go
Expand Up @@ -17,7 +17,11 @@ import (
)

func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator {
scopes := oauth.GetScopes(hostName, []string{})
return NewAuthenticatorWithScopes(clientID, clientSecret, hostName, []string{})
}

func NewAuthenticatorWithScopes(clientID, clientSecret, hostName string, scopes []string) auth.Authenticator {
scopes = GetScopes(hostName, scopes)
return &authClient{
clientID: clientID,
clientSecret: clientSecret,
Expand Down Expand Up @@ -89,3 +93,11 @@ func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, sc

return config, nil
}

func GetScopes(hostName string, scopes []string) []string {
if !oauth.HasScope(scopes, "all-apis") {
scopes = append(scopes, "all-apis")
}

return scopes
}
40 changes: 40 additions & 0 deletions auth/oauth/m2m/m2m_test.go
@@ -0,0 +1,40 @@
package m2m

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestM2MScopes(t *testing.T) {
t.Run("default should be [all-apis]", func(t *testing.T) {
auth := NewAuthenticator("id", "secret", "staging.cloud.company.com").(*authClient)
assert.Equal(t, "id", auth.clientID)
assert.Equal(t, "secret", auth.clientSecret)
assert.Equal(t, []string{"all-apis"}, auth.scopes)

auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", nil).(*authClient)
assert.Equal(t, "id", auth.clientID)
assert.Equal(t, "secret", auth.clientSecret)
assert.Equal(t, []string{"all-apis"}, auth.scopes)

auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{}).(*authClient)
assert.Equal(t, "id", auth.clientID)
assert.Equal(t, "secret", auth.clientSecret)
assert.Equal(t, []string{"all-apis"}, auth.scopes)
})

t.Run("should add all-apis to passed scopes", func(t *testing.T) {
auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"my-scope"}).(*authClient)
assert.Equal(t, "id", auth.clientID)
assert.Equal(t, "secret", auth.clientSecret)
assert.Equal(t, []string{"my-scope", "all-apis"}, auth.scopes)
})

t.Run("should not add all-apis if already in passed scopes", func(t *testing.T) {
auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"all-apis", "my-scope"}).(*authClient)
assert.Equal(t, "id", auth.clientID)
assert.Equal(t, "secret", auth.clientSecret)
assert.Equal(t, []string{"all-apis", "my-scope"}, auth.scopes)
})
}
8 changes: 4 additions & 4 deletions auth/oauth/oauth.go
Expand Up @@ -45,27 +45,27 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error)

func GetScopes(hostName string, scopes []string) []string {
for _, s := range []string{oidc.ScopeOfflineAccess} {
if !hasScope(scopes, s) {
if !HasScope(scopes, s) {
scopes = append(scopes, s)
}
}

cloudType := InferCloudFromHost(hostName)
if cloudType == Azure {
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
if !hasScope(scopes, userImpersonationScope) {
if !HasScope(scopes, userImpersonationScope) {
scopes = append(scopes, userImpersonationScope)
}
} else {
if !hasScope(scopes, "sql") {
if !HasScope(scopes, "sql") {
scopes = append(scopes, "sql")
}
}

return scopes
}

func hasScope(scopes []string, scope string) bool {
func HasScope(scopes []string, scope string) bool {
for _, s := range scopes {
if s == scope {
return true
Expand Down
12 changes: 6 additions & 6 deletions auth/oauth/u2m/authenticator.go
Expand Up @@ -25,11 +25,11 @@ import (
)

const (
azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5"
azureRedirctURL = "localhost:8030"
azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5"
azureRedirectURL = "localhost:8030"

awsClientId = "databricks-sql-connector"
awsRedirctURL = "localhost:8030"
awsClientId = "databricks-sql-connector"
awsRedirectURL = "localhost:8030"
)

func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) {
Expand All @@ -39,10 +39,10 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato
var clientID, redirectURL string
if cloud == oauth.AWS {
clientID = awsClientId
redirectURL = awsRedirctURL
redirectURL = awsRedirectURL
} else if cloud == oauth.Azure {
clientID = azureClientId
redirectURL = azureRedirctURL
redirectURL = azureRedirectURL
} else {
return nil, errors.New("unhandled cloud type: " + cloud.String())
}
Expand Down
2 changes: 2 additions & 0 deletions driverctx/ctx.go
Expand Up @@ -38,6 +38,7 @@ func CorrelationIdFromContext(ctx context.Context) string {
}

// NewContextWithConnId creates a new context with connectionId value.
// The connection ID will be displayed in log messages and other dianostic information.
func NewContextWithConnId(ctx context.Context, connId string) context.Context {
if callback, ok := ctx.Value(ConnIdCallbackKey).(IdCallbackFunc); ok {
callback(connId)
Expand All @@ -59,6 +60,7 @@ func ConnIdFromContext(ctx context.Context) string {
}

// NewContextWithQueryId creates a new context with queryId value.
// The query id will be displayed in log messages and other diagnostic information.
func NewContextWithQueryId(ctx context.Context, queryId string) context.Context {
if callback, ok := ctx.Value(QueryIdCallbackKey).(IdCallbackFunc); ok {
callback(queryId)
Expand Down

0 comments on commit 750c8a0

Please sign in to comment.