Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #186 - token and api endpoint not optional #187

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,17 +1858,17 @@ class AstraDB:

def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
token: str,
api_endpoint: str,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
namespace (str, optional): Namespace for the database.
"""
if token is None or api_endpoint is None:
Expand Down Expand Up @@ -2090,17 +2090,17 @@ def truncate_collection(self, collection_name: str) -> AstraDBCollection:
class AsyncAstraDB:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
token: str,
api_endpoint: str,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
namespace (str, optional): Namespace for the database.
"""
self.client = httpx.AsyncClient()
Expand Down
30 changes: 15 additions & 15 deletions tests/astrapy/test_async_db_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,40 @@
async def test_path_handling(
astra_db_credentials_kwargs: Dict[str, Optional[str]]
) -> None:
async with AsyncAstraDB(**astra_db_credentials_kwargs) as astra_db_1:
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

async with AsyncAstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
) as astra_db_1:
url_1 = astra_db_1.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="v1"
) as astra_db_2:
url_2 = astra_db_2.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1"
) as astra_db_3:
url_3 = astra_db_3.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1/",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1/"
) as astra_db_4:
url_4 = astra_db_4.base_path

assert url_1 == url_2 == url_3 == url_4

# autofill of the default keyspace name
async with AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": DEFAULT_KEYSPACE_NAME},
}
token=token, api_endpoint=api_endpoint, namespace=DEFAULT_KEYSPACE_NAME
) as unspecified_ks_client, AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": None},
}
token=token, api_endpoint=api_endpoint, namespace=None
) as explicit_ks_client:
assert unspecified_ks_client.base_path == explicit_ks_client.base_path

Expand Down
33 changes: 13 additions & 20 deletions tests/astrapy/test_db_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,46 +33,39 @@

@pytest.mark.describe("should confirm path handling in constructor")
def test_path_handling(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> None:
astra_db_1 = AstraDB(**astra_db_credentials_kwargs)
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

astra_db_1 = AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)
url_1 = astra_db_1.base_path

astra_db_2 = AstraDB(
**astra_db_credentials_kwargs,
api_version="v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="v1"
)

url_2 = astra_db_2.base_path

astra_db_3 = AstraDB(
**astra_db_credentials_kwargs,
api_version="/v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1"
)

url_3 = astra_db_3.base_path

astra_db_4 = AstraDB(
**astra_db_credentials_kwargs,
api_version="/v1/",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1/"
)

url_4 = astra_db_4.base_path

assert url_1 == url_2 == url_3 == url_4

# autofill of the default keyspace name
unspecified_ks_client = AstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": DEFAULT_KEYSPACE_NAME},
}
)
explicit_ks_client = AstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": None},
}
token=token, api_endpoint=api_endpoint, namespace=DEFAULT_KEYSPACE_NAME
)
explicit_ks_client = AstraDB(token=token, api_endpoint=api_endpoint, namespace=None)

assert unspecified_ks_client.base_path == explicit_ks_client.base_path


Expand Down
63 changes: 53 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import math

import pytest
from typing import AsyncIterable, Dict, Iterable, List, Optional, Set, TypeVar
from typing import (
AsyncIterable,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
TypedDict,
)

import pytest_asyncio

Expand All @@ -15,8 +24,9 @@
T = TypeVar("T")


ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.environ.get("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]

ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME)

# fixed
Expand Down Expand Up @@ -49,6 +59,12 @@
]


class AstraDBCredentials(TypedDict, total=False):
token: str
api_endpoint: str
namespace: Optional[str]


def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
this_batch = []
for entry in iterable:
Expand All @@ -61,41 +77,68 @@ def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable


@pytest.fixture(scope="session")
def astra_db_credentials_kwargs() -> Dict[str, Optional[str]]:
return {
def astra_db_credentials_kwargs() -> AstraDBCredentials:
astra_db_creds: AstraDBCredentials = {
"token": ASTRA_DB_APPLICATION_TOKEN,
"api_endpoint": ASTRA_DB_API_ENDPOINT,
"namespace": ASTRA_DB_KEYSPACE,
}

return astra_db_creds


@pytest.fixture(scope="session")
def astra_invalid_db_credentials_kwargs() -> Dict[str, Optional[str]]:
return {
def astra_invalid_db_credentials_kwargs() -> AstraDBCredentials:
astra_db_creds: AstraDBCredentials = {
"token": ASTRA_DB_APPLICATION_TOKEN,
"api_endpoint": "http://localhost:1234",
"namespace": ASTRA_DB_KEYSPACE,
}

return astra_db_creds


@pytest.fixture(scope="session")
def db(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> AstraDB:
return AstraDB(**astra_db_credentials_kwargs)
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)


@pytest_asyncio.fixture(scope="function")
async def async_db(
astra_db_credentials_kwargs: Dict[str, Optional[str]]
) -> AsyncIterable[AsyncAstraDB]:
async with AsyncAstraDB(**astra_db_credentials_kwargs) as db:
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

async with AsyncAstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
) as db:
yield db


@pytest.fixture(scope="module")
def invalid_db(
astra_invalid_db_credentials_kwargs: Dict[str, Optional[str]]
) -> AstraDB:
return AstraDB(**astra_invalid_db_credentials_kwargs)
token = astra_invalid_db_credentials_kwargs["token"]
api_endpoint = astra_invalid_db_credentials_kwargs["api_endpoint"]
namespace = astra_invalid_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)


@pytest.fixture(scope="session")
Expand Down
Loading