Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ project_urls =
packages = find:
install_requires =
aiorwlock==1.1.0
appdirs>=1.4.4
appdirs-stubs>=0.1.0
async-property==0.2.1
cryptography>=36.0.1
httpx[http2]==0.21.3
pydantic[dotenv]==1.8.2
readerwriterlock==1.0.9
Expand All @@ -44,6 +47,7 @@ dev =
devtools==0.7.0
mypy==0.910
pre-commit==2.15.0
pyfakefs>=4.5.3
pytest==6.2.5
pytest-asyncio
pytest-cov==3.0.0
Expand Down
17 changes: 14 additions & 3 deletions src/firebolt/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from firebolt.client.constants import _REQUEST_ERRORS, DEFAULT_API_URL
from firebolt.common.exception import AuthenticationError
from firebolt.common.token_storage import TokenSecureStorage
from firebolt.common.urls import AUTH_URL
from firebolt.common.util import fix_url_schema

Expand Down Expand Up @@ -34,13 +35,18 @@ def from_token(token: str) -> "Auth":
return a

def __init__(
self, username: str, password: str, api_endpoint: str = DEFAULT_API_URL
self,
username: str,
password: str,
api_endpoint: str = DEFAULT_API_URL,
):
self.username = username
self.password = password
self._token_storage = TokenSecureStorage(username=username, password=password)

# Add schema to url if it's missing
self._api_endpoint = fix_url_schema(api_endpoint)
self._token: Optional[str] = None
self._token: Optional[str] = self._token_storage.get_cached_token()
self._expires: Optional[int] = None

def copy(self) -> "Auth":
Expand All @@ -56,7 +62,6 @@ def expired(self) -> Optional[int]:

def get_new_token_generator(self) -> Generator[Request, Response, None]:
"""Get new token using username and password"""

try:
response = yield Request(
"POST",
Expand All @@ -74,6 +79,9 @@ def get_new_token_generator(self) -> Generator[Request, Response, None]:

self._token = parsed["access_token"]
self._expires = int(time()) + int(parsed["expires_in"])

self._token_storage.cache_token(parsed["access_token"], self._expires)

except _REQUEST_ERRORS as e:
raise AuthenticationError(repr(e), self._api_endpoint)

Expand All @@ -83,8 +91,11 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]:

if not self.token or self.expired:
yield from self.get_new_token_generator()

request.headers["Authorization"] = f"Bearer {self.token}"

response = yield request

if response.status_code == codes.UNAUTHORIZED:
yield from self.get_new_token_generator()
request.headers["Authorization"] = f"Bearer {self.token}"
Expand Down
131 changes: 131 additions & 0 deletions src/firebolt/common/token_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from base64 import b64decode, b64encode, urlsafe_b64encode
from hashlib import sha256
from json import JSONDecodeError
from json import dump as json_dump
from json import load as json_load
from os import makedirs, path, urandom
from time import time
from typing import Optional

from appdirs import user_data_dir
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

APPNAME = "firebolt"


def generate_salt() -> str:
return b64encode(urandom(16)).decode("ascii")


def generate_file_name(username: str, password: str) -> str:
username_hash = sha256(username.encode("utf-8")).hexdigest()[:32]
password_hash = sha256(password.encode("utf-8")).hexdigest()[:32]

return f"{username_hash}{password_hash}.json"


class TokenSecureStorage:
def __init__(self, username: str, password: str):
"""
Class for permanent storage of token in the filesystem in encrypted way

:param username: username used for toke encryption
:param password: password used for toke encryption
"""
self._data_dir = user_data_dir(appname=APPNAME)
makedirs(self._data_dir, exist_ok=True)

self._token_file = path.join(
self._data_dir, generate_file_name(username, password)
)

self.salt = self._get_salt()
self.encrypter = FernetEncrypter(self.salt, username, password)

def _get_salt(self) -> str:
"""
Get salt from the file if exists, or generate a new one

:return: salt
"""
res = self._read_data_json()
return res.get("salt", generate_salt())

def _read_data_json(self) -> dict:
"""
Read json token file

:return: json object as dict
"""
if not path.exists(self._token_file):
return {}

with open(self._token_file) as f:
try:
return json_load(f)
except JSONDecodeError:
return {}

def get_cached_token(self) -> Optional[str]:
"""
Get decrypted token using username and password
If the token not found or token cannot be decrypted using username, password
None will be returned

:return: token or None
"""
res = self._read_data_json()
if "token" not in res:
return None

# Ignore expired tokens
if "expiration" in res and res["expiration"] <= int(time()):
return None

return self.encrypter.decrypt(res["token"])

def cache_token(self, token: str, expiration_ts: int) -> None:
"""

:param token:
:return:
"""
token = self.encrypter.encrypt(token)

with open(self._token_file, "w") as f:
json_dump(
{"token": token, "salt": self.salt, "expiration": expiration_ts}, f
)


class FernetEncrypter:
def __init__(self, salt: str, username: str, password: str):
"""

:param salt:
:param username:
:param password:
"""

kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
salt=b64decode(salt),
length=32,
iterations=39000,
)
self.fernet = Fernet(
urlsafe_b64encode(
kdf.derive(bytes(f"{username}{password}", encoding="utf-8"))
)
)

def encrypt(self, data: str) -> str:
return self.fernet.encrypt(bytes(data, encoding="utf-8")).decode("utf-8")

def decrypt(self, data: str) -> Optional[str]:
try:
return self.fernet.decrypt(bytes(data, encoding="utf-8")).decode("utf-8")
except InvalidToken:
return None
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pyfakefs.fake_filesystem_unittest import Patcher
from pytest import fixture


@fixture(autouse=True)
def global_fake_fs() -> None:
with Patcher():
yield
Comment on lines +5 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I preferred explicitly adding fs to the relevant tests. That way there's no hidden functionality, you know which tests are hitting a fake file system and which do not.
But up to you, I'm not too worried about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required because now any authentication leaves token cache artifact file, which also causes tests to fail

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to add fake fs to almost every integration test and to a lot of unit tests

11 changes: 4 additions & 7 deletions tests/unit/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from httpx import Client, Request, StreamError, codes
from pyfakefs.fake_filesystem import FakeFilesystem
from pytest_httpx import HTTPXMock
from pytest_mock import MockerFixture

Expand Down Expand Up @@ -30,9 +31,7 @@ def test_auth_basic(


def test_auth_refresh_on_expiration(
httpx_mock: HTTPXMock,
test_token: str,
test_token2: str,
httpx_mock: HTTPXMock, test_token: str, test_token2: str, fs: FakeFilesystem
):
"""Auth refreshes the token on expiration."""

Expand All @@ -56,9 +55,7 @@ def test_auth_refresh_on_expiration(


def test_auth_uses_same_token_if_valid(
httpx_mock: HTTPXMock,
test_token: str,
test_token2: str,
httpx_mock: HTTPXMock, test_token: str, test_token2: str, fs: FakeFilesystem
):
"""Auth refreshes the token on expiration"""

Expand Down Expand Up @@ -92,7 +89,7 @@ def test_auth_uses_same_token_if_valid(
httpx_mock.reset(False)


def test_auth_error_handling(httpx_mock: HTTPXMock):
def test_auth_error_handling(httpx_mock: HTTPXMock, fs: FakeFilesystem):
"""Auth handles various errors properly."""

for api_endpoint in ("https://host", "host"):
Expand Down
12 changes: 4 additions & 8 deletions tests/unit/client/test_auth_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from httpx import AsyncClient, Request, codes
from pyfakefs.fake_filesystem import FakeFilesystem
from pytest import mark
from pytest_httpx import HTTPXMock

Expand All @@ -8,9 +9,7 @@

@mark.asyncio
async def test_auth_refresh_on_expiration(
httpx_mock: HTTPXMock,
test_token: str,
test_token2: str,
httpx_mock: HTTPXMock, test_token: str, test_token2: str, fs: FakeFilesystem
):
"""Auth refreshes the token on expiration."""

Expand Down Expand Up @@ -39,9 +38,7 @@ async def test_auth_refresh_on_expiration(

@mark.asyncio
async def test_auth_uses_same_token_if_valid(
httpx_mock: HTTPXMock,
test_token: str,
test_token2: str,
httpx_mock: HTTPXMock, test_token: str, test_token2: str, fs: FakeFilesystem
):
"""Auth refreshes the token on expiration"""

Expand Down Expand Up @@ -81,8 +78,7 @@ async def test_auth_uses_same_token_if_valid(

@mark.asyncio
async def test_auth_adds_header(
httpx_mock: HTTPXMock,
test_token: str,
httpx_mock: HTTPXMock, test_token: str, fs: FakeFilesystem
):
"""Auth adds required authentication headers to httpx.Request."""
httpx_mock.add_response(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable

from httpx import codes
from pyfakefs.fake_filesystem import FakeFilesystem
from pytest import raises
from pytest_httpx import HTTPXMock

Expand All @@ -15,6 +16,7 @@ def test_client_retry(
test_username: str,
test_password: str,
test_token: str,
fs: FakeFilesystem,
):
"""
Client retries with new auth token
Expand Down Expand Up @@ -56,6 +58,7 @@ def test_client_different_auths(
test_username: str,
test_password: str,
test_token: str,
fs: FakeFilesystem,
):
"""
Client properly handles such auth types:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/client/test_client_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable

from httpx import codes
from pyfakefs.fake_filesystem import FakeFilesystem
from pytest import mark, raises
from pytest_httpx import HTTPXMock

Expand All @@ -16,6 +17,7 @@ async def test_client_retry(
test_username: str,
test_password: str,
test_token: str,
fs: FakeFilesystem,
):
"""
Client retries with new auth token
Expand Down Expand Up @@ -58,6 +60,7 @@ async def test_client_different_auths(
test_username: str,
test_password: str,
test_token: str,
fs: FakeFilesystem,
):
"""
Client properly handles such auth types:
Expand Down
Loading