Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer Azure tenant ID if not set #910

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (c AzureCliCredentials) Name() string {
// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureCliCredentials) tokenSourceFor(
ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return NewAzureCliTokenSource(resource)
return NewAzureCliTokenSource(ctx, resource, cfg.AzureTenantID)
}

// There are three scenarios:
Expand All @@ -43,7 +43,7 @@ func (c AzureCliCredentials) tokenSourceFor(
// If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service
// management token. Otherwise, we always pass the service management token.
func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner oauth2.TokenSource) (func(*http.Request) error, error) {
ts := &azureCliTokenSource{cfg.Environment().AzureServiceManagementEndpoint(), ""}
ts := &azureCliTokenSource{ctx, cfg.Environment().AzureServiceManagementEndpoint(), cfg.AzureResourceID, cfg.AzureTenantID}
t, err := ts.Token()
if err != nil {
logger.Debugf(ctx, "Not including service management token in headers: %v", err)
Expand All @@ -57,8 +57,13 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
if !cfg.IsAzure() {
return nil, nil
}
// Set the azure tenant ID from host if available
err := cfg.loadAzureTenantId(ctx)
if err != nil {
return nil, fmt.Errorf("load tenant id: %w", err)
}
// Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI.
ts := &azureCliTokenSource{cfg.Environment().AzureApplicationID, cfg.AzureResourceID}
ts := &azureCliTokenSource{ctx, cfg.Environment().AzureApplicationID, cfg.AzureResourceID, cfg.AzureTenantID}
t, err := ts.Token()
if err != nil {
if strings.Contains(err.Error(), "No subscription found") {
Expand All @@ -85,15 +90,19 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
}

// NewAzureCliTokenSource returns [oauth2.TokenSource] for a passwordless authentication via Azure CLI (`az login`)
func NewAzureCliTokenSource(resource string) oauth2.TokenSource {
func NewAzureCliTokenSource(ctx context.Context, resource, azureTenantId string) oauth2.TokenSource {
return &azureCliTokenSource{
resource: resource,
ctx: ctx,
resource: resource,
azureTenantId: azureTenantId,
}
}

type azureCliTokenSource struct {
ctx context.Context
resource string
workspaceResourceId string
azureTenantId string
}

type internalCliToken struct {
Expand Down Expand Up @@ -129,8 +138,12 @@ func (ts *azureCliTokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, fmt.Errorf("cannot parse expiry: %w", err)
}
logger.Infof(context.Background(), "Refreshed OAuth token for %s from Azure CLI, which expires on %s",
ts.resource, it.ExpiresOn)
tenantIdDebug := ""
if ts.azureTenantId != "" {
tenantIdDebug = fmt.Sprintf(" for tenant %s", ts.azureTenantId)
}
logger.Infof(context.Background(), "Refreshed OAuth token for %s%s from Azure CLI, which expires on %s",
ts.resource, tenantIdDebug, it.ExpiresOn)

var extra map[string]interface{}
err = json.Unmarshal(tokenBytes, &extra)
Expand All @@ -146,23 +159,24 @@ func (ts *azureCliTokenSource) Token() (*oauth2.Token, error) {
}

func (ts *azureCliTokenSource) getTokenBytes() ([]byte, error) {
subscription := ts.getSubscription()
args := []string{"account", "get-access-token", "--resource",
ts.resource, "--output", "json"}
if subscription != "" {
extendedArgs := make([]string, len(args))
copy(extendedArgs, args)
extendedArgs = append(extendedArgs, "--subscription", subscription)
if ts.azureTenantId != "" {
args = append(args, "--tenant", ts.azureTenantId)
}
subscription := ts.getSubscription()
if subscription != "" && ts.azureTenantId == "" {
// This will fail if the user has access to the workspace, but not to the subscription
// itself.
// In such case, we fall back to not using the subscription.
result, err := exec.Command("az", extendedArgs...).Output()
// This should only be attempted when the tenant ID is not known.
result, err := runCommand(ts.ctx, "az", append(args, "--subscription", subscription))
if err == nil {
return result, nil
}
logger.Warnf(context.Background(), "Failed to get token for subscription. Using resource only token.")
logger.Infof(ts.ctx, "Failed to get token for subscription. Using resource only token.")
}
result, err := exec.Command("az", args...).Output()
result, err := runCommand(ts.ctx, "az", args)
if ee, ok := err.(*exec.ExitError); ok {
return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr))
}
Expand Down
64 changes: 61 additions & 3 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"context"
"errors"
"net/http"
"os"
"path/filepath"
Expand All @@ -13,9 +14,53 @@ import (
"github.com/stretchr/testify/require"
)

var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"}
var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"}
var azDummyWitInvalidResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "invalidResourceId"}
type mockTransport struct {
resp *http.Response
err error
}

func (m mockTransport) RoundTrip(*http.Request) (*http.Response, error) {
if m.err != nil {
return nil, m.err
}
return m.resp, nil
}

func makeClient(r *http.Response) *http.Client {
return &http.Client{
Transport: mockTransport{resp: r},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}

func makeFailingClient(err error) *http.Client {
return &http.Client{
Transport: mockTransport{err: err},
}
}

var redirectResponse = &http.Response{
StatusCode: 302,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}},
}
var errDummy = errors.New("failed to get login endpoint")

var azDummy = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(redirectResponse),
}
var azDummyWithResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeClient(redirectResponse),
}
var azDummyWitInvalidResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "invalidResourceId",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

// testdataPath returns the PATH to use for the duration of a test.
// It must only return absolute directories because Go refuses to run
Expand Down Expand Up @@ -187,6 +232,19 @@ func TestAzureCliCredentials_CorruptExpire(t *testing.T) {
assert.EqualError(t, err, "cannot parse expiry: parsing time \"\" as \"2006-01-02 15:04:05.999999\": cannot parse \"\" as \"2006\"")
}

func TestAzureCliCredentials_DoNotFetchIfTenantIdAlreadySet(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
aa := AzureCliCredentials{}
_, err := aa.Configure(context.Background(), &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureTenantID: "123",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeFailingClient(errDummy),
})
assert.NoError(t, err)
}

// TODO: this test should rather be on sequencing
// func TestConfigureWithAzureCLI_SP(t *testing.T) {
// aa := DatabricksClient{
Expand Down
6 changes: 5 additions & 1 deletion config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
err := cfg.loadAzureTenantId(ctx)
if err != nil {
return nil, fmt.Errorf("load tenant id: %w", err)
}
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions config/auth_databricks_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, nil
}

ts, err := newDatabricksCliTokenSource(cfg)
ts, err := newDatabricksCliTokenSource(ctx, cfg)
if err != nil {
if errors.Is(err, exec.ErrNotFound) {
logger.Debugf(ctx, "Most likely the Databricks CLI is not installed")
Expand Down Expand Up @@ -60,11 +60,12 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (f
var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected")

type databricksCliTokenSource struct {
ctx context.Context
name string
args []string
}

func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error) {
func newDatabricksCliTokenSource(ctx context.Context, cfg *Config) (*databricksCliTokenSource, error) {
args := []string{"auth", "token", "--host", cfg.Host}

if cfg.IsAccountClient() {
Expand Down Expand Up @@ -100,11 +101,11 @@ func newDatabricksCliTokenSource(cfg *Config) (*databricksCliTokenSource, error)
return nil, errLegacyDatabricksCli
}

return &databricksCliTokenSource{name: path, args: args}, nil
return &databricksCliTokenSource{ctx: ctx, name: path, args: args}, nil
}

func (ts *databricksCliTokenSource) Token() (*oauth2.Token, error) {
out, err := exec.Command(ts.name, ts.args...).Output()
out, err := runCommand(ts.ctx, ts.name, ts.args)
if ee, ok := err.(*exec.ExitError); ok {
return nil, fmt.Errorf("cannot get access token: %s", string(ee.Stderr))
}
Expand Down
4 changes: 4 additions & 0 deletions config/auth_permutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config,
AzureTenantID: cf.AzureTenantID,
AzureResourceID: cf.AzureResourceID,
AuthType: cf.AuthType,
azureTenantIdFetchClient: makeClient(&http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/tenant_id/abc"}},
}),
}
if client.IsAzure() {
client.DatabricksEnvironment = &DatabricksEnvironment{
Expand Down
31 changes: 31 additions & 0 deletions config/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"

Expand Down Expand Up @@ -169,3 +171,32 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol
logger.Debugf(ctx, "Discovered workspace url: %s", c.Host)
return nil
}

// loadAzureTenantId fetches the Azure tenant ID from the Azure AD endpoint.
// The tenant ID isn't directly exposed by any API today, but it can be inferred
// from the URL that a user is redirected to after accessing /aad/auth (the
// Azure Databricks login endpoint). Here, the redirect is not followed, but the
// tenant ID is extracted from the URL.
func (c *Config) loadAzureTenantId(ctx context.Context) error {
if !c.IsAzure() || c.AzureTenantID != "" || c.Host == "" {
return nil
}
req, err := http.NewRequestWithContext(ctx, "GET", c.CanonicalHostName()+"/aad/auth", nil)
if err != nil {
return err
}
res, err := c.azureTenantIdFetchClient.Do(req)
if err != nil && !errors.Is(err, http.ErrUseLastResponse) {
return err
}
location := res.Header.Get("Location")
parsedUrl, err := url.ParseRequestURI(location)
if err != nil {
return err
}
// Response URL is of the form https://login.microsoftonline.com/<tenantID>/oauth2/authorize?...
// or corresponding in other Azure clouds
splitPath := strings.SplitN(parsedUrl.Path, "/", 3)
c.AzureTenantID = splitPath[1]
return nil
}
44 changes: 44 additions & 0 deletions config/azure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package config

import (
"context"
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLoadAzureTenantId(t *testing.T) {
c := &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(&http.Response{
StatusCode: 302,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}},
}),
}
err := c.loadAzureTenantId(context.Background())
assert.NoError(t, err)
assert.Equal(t, c.AzureTenantID, "123-abc")
}

func TestLoadAzureTenantId_Failure(t *testing.T) {
testErr := errors.New("Failed to fetch login page")
c := &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeFailingClient(testErr),
}
err := c.loadAzureTenantId(context.Background())
assert.ErrorIs(t, err, testErr)
}

func TestLoadAzureTenantId_SkipNotInAzure(t *testing.T) {
testErr := errors.New("Failed to fetch login page")
c := &Config{
Host: "https://test.cloud.databricks.com/",
azureTenantIdFetchClient: makeFailingClient(testErr),
}
err := c.loadAzureTenantId(context.Background())
assert.NoError(t, err)
assert.Empty(t, c.AzureTenantID)
}
15 changes: 15 additions & 0 deletions config/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package config

import (
"context"
"os/exec"
"strings"

"github.com/databricks/databricks-sdk-go/logger"
)

// Run a command and return the output.
func runCommand(ctx context.Context, cmd string, args []string) ([]byte, error) {
logger.Debugf(ctx, "Running command: %s %v", cmd, strings.Join(args, " "))
Copy link
Contributor

@tanmay-db tanmay-db May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to allow passing debug level defaulting to DEBUG.

    logMessage := fmt.Sprintf("Running command: %s %v", cmd, strings.Join(args, " "))

    // logLevel is enum for log levels, default to debug if not provided. 
    switch logLevel {
    case LogLevelDebug:
        logger.Debugf(ctx, logMessage)
    case LogLevelInfo:
        logger.Infof(ctx, logMessage)
    case LogLevelWarn:
        logger.Warnf(ctx, logMessage)
    case LogLevelError:
        logger.Errorf(ctx, logMessage)
    default:
        logger.Infof(ctx, "Unspecified log level for: %s", logMessage)
    }

return exec.Command(cmd, args...).Output()
}
11 changes: 11 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ type Config struct {
// internal background context used for authentication purposes together with refreshClient
refreshCtx context.Context

// internal client used to fetch Azure Tenant ID from Databricks Login endpoint
azureTenantIdFetchClient *http.Client

// marker for testing fixture
isTesting bool

Expand Down Expand Up @@ -288,6 +291,14 @@ func (c *Config) EnsureResolved() error {
"rate limit",
},
})
if c.azureTenantIdFetchClient == nil {
c.azureTenantIdFetchClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Do not follow redirects
return http.ErrUseLastResponse
},
}
}
c.resolved = true
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion examples/default-auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func main() {
if err != nil {
panic(err)
}
for _, c := range all {
for _, c := range all[:10] {
println(c.ClusterName)
}
}