Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions src/firebolt/common/settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from typing import Optional
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Optional

from pydantic import BaseSettings, Field, SecretStr, root_validator

from firebolt.client.auth import Auth
from firebolt.client.auth import Auth, UsernamePassword

logger = logging.getLogger(__name__)

Expand All @@ -18,8 +18,31 @@
>>> ...
>>> settings = Settings(auth=Token(access_token), ...)"""

USERNAME_ENV = "FIREBOLT_USER"
PASSWORD_ENV = "FIREBOLT_PASSWORD"
AUTH_TOKEN_ENV = "FIREBOLT_AUTH_TOKEN"
ACCOUNT_ENV = "FIREBOLT_ACCOUNT"
SERVER_ENV = "FIREBOLT_SERVER"
DEFAULT_REGION_ENV = "FIREBOLT_DEFAULT_REGION"


def from_env(var_name: str, default: Any = None) -> Callable:
def inner() -> Any:
os.environ.get(var_name, default)

return inner


class Settings(BaseSettings):
def auth_from_env() -> Optional[Auth]:
username = os.environ.get(USERNAME_ENV, None)
password = os.environ.get(PASSWORD_ENV, None)
if username and password:
return UsernamePassword(username, password)
return None


@dataclass
class Settings:
"""Settings for Firebolt SDK.

Attributes:
Expand All @@ -34,25 +57,19 @@ class Settings(BaseSettings):
default_region (str): Default region for provisioning
"""

auth: Optional[Auth] = Field(None)
auth: Optional[Auth] = field(default_factory=auth_from_env)
# Authorization
user: Optional[str] = Field(None, env="FIREBOLT_USER")
password: Optional[SecretStr] = Field(None, env="FIREBOLT_PASSWORD")
user: Optional[str] = field(default=None)
password: Optional[str] = field(default=None)
# Or
access_token: Optional[str] = Field(None, env="FIREBOLT_AUTH_TOKEN")

account_name: Optional[str] = Field(None, env="FIREBOLT_ACCOUNT")
server: str = Field(..., env="FIREBOLT_SERVER")
default_region: str = Field(..., env="FIREBOLT_DEFAULT_REGION")
use_token_cache: bool = Field(True)
access_token: Optional[str] = field(default_factory=from_env(AUTH_TOKEN_ENV))

class Config:
"""Internal pydantic config."""
account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV))
server: str = field(default_factory=from_env(SERVER_ENV))
default_region: str = field(default_factory=from_env(DEFAULT_REGION_ENV))
use_token_cache: bool = field(default=True)

env_file = ".env"

@root_validator
def mutual_exclusive_with_creds(cls, values: dict) -> dict:
def __post_init__(self) -> None:
"""Validate that either creds or token is provided.

Args:
Expand All @@ -66,17 +83,15 @@ def mutual_exclusive_with_creds(cls, values: dict) -> dict:
"""

params_present = (
values.get("user") is not None or values.get("password") is not None,
values.get("access_token") is not None,
values.get("auth") is not None,
self.user is not None or self.password is not None,
self.access_token is not None,
self.auth is not None,
)
if sum(params_present) == 0:
raise ValueError(
"Provide at least one of auth, user/password or access_token."
)
if sum(params_present) > 1:
raise ValueError("Provide only one of auth, user/password or access_token")
if any(values.get(f) for f in ("user", "password", "access_token")):
if any((self.user, self.password, self.access_token)):
logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE)

return values
2 changes: 1 addition & 1 deletion src/firebolt/service/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, settings: Optional[Settings] = None):
assert self.settings.password
auth = UsernamePassword(
self.settings.user,
self.settings.password.get_secret_value(),
self.settings.password,
self.settings.use_token_cache,
)

Expand Down
14 changes: 5 additions & 9 deletions tests/unit/async_db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async def test_connection_token_caching(
async with await connect(
database=db_name,
username=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
Expand All @@ -281,17 +281,15 @@ async def test_connection_token_caching(
assert await connection.cursor().execute("select*") == len(
python_query_data
)
ts = TokenSecureStorage(
username=settings.user, password=settings.password.get_secret_value()
)
ts = TokenSecureStorage(username=settings.user, password=settings.password)
assert ts.get_cached_token() == access_token, "Invalid token value cached"

# Do the same, but with use_token_cache=False
with Patcher():
async with await connect(
database=db_name,
username=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
Expand All @@ -300,9 +298,7 @@ async def test_connection_token_caching(
assert await connection.cursor().execute("select*") == len(
python_query_data
)
ts = TokenSecureStorage(
username=settings.user, password=settings.password.get_secret_value()
)
ts = TokenSecureStorage(username=settings.user, password=settings.password)
assert (
ts.get_cached_token() is None
), "Token is cached even though caching is disabled"
Expand All @@ -324,7 +320,7 @@ async def test_connect_with_auth(
for auth in (
UsernamePassword(
settings.user,
settings.password.get_secret_value(),
settings.password,
use_token_cache=False,
),
Token(access_token),
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_username(settings: Settings) -> str:

@fixture
def test_password(settings: Settings) -> str:
return settings.password.get_secret_value()
return settings.password


@fixture
Expand Down
22 changes: 16 additions & 6 deletions tests/unit/common/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Tuple
from unittest.mock import Mock, patch

from pydantic import ValidationError
from pytest import mark, raises

from firebolt.client.auth import Auth
Expand All @@ -21,8 +22,6 @@ def test_settings_happy_path(fields: Tuple[str]) -> None:

for f in fields:
field = getattr(s, f)
if hasattr(field, "get_secret_value"):
field = field.get_secret_value()
assert (
(field == f) if f != "auth" else isinstance(field, Auth)
), f"Invalid settings value {f}"
Expand All @@ -41,8 +40,19 @@ def test_settings_happy_path(fields: Tuple[str]) -> None:
),
)
def test_settings_auth_credentials(kwargs) -> None:
with raises(ValidationError) as exc_info:
with raises(ValueError) as exc_info:
Settings(**kwargs)

err = exc_info.value
assert len(err.errors()) > 0

@patch("firebolt.common.settings.logger")
def test_no_deprecation_warning_with_env(logger_mock: Mock):
with patch.dict(
os.environ,
{"FIREBOLT_USER": "user", "FIREBOLT_PASSWORD": "password"},
clear=True,
):
s = Settings(server="server", default_region="region")
logger_mock.warning.assert_not_called()
assert s.auth is not None, "Settings.auth wasn't populated from env variables"
assert s.auth.username == "user", "Invalid username in Settings.auth"
assert s.auth.password == "password", "Invalid password in Settings.auth"
7 changes: 2 additions & 5 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import httpx
from httpx import Request, Response
from pydantic import SecretStr
from pyfakefs.fake_filesystem_unittest import Patcher
from pytest import fixture

Expand Down Expand Up @@ -122,7 +121,7 @@ def settings(server: str, region_1: str, username: str, password: str) -> Settin
return Settings(
server=server,
user=username,
password=SecretStr(password),
password=password,
default_region=region_1.name,
account_name=None,
)
Expand Down Expand Up @@ -354,9 +353,7 @@ def check_credentials(
assert "username" in body, "Missing username"
assert body["username"] == settings.user, "Invalid username"
assert "password" in body, "Missing password"
assert (
body["password"] == settings.password.get_secret_value()
), "Invalid password"
assert body["password"] == settings.password, "Invalid password"

return Response(
status_code=httpx.codes.OK,
Expand Down
14 changes: 5 additions & 9 deletions tests/unit/db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,33 +290,29 @@ def test_connection_token_caching(
with connect(
database=db_name,
username=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
use_token_cache=True,
) as connection:
assert connection.cursor().execute("select*") == len(python_query_data)
ts = TokenSecureStorage(
username=settings.user, password=settings.password.get_secret_value()
)
ts = TokenSecureStorage(username=settings.user, password=settings.password)
assert ts.get_cached_token() == access_token, "Invalid token value cached"

# Do the same, but with use_token_cache=False
with Patcher():
with connect(
database=db_name,
username=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
use_token_cache=False,
) as connection:
assert connection.cursor().execute("select*") == len(python_query_data)
ts = TokenSecureStorage(
username=settings.user, password=settings.password.get_secret_value()
)
ts = TokenSecureStorage(username=settings.user, password=settings.password)
assert (
ts.get_cached_token() is None
), "Token is cached even though caching is disabled"
Expand All @@ -338,7 +334,7 @@ def test_connect_with_auth(
for auth in (
UsernamePassword(
settings.user,
settings.password.get_secret_value(),
settings.password,
use_token_cache=False,
),
Token(access_token),
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/service/test_resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_rm_credentials(
rm.client.get(url)

auth_username_password_settings = Settings(
auth=UsernamePassword(settings.user, settings.password.get_secret_value()),
auth=UsernamePassword(settings.user, settings.password),
server=settings.server,
default_region=settings.default_region,
)
Expand Down Expand Up @@ -87,30 +87,30 @@ def test_rm_token_cache(
with Patcher():
local_settings = Settings(
user=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
server=settings.server,
default_region=settings.default_region,
use_token_cache=True,
)
rm = ResourceManager(local_settings)
rm.client.get(url)

ts = TokenSecureStorage(settings.user, settings.password.get_secret_value())
ts = TokenSecureStorage(settings.user, settings.password)
assert ts.get_cached_token() == access_token, "Invalid token value cached"

# Do the same, but with use_token_cache=False
with Patcher():
local_settings = Settings(
user=settings.user,
password=settings.password.get_secret_value(),
password=settings.password,
server=settings.server,
default_region=settings.default_region,
use_token_cache=False,
)
rm = ResourceManager(local_settings)
rm.client.get(url)

ts = TokenSecureStorage(settings.user, settings.password.get_secret_value())
ts = TokenSecureStorage(settings.user, settings.password)
assert (
ts.get_cached_token() is None
), "Token is cached even though caching is disabled"
Expand Down