Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(application): add application authentification
MP-685
- Loading branch information
1 parent
f5a3ead
commit 0382ae6
Showing
6 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .application import Application, ApplicationClient # noqa | ||
from .base import IClient, Request, Response # noqa | ||
from .exceptions import InvalidLogin # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable | ||
|
||
from oauthlib.oauth2 import BackendApplicationClient, OAuth2Error, TokenExpiredError | ||
from requests_oauthlib import OAuth2Session | ||
|
||
from .base import IClient, Request, Response | ||
from .exceptions import InvalidLogin | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Application: | ||
client_id: str | ||
client_secret: str | ||
|
||
|
||
def retry_on_expired_token(func: Callable[..., Response]) -> Callable[..., Response]: | ||
def inner(client: "ApplicationClient", request: Request, **kwargs: Any) -> Response: | ||
try: | ||
return func(client, request, **kwargs) | ||
except TokenExpiredError: | ||
client.fetch_token() | ||
return func(client, request, **kwargs) | ||
|
||
return inner | ||
|
||
|
||
class ApplicationClient(IClient): | ||
def __init__( | ||
self, base_url: str, organization_id: str, application: Application | ||
) -> None: | ||
""" | ||
Args: | ||
base_url: The API base url, i.e your Haussmann cell. | ||
e.g: https://XX-cell-YYY.api.lumapps.com | ||
organization_id: The ID of the given customer / organization. | ||
application: A LumApps application of the same customer. | ||
""" | ||
self.base_url = base_url.rstrip("/") | ||
self.organization_id = organization_id | ||
self.application = application | ||
self.session = OAuth2Session( | ||
client=BackendApplicationClient( | ||
client_id=application.client_id, scope=None, | ||
) | ||
) | ||
self.organization_url = ( | ||
f"{self.base_url}/v2/organizations/{self.organization_id}" | ||
) | ||
|
||
@retry_on_expired_token | ||
def request(self, request: Request, **_: Any) -> Response: | ||
if not self.session.token: | ||
# Ensure token in request | ||
self.fetch_token() | ||
response = self.session.request( | ||
request.method, | ||
f"{self.organization_url}/{request.url.lstrip('/')}", | ||
params=request.params, | ||
headers={**request.headers, "User-Agent": "lumapps-sdk"}, | ||
json=request.json, | ||
) | ||
return Response( | ||
status_code=response.status_code, | ||
headers=dict(response.headers), | ||
json=response.json() if response.text else None, | ||
) | ||
|
||
def fetch_token(self) -> None: | ||
try: | ||
self.session.fetch_token( | ||
f"{self.organization_url}/application-token", | ||
client_secret=self.application.client_secret, | ||
) | ||
except OAuth2Error as err: | ||
raise InvalidLogin("Could not fetch token from application") from err |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from typing import Any, Mapping | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Request: | ||
# The HTTP method, usually GET, POST, PUT or PATCH | ||
method: str | ||
# The requested URL | ||
url: str | ||
# The query parameters (?key=value) | ||
params: Mapping[str, str] = field(default_factory=dict) | ||
# The extra headers required to process the request | ||
headers: Mapping[str, str] = field(default_factory=dict) | ||
# The JSON content of the request | ||
json: Any = None | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Response: | ||
status_code: int | ||
headers: Mapping[str, str] | ||
json: Any | ||
|
||
|
||
class IClient(ABC): | ||
""" | ||
The generic HTTP client for LumApps | ||
The implementation must handle authentification and specifics if necessary | ||
""" | ||
|
||
@abstractmethod | ||
def request(self, request: Request, **kwargs: Any) -> Response: # pragma: no cover | ||
""" | ||
kwargs should be used for very niche behavior and not relied on extensively | ||
Most, if not all, implementations should NOT need to use it | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class InvalidLogin(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from lumapps.latest.client import ApplicationClient, Application, Request, InvalidLogin | ||
from oauthlib.oauth2 import TokenExpiredError | ||
from pytest import raises | ||
|
||
|
||
def test_request(requests_mock): | ||
# Given | ||
token_mock = requests_mock.post( | ||
"https://mock/v2/organizations/123/application-token", | ||
request_headers={"Authorization": "Basic Y2xpZW50OnNlY3JldA=="}, | ||
json={"access_token": "123"}, | ||
) | ||
call_mock = requests_mock.get( | ||
"https://mock/v2/organizations/123/test", | ||
request_headers={"Authorization": "Bearer 123"}, | ||
json="response" | ||
) | ||
client = ApplicationClient( | ||
"https://mock", "123", Application(client_id="client", client_secret="secret") | ||
) | ||
|
||
# When | ||
response = client.request(Request(method="GET", url="/test")) | ||
|
||
# Then | ||
assert response.status_code == 200 | ||
assert response.json == "response" | ||
assert token_mock.call_count == 1 | ||
assert call_mock.call_count == 1 | ||
|
||
|
||
def test_request_token_expired(requests_mock): | ||
# Given | ||
token_mock = requests_mock.post( | ||
"https://mock/v2/organizations/123/application-token", | ||
[ | ||
{"json": {"access_token": "123"}}, | ||
{"json": {"access_token": "456"}}, | ||
] | ||
) | ||
call_token_expired_mock = requests_mock.get( | ||
"https://mock/v2/organizations/123/test", | ||
request_headers={"Authorization": "Bearer 123"}, | ||
exc=TokenExpiredError | ||
) | ||
call_token_updated_mock = requests_mock.get( | ||
"https://mock/v2/organizations/123/test", | ||
request_headers={"Authorization": "Bearer 456"}, | ||
json="response" | ||
) | ||
client = ApplicationClient( | ||
"https://mock", "123", Application(client_id="client", client_secret="secret") | ||
) | ||
|
||
# When | ||
response = client.request(Request(method="GET", url="/test")) | ||
|
||
# Then | ||
assert response.status_code == 200 | ||
assert response.json == "response" | ||
assert token_mock.call_count == 2 | ||
assert call_token_expired_mock.call_count == 1 | ||
assert call_token_updated_mock.call_count == 1 | ||
|
||
|
||
def test_request_no_token(requests_mock): | ||
# Given | ||
token_mock = requests_mock.post( | ||
"https://mock/v2/organizations/123/application-token", | ||
status_code=400, | ||
json={"error": "invalid_request"}, | ||
) | ||
client = ApplicationClient( | ||
"https://mock", "123", Application(client_id="client", client_secret="secret") | ||
) | ||
|
||
# When | ||
with raises(InvalidLogin): | ||
client.request(Request(method="GET", url="/test")) | ||
|
||
# Then | ||
assert token_mock.call_count == 1 |