Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RESTClient: implement AuthConfigBase.__bool__ + update docs #1398

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):
configurable via env variables or toml files
"""

pass
def __bool__(self) -> bool:
# This is needed to avoid AuthConfigBase-derived classes
# which do not implement CredentialsConfiguration interface
# to be evaluated as False in requests.sessions.Session.prepare_request()
return True


@configspec
Expand Down
8 changes: 4 additions & 4 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class.
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.

### Bearer token authentication

Expand Down Expand Up @@ -479,12 +479,12 @@ response = client.get("/protected/resource")

### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method:
You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:

```py
from dlt.sources.helpers.rest_client.auth import AuthConfigBase
from requests.auth import AuthBase

class CustomAuth(AuthConfigBase):
class CustomAuth(AuthBase):
def __init__(self, token):
self.token = token

Expand Down
43 changes: 42 additions & 1 deletion tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pytest
from typing import Any, cast
from requests.auth import AuthBase
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.rest_client import RESTClient
Expand Down Expand Up @@ -57,7 +58,6 @@ def test_page_context(self, rest_client: RESTClient) -> None:
for page in rest_client.paginate(
"/posts",
paginator=JSONResponsePaginator(next_url_path="next_page"),
auth=AuthConfigBase(),
):
# response that produced data
assert isinstance(page.response, Response)
Expand Down Expand Up @@ -183,3 +183,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_auth_success(self, rest_client: RESTClient):
class CustomAuthConfigBase(AuthConfigBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

auth_list = [
CustomAuthConfigBase("test-token"),
CustomAuthAuthBase("test-token"),
]

for auth in auth_list:
response = rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)

assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}

pages_iter = rest_client.paginate(
"/protected/posts/bearer-token",
auth=auth,
)

pages_list = list(pages_iter)
assert_pagination(pages_list)

assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"
Loading