Skip to content

Commit

Permalink
Fixed the AzureTenantId in oauth.go (#210)
Browse files Browse the repository at this point in the history
AzureTenantId was hardcoded as a constant which was linked to
".staging.azuredatabricks.net". Therefore, if the host URL was for
example from ".azuredatabricks.net", it failed to detect ClientId. I
replaced the const value with a map to find the AzureTenantId based on
the DSN host.

---------

Signed-off-by: Erfan Mahmoodnejad <erfan.mahmoodnejad@gmail.com>
Co-authored-by: Erfan Mahmoodnejad <erfan.mahmoodnejad@gmail.com>
  • Loading branch information
tubiskasarus and tubiskasaroos authored Apr 16, 2024
1 parent aeb5e5d commit f06515c
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ import (
"golang.org/x/oauth2"
)

const (
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68"
)
var azureTenants = map[string]string{
".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68",
".azuredatabricks.net": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
".databricks.azure.us": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
".databricks.azure.cn": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
}

func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) {
if ctx == nil {
Expand Down Expand Up @@ -52,7 +56,7 @@ func GetScopes(hostName string, scopes []string) []string {

cloudType := InferCloudFromHost(hostName)
if cloudType == Azure {
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenants[GetAzureDnsZone(hostName)])
if !HasScope(scopes, userImpersonationScope) {
scopes = append(scopes, userImpersonationScope)
}
Expand Down Expand Up @@ -133,3 +137,12 @@ func InferCloudFromHost(hostname string) CloudType {
}
return Unknown
}

func GetAzureDnsZone(hostname string) string {
for _, d := range databricksAzureDomains {
if strings.Contains(hostname, d) {
return d
}
}
return ""
}

0 comments on commit f06515c

Please sign in to comment.