Skip to content

Commit

Permalink
Rename RefreshableCredentials to SessionCredentials (#116)
Browse files Browse the repository at this point in the history
Align more with other SDKs
  • Loading branch information
nfx committed May 17, 2023
1 parent 1afdc73 commit 1135557
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
33 changes: 18 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ Works for both AWS and Azure. Not supported for GCP at the moment.

```python
from databricks.sdk.oauth import OAuthClient

oauth_client = OAuthClient(host='<workspace-url>',
client_id='<oauth client ID>',
redirect_url=f'http://host.domain/callback',
Expand All @@ -380,29 +381,31 @@ APP_NAME = 'flask-demo'
app = Flask(APP_NAME)
app.secret_key = secrets.token_urlsafe(32)


@app.route('/callback')
def callback():
from databricks.sdk.oauth import Consent
consent = Consent.from_dict(oauth_client, session['consent'])
session['creds'] = consent.exchange_callback_parameters(request.args).as_dict()
return redirect(url_for('index'))
from databricks.sdk.oauth import Consent
consent = Consent.from_dict(oauth_client, session['consent'])
session['creds'] = consent.exchange_callback_parameters(request.args).as_dict()
return redirect(url_for('index'))


@app.route('/')
def index():
if 'creds' not in session:
consent = oauth_client.initiate_consent()
session['consent'] = consent.as_dict()
return redirect(consent.auth_url)
if 'creds' not in session:
consent = oauth_client.initiate_consent()
session['consent'] = consent.as_dict()
return redirect(consent.auth_url)

from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import RefreshableCredentials
from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import SessionCredentials

credentials_provider = RefreshableCredentials.from_dict(oauth_client, session['creds'])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider)
credentials_provider = SessionCredentials.from_dict(oauth_client, session['creds'])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider)

return render_template_string('...', w=workspace_client)
return render_template_string('...', w=workspace_client)
```

### SSO for local scripts on development machines
Expand Down
20 changes: 10 additions & 10 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def do_GET(self):
self.wfile.write(b'You can close this tab.')


class RefreshableCredentials(Refreshable):
class SessionCredentials(Refreshable):

def __init__(self, client: 'OAuthClient', token: Token):
self._client = client
Expand All @@ -187,8 +187,8 @@ def as_dict(self) -> dict:
return {'token': self._token.as_dict()}

@staticmethod
def from_dict(client: 'OAuthClient', raw: dict) -> 'RefreshableCredentials':
return RefreshableCredentials(client=client, token=Token.from_dict(raw['token']))
def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials':
return SessionCredentials(client=client, token=Token.from_dict(raw['token']))

def auth_type(self):
"""Implementing CredentialsProvider protocol"""
Expand Down Expand Up @@ -237,7 +237,7 @@ def as_dict(self) -> dict:
def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent':
return Consent(client, raw['state'], raw['verifier'])

def launch_external_browser(self) -> RefreshableCredentials:
def launch_external_browser(self) -> SessionCredentials:
redirect_url = urllib.parse.urlparse(self._client.redirect_url)
if redirect_url.hostname not in ('localhost', '127.0.0.1'):
raise ValueError(f'cannot listen on {redirect_url.hostname}')
Expand All @@ -254,14 +254,14 @@ def launch_external_browser(self) -> RefreshableCredentials:
query = feedback.pop()
return self.exchange_callback_parameters(query)

def exchange_callback_parameters(self, query: Dict[str, str]) -> RefreshableCredentials:
def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
if 'error' in query:
raise ValueError('{error}: {error_description}'.format(**query))
if 'code' not in query or 'state' not in query:
raise ValueError('No code returned in callback')
return self.exchange(query['code'], query['state'])

def exchange(self, code: str, state: str) -> RefreshableCredentials:
def exchange(self, code: str, state: str) -> SessionCredentials:
if self._state != state:
raise ValueError('state mismatch')
params = {
Expand All @@ -279,7 +279,7 @@ def exchange(self, code: str, state: str) -> RefreshableCredentials:
params=params,
headers=headers,
use_params=True)
return RefreshableCredentials(self._client, token)
return SessionCredentials(self._client, token)
except ValueError as e:
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
# Retry in cases of 'Single-Page Application' client-type with
Expand Down Expand Up @@ -420,7 +420,7 @@ def filename(self) -> str:
hash.update(chunk.encode('utf-8'))
return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))

def load(self) -> Optional[RefreshableCredentials]:
def load(self) -> Optional[SessionCredentials]:
"""
Load credentials from cache file. Return None if the cache file does not exist or is invalid.
"""
Expand All @@ -430,11 +430,11 @@ def load(self) -> Optional[RefreshableCredentials]:
try:
with open(self.filename, 'r') as f:
raw = json.load(f)
return RefreshableCredentials.from_dict(self.client, raw)
return SessionCredentials.from_dict(self.client, raw)
except Exception:
return None

def save(self, credentials: RefreshableCredentials) -> None:
def save(self, credentials: SessionCredentials) -> None:
"""
Save credentials to cache file.
"""
Expand Down
4 changes: 2 additions & 2 deletions examples/flask_app_with_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def index():
return redirect(consent.auth_url)

from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import RefreshableCredentials
from databricks.sdk.oauth import SessionCredentials

credentials_provider = RefreshableCredentials.from_dict(oauth_client, session["creds"])
credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider,
Expand Down

0 comments on commit 1135557

Please sign in to comment.