diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 4a536b3..c5b72df 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -12,7 +12,7 @@ jobs: # tests can't run in parallel as they write and read data with same keys max-parallel: 1 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.6", "3.7", "3.8", "3.9"] steps: # Get the code into the container - name: Checkout @@ -25,12 +25,12 @@ jobs: # Test the code - name: Test code env: - DETA_SDK_TEST_PROJECT_KEY: ${{secrets.DETA_SDK_TEST_PROJECT_KEY}} - DETA_SDK_TEST_BASE_NAME: ${{secrets.DETA_SDK_TEST_BASE_NAME}} - DETA_SDK_TEST_DRIVE_NAME: ${{secrets.DETA_SDK_TEST_DRIVE_NAME}} - DETA_SDK_TEST_DRIVE_HOST: ${{secrets.DETA_SDK_TEST_DRIVE_HOST}} + DETA_SDK_TEST_PROJECT_KEY: ${{ secrets.DETA_SDK_TEST_PROJECT_KEY }} + DETA_SDK_TEST_BASE_NAME: ${{ secrets.DETA_SDK_TEST_BASE_NAME }} + DETA_SDK_TEST_DRIVE_NAME: ${{ secrets.DETA_SDK_TEST_DRIVE_NAME }} + DETA_SDK_TEST_DRIVE_HOST: ${{ secrets.DETA_SDK_TEST_DRIVE_HOST }} DETA_SDK_TEST_TTL_ATTRIBUTE: __expires run: | python -m pip install --upgrade pip python -m pip install pytest pytest-asyncio aiohttp - pytest tests \ No newline at end of file + pytest tests diff --git a/.github/workflows/tag_release.yml b/.github/workflows/tag_release.yml index da880e0..22047d9 100644 --- a/.github/workflows/tag_release.yml +++ b/.github/workflows/tag_release.yml @@ -1,6 +1,6 @@ name: Lint and tag before release -on: +on: push: branches: - release @@ -16,10 +16,10 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: ^3.5 + python-version: ^3.6 # Install dependencies - name: Install dependencies run: "scripts/install" # Make tag - name: Make git tag - run: "scripts/tag" \ No newline at end of file + run: "scripts/tag" diff --git a/.github/workflows/test_release.yml b/.github/workflows/test_release.yml index cf59e6a..b2f142a 100644 --- a/.github/workflows/test_release.yml +++ b/.github/workflows/test_release.yml @@ -1,6 +1,6 @@ name: Test release -on: +on: workflow_dispatch: jobs: @@ -14,7 +14,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: ^3.5 + python-version: ^3.6 # Install dependencies - name: Install dependencies run: "scripts/install" @@ -27,4 +27,3 @@ jobs: env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TEST_TOKEN }} - \ No newline at end of file diff --git a/.gitignore b/.gitignore index 567f9df..d48fe4b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,13 @@ .* !.gitignore !.github/ -__pycache__ +__pycache__/ test.py build/ +dist/ *.egg-info -dist testEnv .env -venv/ \ No newline at end of file +env/ +venv/ +.venv/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index af12021..a817003 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,10 +17,10 @@ source .venv/bin/activate **Windows** ```powershell -venv/scripts/activate +.venv/scripts/activate ``` -### Installing the dependencies +### Install the dependencies ```sh pip install -r requirements.txt @@ -36,21 +36,21 @@ cp env.sample .env Then provide the values as follows: -- `DETA_SDK_TEST_PROJECT_KEY` – Test project key (create a new Deta project for testing and grab the generated key). -- `DETA_SDK_TEST_BASE_NAME` – Name of your Base, default is fine. -- `DETA_SDK_TEST_DRIVE_NAME` – Name of your Drive, default is fine. -- `DETA_SDK_TEST_DRIVE_HOST` – Host URL, default is fine. +- `DETA_SDK_TEST_PROJECT_KEY` - Test project key (create a new Deta project for testing and grab the generated key). +- `DETA_SDK_TEST_BASE_NAME` - Name of your Base, default is fine. +- `DETA_SDK_TEST_DRIVE_NAME` - Name of your Drive, default is fine. +- `DETA_SDK_TEST_DRIVE_HOST` - Host URL, default is fine. ### Run the tests ```sh -python tests.py +pytest tests ``` 🎉 Now you are ready to contribute! ### How to contribute -1. Git clone and make a feature branch -2. Make a draft PR +1. Clone this repo and make a feature branch +2. Make a draft pull request 3. Make your changes to the feature branch -4. Mark draft as ready for review +4. Mark your draft as ready for review diff --git a/README.md b/README.md index 247118d..ad556d6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Deta Python Library (SDK) -Supports Python 3.5+ only. [Read the docs.](https://docs.deta.sh/docs/base/sdk) +Supports Python 3.6+ only. [Read the docs.](https://docs.deta.sh/docs/base/sdk) -Install from PyPi +Install from PyPI ```sh pip install deta @@ -13,4 +13,4 @@ If you are interested in contributing, please look at [**CONTRIBUTING.md**](CONT ## How to release (for maintainers) 1. Add changes to `CHANGELOG.md` 2. Merge the `master` branch with the `release` branch. -3. After scripts finish, update release and tag with relevant info \ No newline at end of file +3. After scripts finish, update release and tag with relevant info diff --git a/deta/__init__.py b/deta/__init__.py index 1eec891..ad81c03 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -1,78 +1,87 @@ import os +import json import urllib.error import urllib.request -import json +from typing import Optional, Sequence, Union from .base import _Base +from ._async.client import _AsyncBase from .drive import _Drive from .utils import _get_project_key_id - try: - from detalib.app import App + from detalib.app import App # type: ignore app = App() -except Exception: - pass - -try: - from ._async.client import AsyncBase except ImportError: pass __version__ = "1.1.0" - -def Base(name: str): +def Base(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() - return _Base(name, project_key, project_id) + return _Base(name, project_key, project_id, host=host) -def Drive(name: str): +# TODO: type hint for session +def AsyncBase(name: str, host: Optional[str] = None, session=None): project_key, project_id = _get_project_key_id() - return _Drive(name, project_key, project_id) - + return _AsyncBase(name, project_key, project_id, host=host, session=session) -class Deta: - def __init__(self, project_key: str = None, *, project_id: str = None): - project_key, project_id = _get_project_key_id(project_key, project_id) - self.project_key = project_key - self.project_id = project_id - - def Base(self, name: str, host: str = None): - return _Base(name, self.project_key, self.project_id, host) - - def AsyncBase(self, name: str, host: str = None): - from ._async.client import _AsyncBase - return _AsyncBase(name, self.project_key, self.project_id, host) - - def Drive(self, name: str, host: str = None): - return _Drive( - name=name, - project_key=self.project_key, - project_id=self.project_id, - host=host, - ) - def send_email(self, to, subject, message, charset="UTF-8"): - return send_email(to, subject, message, charset) +def Drive(name: str, host: Optional[str] = None): + project_key, project_id = _get_project_key_id() + return _Drive(name, project_key, project_id, host=host) -def send_email(to, subject, message, charset="UTF-8"): +class Deta: + def __init__(self, project_key: Optional[str] = None, *, project_id: Optional[str] = None): + self.project_key, self.project_id = _get_project_key_id(project_key, project_id) + + def Base(self, name: str, host: Optional[str] = None): + return _Base(name, self.project_key, self.project_id, host=host) + + # TODO: type hint for session + def AsyncBase(self, name: str, host: Optional[str] = None, session=None): + return _AsyncBase(name, self.project_key, self.project_id, host=host, session=session) + + def Drive(self, name: str, host: Optional[str] = None): + return _Drive(name, self.project_key, self.project_id, host=host) + + def send_email( + self, + to: Union[str, Sequence[str]], + subject: str, + message: str, + charset: str = "utf-8", + ): + send_email(to, subject, message, charset) + + +def send_email( + to: Union[str, Sequence[str]], + subject: str, + message: str, + charset: str = "utf-8", +): + # FIXME: should function continue if these are not present? pid = os.getenv("AWS_LAMBDA_FUNCTION_NAME") url = os.getenv("DETA_MAILER_URL") api_key = os.getenv("DETA_PROJECT_KEY") endpoint = f"{url}/mail/{pid}" - to = to if type(to) == list else [to] + if isinstance(to, str): + to = [to] + else: + to = list(to) + data = { "to": to, "subject": subject, "message": message, "charset": charset, } - headers = {"X-API-Key": api_key} req = urllib.request.Request(endpoint, json.dumps(data).encode("utf-8"), headers) @@ -82,4 +91,4 @@ def send_email(to, subject, message, charset="UTF-8"): if resp.getcode() != 200: raise Exception(resp.read().decode("utf-8")) except urllib.error.URLError as e: - raise Exception(e.reason) + raise Exception(e.reason) from e diff --git a/deta/_async/client.py b/deta/_async/client.py index 18c21d5..7e54e43 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -1,52 +1,56 @@ -import typing - -import datetime import os -import aiohttp +import datetime +from typing import Mapping, Optional, Sequence, Union, overload from urllib.parse import quote -from deta.utils import _get_project_key_id -from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTRIBUTE - +try: + import aiohttp +except ImportError: + has_aiohttp = False +else: + has_aiohttp = True -def AsyncBase(name: str, *, session: aiohttp.ClientSession = None): - project_key, project_id = _get_project_key_id() - return _AsyncBase(name, project_key, project_id, session=session) +from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTRIBUTE class _AsyncBase: def __init__( - self, - name: str, - project_key: str, - project_id: str, - *, - host: str = None, - session: aiohttp.ClientSession = None + self, + name: str, + project_key: str, + project_id: str, + *, + host: Optional[str] = None, + session: "Optional[aiohttp.ClientSession]" = None, ): - if not project_key: - raise AssertionError("No Base name provided") + if not has_aiohttp: + raise RuntimeError("aiohttp library is required for async support") + + if not name: + raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" self._base_url = f"https://{host}/v1/{project_id}/{name}" self.util = Util() - self.__ttl_attribute = BASE_TTL_ATTRIBUTE + self._ttl_attribute = BASE_TTL_ATTRIBUTE - self._session = session or aiohttp.ClientSession( - headers={ - "Content-type": "application/json", - "X-API-Key": project_key, - }, - raise_for_status=True, - ) + if session is not None: + self._session = session + else: + self._session = aiohttp.ClientSession( + headers={ + "Content-type": "application/json", + "X-API-Key": project_key, + }, + raise_for_status=True, + ) - async def close(self) -> None: + async def close(self): await self._session.close() - async def get(self, key: str): + async def get(self, key: str) -> Optional[dict]: key = quote(key, safe="") - try: async with self._session.get(f"{self._base_url}/items/{key}") as resp: return await resp.json() @@ -58,113 +62,179 @@ async def get(self, key: str): async def delete(self, key: str): key = quote(key, safe="") - async with self._session.delete(f"{self._base_url}/items/{key}"): return + @overload async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + ) -> dict: + ... + + @overload + async def insert( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + + @overload + async def insert( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def insert(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - async with self._session.post( - f"{self._base_url}/items", json={"item": data} - ) as resp: + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) + async with self._session.post(f"{self._base_url}/items", json={"item": data}) as resp: return await resp.json() + @overload + async def put( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + ) -> dict: + ... + + @overload async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + + @overload + async def put( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def put(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - async with self._session.put( - f"{self._base_url}/items", json={"items": [data]} - ) as resp: - if resp.status == 207: - resp_json = await resp.json() - return resp_json["processed"]["items"][0] - else: + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) + async with self._session.put(f"{self._base_url}/items", json={"items": [data]}) as resp: + if resp.status != 207: return None + resp_json = await resp.json() + return resp_json["processed"]["items"][0] + + @overload + async def put_many( + self, + items: Sequence[Union[dict, list, str, int, bool]], + ) -> dict: + ... + @overload async def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): + expire_in: int, + ) -> dict: + ... + + @overload + async def put_many( + self, + items: Sequence[Union[dict, list, str, int, bool]], + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def put_many(self, items, *, expire_in=None, expire_at=None): if len(items) > 25: - raise AssertionError("We can't put more than 25 items at a time.") + raise ValueError("cannot put more than 25 items at a time") + _items = [] - for i in items: - data = i - if not isinstance(i, dict): - data = {"value": i} - insert_ttl( - data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at - ) + for item in items: + data = item + if not isinstance(item, dict): + data = {"value": item} + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) _items.append(data) - async with self._session.put( - f"{self._base_url}/items", json={"items": _items} - ) as resp: + async with self._session.put(f"{self._base_url}/items", json={"items": _items}) as resp: return await resp.json() async def fetch( self, - query: typing.Union[dict, list] = None, + query: Optional[Union[Mapping, Sequence[Mapping]]] = None, *, limit: int = 1000, - last: str = None, + last: Optional[str] = None, ): - payload = {} + payload = { + "limit": limit, + "last": last if not isinstance(last, bool) else None, + } + if query: - payload["query"] = query if isinstance(query, list) else [query] - if limit: - payload["limit"] = limit - if last: - payload["last"] = last + payload["query"] = query if isinstance(query, Sequence) else [query] + async with self._session.post(f"{self._base_url}/query", json=payload) as resp: resp_json = await resp.json() paging = resp_json.get("paging") - return FetchResponse( - paging.get("size"), paging.get("last"), resp_json.get("items") - ) + return FetchResponse(paging.get("size"), paging.get("last"), resp_json.get("items")) + + @overload + async def update( + self, + updates: Mapping, + key: str, + ): + ... + + @overload + async def update( + self, + updates: Mapping, + key: str, + *, + expire_in: int, + ): + ... + @overload async def update( self, - updates: dict, + updates: Mapping, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_at: Union[int, float, datetime.datetime], ): - if key == "": - raise ValueError("Key is empty") + ... + + async def update(self, updates, key, *, expire_in=None, expire_at=None): + if not key: + raise ValueError("parameter 'key' must be a non-empty string") payload = { "set": {}, @@ -173,29 +243,24 @@ async def update( "prepend": {}, "delete": [], } + if updates: for attr, value in updates.items(): if isinstance(value, Util.Trim): payload["delete"].append(attr) elif isinstance(value, Util.Increment): - payload["increment"][attr] = value.val + payload["increment"][attr] = value.value elif isinstance(value, Util.Append): - payload["append"][attr] = value.val + payload["append"][attr] = value.value elif isinstance(value, Util.Prepend): - payload["prepend"][attr] = value.val + payload["prepend"][attr] = value.value else: payload["set"][attr] = value if not payload: - raise ValueError("Provide at least one update action.") + raise ValueError("must provide at least one update action") - insert_ttl( - payload["set"], - self.__ttl_attribute, - expire_in=expire_in, - expire_at=expire_at, - ) - - key = quote(key, safe="") + insert_ttl(payload["set"], self._ttl_attribute, expire_in, expire_at) - await self._session.patch(f"{self._base_url}/items/{key}", json=payload) + encoded_key = quote(key, safe="") + await self._session.patch(f"{self._base_url}/items/{encoded_key}", json=payload) diff --git a/deta/base.py b/deta/base.py index 706235c..225a0f8 100644 --- a/deta/base.py +++ b/deta/base.py @@ -1,6 +1,6 @@ import os import datetime -import typing +from typing import Any, Dict, Mapping, Optional, Sequence, Union, overload from urllib.parse import quote from .service import _Service, JSON_MIME @@ -11,29 +11,19 @@ class FetchResponse: - def __init__(self, count=0, last=None, items=[]): - self._count = count - self._last = last - self._items = items + def __init__(self, count: int = 0, last: Optional[str] = None, items: Optional[list] = None): + self.count = count + self.last = last + self.items = items if items is not None else [] - @property - def count(self): - return self._count + def __eq__(self, other: "FetchResponse"): + return self.count == other.count and self.last == other.last and self.items == other.items - @property - def last(self): - return self._last + def __iter__(self): + return iter(self.items) - @property - def items(self): - return self._items - - def __eq__(self, other): - return ( - self.count == other.count - and self.last == other.last - and self.items == other.items - ) + def __len__(self) -> int: + return len(self.items) class Util: @@ -41,198 +31,239 @@ class Trim: pass class Increment: - def __init__(self, value=None): - self.val = value - if not value: - self.val = 1 + def __init__(self, value: Union[int, float] = 1): + self.value = value class Append: - def __init__(self, value): - self.val = value - if not isinstance(value, list): - self.val = [value] + def __init__(self, value: Union[dict, list, str, int, float, bool]): + self.value = value if isinstance(value, list) else [value] class Prepend: - def __init__(self, value): - self.val = value - if not isinstance(value, list): - self.val = [value] + def __init__(self, value: Union[dict, list, str, int, float, bool]): + self.value = value if isinstance(value, list) else [value] def trim(self): return self.Trim() - def increment(self, value: typing.Union[int, float] = None): + def increment(self, value: Union[int, float] = 1): return self.Increment(value) - def append(self, value: typing.Union[dict, list, str, int, float, bool]): + def append(self, value: Union[dict, list, str, int, float, bool]): return self.Append(value) - def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): + def prepend(self, value: Union[dict, list, str, int, float, bool]): return self.Prepend(value) class _Base(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): - assert name, "No Base name provided" + def __init__(self, name: str, project_key: str, project_id: str, *, host: Optional[str] = None): + if not name: + raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" - super().__init__( - project_key=project_key, - project_id=project_id, - host=host, - name=name, - timeout=BASE_SERVICE_TIMEOUT, - ) - self.__ttl_attribute = "__expires" + super().__init__(project_key, project_id, host, name, BASE_SERVICE_TIMEOUT) + self._ttl_attribute = BASE_TTL_ATTRIBUTE self.util = Util() - def get(self, key: str): - if key == "": - raise ValueError("Key is empty") + def get(self, key: str) -> Dict[str, Any]: + if not key: + raise ValueError("parameter 'key' must be a non-empty string") - # encode key key = quote(key, safe="") - _, res = self._request("/items/{}".format(key), "GET") - return res or None + _, res = self._request(f"/items/{key}", "GET") + return res def delete(self, key: str): """Delete an item from the database - key: the key of item to be deleted + + Args: + key: The key of item to be deleted. """ - if key == "": - raise ValueError("Key is empty") + if not key: + raise ValueError("parameter 'key' must be a non-empty string") - # encode key key = quote(key, safe="") - self._request("/items/{}".format(key), "DELETE") - return None + self._request(f"/items/{key}", "DELETE") + + @overload + def insert( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + ) -> dict: + ... + @overload def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + @overload + def insert( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + def insert(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - code, res = self._request( - "/items", "POST", {"item": data}, content_type=JSON_MIME - ) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) + code, res = self._request("/items", "POST", {"item": data}, content_type=JSON_MIME) + if code == 201: return res elif code == 409: - raise Exception(f"Item with key '{key}' already exists") + raise ValueError(f"item with key '{key}' already exists") + @overload def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, - *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - """store (put) an item in the database. Overrides an item if key already exists. - `key` could be provided as function argument or a field in the data dict. - If `key` is not provided, the server will generate a random 12 chars key. - """ + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + ) -> dict: + ... - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + @overload + def put( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + *, + expire_in: int, + ) -> dict: + ... + @overload + def put( + self, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + def put(self, data, key=None, *, expire_in=None, expire_at=None): + """Store (put) an item in the database. Overrides an item if key already exists. + `key` could be provided as an argument or a field in the data dict. + If `key` is not provided, the server will generate a random 12-character key. + """ + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - code, res = self._request( - "/items", "PUT", {"items": [data]}, content_type=JSON_MIME - ) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) + code, res = self._request("/items", "PUT", {"items": [data]}, content_type=JSON_MIME) return res["processed"]["items"][0] if res and code == 207 else None + @overload def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], + ) -> dict: + ... + + @overload + def put_many( + self, + items: Sequence[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - assert len(items) <= 25, "We can't put more than 25 items at a time." + expire_in: int, + ) -> dict: + ... + + @overload + def put_many( + self, + items: Sequence[Union[dict, list, str, int, bool]], + *, + expire_at: Union[int, float, datetime.datetime], + ) -> dict: + ... + + def put_many(self, items, *, expire_in=None, expire_at=None): + if len(items) > 25: + raise ValueError("cannot put more than 25 items at a time") + _items = [] - for i in items: - data = i - if not isinstance(i, dict): - data = {"value": i} - insert_ttl( - data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at - ) + for item in items: + data = item + if not isinstance(item, dict): + data = {"value": item} + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) _items.append(data) - _, res = self._request( - "/items", "PUT", {"items": _items}, content_type=JSON_MIME - ) + _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res - def _fetch( + def fetch( self, - query: typing.Union[dict, list] = None, - buffer: int = None, - last: str = None, - ) -> typing.Optional[typing.Tuple[int, list]]: - """This is where actual fetch happens.""" + query: Optional[Union[Mapping, Sequence[Mapping]]] = None, + *, + limit: int = 1000, + last: Optional[str] = None, + ): + """Fetch items from the database. `query` is an optional filter or list of filters. + Without a filter, it will return the whole db. + """ payload = { - "limit": buffer, + "limit": limit, "last": last if not isinstance(last, bool) else None, } if query: - payload["query"] = query if isinstance(query, list) else [query] + payload["query"] = query if isinstance(query, Sequence) else [query] - code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) - return code, res + _, res = self._request("/query", "POST", payload, content_type=JSON_MIME) + paging = res.get("paging") + return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) - def fetch( + @overload + def update( self, - query: typing.Union[dict, list] = None, - *, - limit: int = 1000, - last: str = None, + updates: Mapping, + key: str, ): - """ - fetch items from the database. - `query` is an optional filter or list of filters. Without filter, it will return the whole db. - """ - _, res = self._fetch(query, limit, last) - - paging = res.get("paging") + ... - return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) + @overload + def update( + self, + updates: Mapping, + key: str, + *, + expire_in: int, + ): + ... + @overload def update( self, - updates: dict, + updates: Mapping, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_at: Union[int, float, datetime.datetime], ): - """ - update an item in the database - `updates` specifies the attribute names and values to update,add or remove - `key` is the kye of the item to be updated - """ + ... - if key == "": - raise ValueError("Key is empty") + def update(self, updates, key, *, expire_in=None, expire_at=None): + """Update an item in the database. + `updates` specifies the attribute names and values to update, add or remove. + `key` is the key of the item to be updated. + """ + if not key: + raise ValueError("parameter 'key' must be a non-empty string") payload = { "set": {}, @@ -241,39 +272,37 @@ def update( "prepend": {}, "delete": [], } + if updates: for attr, value in updates.items(): if isinstance(value, Util.Trim): payload["delete"].append(attr) elif isinstance(value, Util.Increment): - payload["increment"][attr] = value.val + payload["increment"][attr] = value.value elif isinstance(value, Util.Append): - payload["append"][attr] = value.val + payload["append"][attr] = value.value elif isinstance(value, Util.Prepend): - payload["prepend"][attr] = value.val + payload["prepend"][attr] = value.value else: payload["set"][attr] = value - insert_ttl( - payload["set"], - self.__ttl_attribute, - expire_in=expire_in, - expire_at=expire_at, - ) + insert_ttl(payload["set"], self._ttl_attribute, expire_in, expire_at) encoded_key = quote(key, safe="") - code, _ = self._request( - "/items/{}".format(encoded_key), "PATCH", payload, content_type=JSON_MIME - ) - if code == 200: - return None - elif code == 404: - raise Exception("Key '{}' not found".format(key)) + code, _ = self._request(f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME) + if code == 404: + raise ValueError(f"key '{key}' not found") -def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): +def insert_ttl( + item, + ttl_attribute: str, + expire_in: Optional[Union[int, float]] = None, + expire_at: Optional[Union[int, float, datetime.datetime]] = None, +): if expire_in and expire_at: - raise ValueError("both expire_in and expire_at provided") + raise ValueError("'expire_in' and 'expire_at' are mutually exclusive parameters") + if not expire_in and not expire_at: return @@ -282,8 +311,7 @@ def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): if isinstance(expire_at, datetime.datetime): expire_at = expire_at.replace(microsecond=0).timestamp() - - if not isinstance(expire_at, (int, float)): - raise TypeError("expire_at should one one of int, float or datetime") + elif not isinstance(expire_at, (int, float)): + raise TypeError("'expire_at' must be of type 'int', 'float' or 'datetime'") item[ttl_attribute] = int(expire_at) diff --git a/deta/drive.py b/deta/drive.py index 2d4db38..73875e6 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -1,6 +1,6 @@ import os -import typing from io import BufferedIOBase, TextIOBase, RawIOBase, StringIO, BytesIO +from typing import Any, Dict, Iterator, Optional, Sequence, Union, overload from urllib.parse import quote_plus from .service import JSON_MIME, _Service @@ -14,174 +14,134 @@ class DriveStreamingBody: def __init__(self, res: BufferedIOBase): - self.__stream = res + self._stream = res @property def closed(self): - return self.__stream.closed + return self._stream.closed - def read(self, size: int = None): - return self.__stream.read(size) + def read(self, size: Optional[int] = None) -> bytes: + return self._stream.read(size) - def iter_chunks(self, chunk_size: int = 1024): + def iter_chunks(self, chunk_size: int = 1024) -> Iterator[bytes]: while True: - chunk = self.__stream.read(chunk_size) + chunk = self._stream.read(chunk_size) if not chunk: break yield chunk - - def iter_lines(self, chunk_size: int = 1024): + + def iter_lines(self, chunk_size: int = 1024) -> Iterator[bytes]: while True: - chunk = self.__stream.readline(chunk_size) + chunk = self._stream.readline(chunk_size) if not chunk: break yield chunk def close(self): - # close stream try: - self.__stream.close() - except: + self._stream.close() + except Exception: pass class _Drive(_Service): - def __init__( - self, - name: str = None, - project_key: str = None, - project_id: str = None, - host: str = None, - ): - assert name, "No Drive name provided" - host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" + def __init__(self, name: str, project_key: str, project_id: str, *, host: Optional[str] = None): + if not name: + raise ValueError("parameter 'name' must be a non-empty string") - super().__init__( - project_key=project_key, - project_id=project_id, - host=host, - name=name, - timeout=DRIVE_SERVICE_TIMEOUT, - keep_alive=False, - ) + host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" + super().__init__(project_key, project_id, host, name, DRIVE_SERVICE_TIMEOUT, False) - def _quote(self, param: str): + def _quote(self, param: str) -> str: return quote_plus(param) - def get(self, name: str): - """Get/Download a file from drive. + def get(self, name: str) -> Optional[DriveStreamingBody]: + """Download a file from drive. `name` is the name of the file. Returns a DriveStreamingBody. """ - assert name, "No name provided" - _, res = self._request( - f"/files/download?name={self._quote(name)}", "GET", stream=True - ) + if not name: + raise ValueError("parameter 'name' must be a non-empty string") + + _, res = self._request(f"/files/download?name={self._quote(name)}", "GET", stream=True) if res: return DriveStreamingBody(res) - return None - def delete_many(self, names: typing.List[str]): - """Delete many files from drive in single request. - `names` are the names of the files to be deleted. - Returns a dict with 'deleted' and 'failed' files. - """ - assert names, "Names is empty" - assert len(names) <= 1000, "More than 1000 names to delete" - _, res = self._request( - "/files", "DELETE", {"names": names}, content_type=JSON_MIME - ) - return res - - def delete(self, name: str): + def delete(self, name: str) -> str: """Delete a file from drive. `name` is the name of the file. Returns the name of the file deleted. """ - assert name, "Name not provided or empty" + if not name: + raise ValueError("parameter 'name' must be a non-empty string") + payload = self.delete_many([name]) failed = payload.get("failed") if failed: - raise Exception(f"Failed to delete '{name}':{failed[name]}") + raise Exception(f"failed to delete '{name}': {failed[name]}") + return name - def list(self, limit: int = 1000, prefix: str = None, last: str = None): - """List file names from drive. - `limit` is the limit of number of file names to get, defaults to 1000. - `prefix` is the prefix of file names. - `last` is the last name seen in the a previous paginated response. - Returns a dict with 'paging' and 'names'. + def delete_many(self, names: Sequence[str]) -> dict: + """Delete many files from drive in single request. + `names` are the names of the files to be deleted. + Returns a dict with 'deleted' and 'failed' files. """ - url = f"/files?limit={limit}" - if prefix: - url += f"&prefix={prefix}" - if last: - url += f"&last={last}" - _, res = self._request(url, "GET") - return res + if not names: + raise ValueError("parameter 'names' must be a non-empty list") - def _start_upload(self, name: str): - _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") - return res["upload_id"] - - def _finish_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") + if len(names) > 1000: + raise ValueError("cannot delete more than 1000 items") - def _abort_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + _, res = self._request("/files", "DELETE", {"names": names}, content_type=JSON_MIME) + return res - def _upload_part( + @overload + def put( self, name: str, - chunk: bytes, - upload_id: str, - part: int, - content_type: str = None, - ): - self._request( - f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", - "POST", - data=chunk, - content_type=content_type, - ) - - def _get_content_stream( - self, data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] - ): - if isinstance(data, str): - return StringIO(data) - elif isinstance(data, bytes): - return BytesIO(data) - return data + data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + *, + content_type: str, + ) -> str: + ... + @overload def put( self, name: str, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] = None, *, - path: str = None, - content_type: str = None, + path: str, + content_type: str, ) -> str: + ... + + def put(self, name, data=None, *, path=None, content_type=None): """Put a file in drive. `name` is the name of the file. `data` is the data to be put. `content_type` is the mime type of the file. Returns the name of the file. """ - assert name, "No name provided" - assert path or data, "No data or path provided" - assert not (path and data), "Both path and data provided" + if not name: + raise ValueError("parameter 'name' must be a non-empty string") + if not path and not data: + raise ValueError("must provide either 'data' or 'path'") + if path and data: + raise ValueError("'data' and 'path' are exclusive parameters") # start upload upload_id = self._start_upload(name) - - content_stream = open(path, "rb") if path else self._get_content_stream(data) + if data: + content_stream = self._get_content_stream(data) + else: + content_stream = open(path, "rb") part = 1 # upload chunks while True: chunk = content_stream.read(UPLOAD_CHUNK_SIZE) - ## eof stop the loop + # eof stop the loop if not chunk: self._finish_upload(name, upload_id) content_stream.close() @@ -197,3 +157,61 @@ def put( self._abort_upload(name, upload_id) content_stream.close() raise e + + def list( + self, + limit: int = 1000, + prefix: Optional[str] = None, + last: Optional[str] = None, + ) -> Dict[str, Any]: + """List file names from drive. + `limit` is the number of file names to get, defaults to 1000. + `prefix` is the prefix of file names. + `last` is the last name seen in a previous paginated response. + Returns a dict with 'paging' and 'names'. + """ + url = f"/files?limit={limit}" + if prefix: + url += f"&prefix={prefix}" + + if last: + url += f"&last={last}" + + _, res = self._request(url, "GET") + return res + + def _start_upload(self, name: str): + _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") + return res["upload_id"] + + def _finish_upload(self, name: str, upload_id: str): + self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") + + def _abort_upload(self, name: str, upload_id: str): + self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + + def _upload_part( + self, + name: str, + chunk: bytes, + upload_id: str, + part: int, + content_type: Optional[str] = None, + ): + self._request( + f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", + "POST", + data=chunk, + content_type=content_type, + ) + + def _get_content_stream( + self, + data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + ): + if isinstance(data, str): + return StringIO(data) + elif isinstance(data, bytes): + return BytesIO(data) + else: + return data diff --git a/deta/py.typed b/deta/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/deta/service.py b/deta/service.py index 260bc16..67b4d20 100644 --- a/deta/service.py +++ b/deta/service.py @@ -1,10 +1,10 @@ -import http.client import os import json import socket +import http.client import struct -import typing import urllib.error +from typing import Mapping, MutableMapping, Optional, Tuple, Union JSON_MIME = "application/json" @@ -20,13 +20,11 @@ def __init__( keep_alive: bool = True, ): self.project_key = project_key - self.base_path = "/v1/{0}/{1}".format(project_id, name) + self.base_path = f"/v1/{project_id}/{name}" self.host = host self.timeout = timeout self.keep_alive = keep_alive - self.client = ( - http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None - ) + self.client = http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None def _is_socket_closed(self): if not self.client.sock: @@ -44,11 +42,11 @@ def _request( self, path: str, method: str, - data: typing.Union[str, bytes, dict] = None, - headers: dict = None, - content_type: str = None, + data: Optional[Union[str, bytes, dict]] = None, + headers: Optional[MutableMapping[str, str]] = None, + content_type: Optional[str] = None, stream: bool = False, - ): + ) -> Tuple[int, Optional[dict]]: url = self.base_path + path headers = headers or {} headers["X-Api-Key"] = self.project_key @@ -66,7 +64,7 @@ def _request( and self._is_socket_closed() ): self.client.close() - except: + except Exception: pass # send request @@ -92,9 +90,7 @@ def _request( # return json if application/json payload = ( - json.loads(res.read()) - if JSON_MIME in res.getheader("content-type") - else res.read() + json.loads(res.read()) if JSON_MIME in res.getheader("content-type") else res.read() ) if not self.keep_alive: @@ -105,17 +101,16 @@ def _send_request_with_retry( self, method: str, url: str, - headers: dict = None, - body: typing.Union[str, bytes, dict] = None, - retry=2, # try at least twice to regain a new connection + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[str, bytes, dict]] = None, + retry: int = 2, # try at least twice to regain a new connection ): - reinitialize_connection = False + headers = headers if headers is not None else {} + reinit_connection = False while retry > 0: try: - if not self.keep_alive or reinitialize_connection: - self.client = http.client.HTTPSConnection( - host=self.host, timeout=self.timeout - ) + if not self.keep_alive or reinit_connection: + self.client = http.client.HTTPSConnection(host=self.host, timeout=self.timeout) self.client.request( method, @@ -126,5 +121,5 @@ def _send_request_with_retry( res = self.client.getresponse() return res except http.client.RemoteDisconnected: - reinitialize_connection = True + reinit_connection = True retry -= 1 diff --git a/deta/utils.py b/deta/utils.py index 5bd127d..95c0dd0 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,16 +1,17 @@ import os +from typing import Optional -def _get_project_key_id(project_key: str = None, project_id: str = None): - project_key = project_key or os.getenv("DETA_PROJECT_KEY", "") +def _get_project_key_id(project_key: Optional[str] = None, project_id: Optional[str] = None): + project_key = project_key or os.getenv("DETA_PROJECT_KEY") if not project_key: - raise AssertionError("No project key defined") + raise ValueError("no project key defined") if not project_id: project_id = project_key.split("_")[0] if project_id == project_key: - raise AssertionError("Bad project key provided") + raise ValueError("bad project key provided") return project_key, project_id diff --git a/env.sample b/env.sample index 0f3d86f..dc4ffc5 100644 --- a/env.sample +++ b/env.sample @@ -1,4 +1,5 @@ DETA_SDK_TEST_PROJECT_KEY= DETA_SDK_TEST_DRIVE_NAME=testdrive DETA_SDK_TEST_DRIVE_HOST=drive.deta.sh -DETA_SDK_TEST_BASE_NAME=testbase \ No newline at end of file +DETA_SDK_TEST_BASE_NAME=testbase +DETA_SDK_TEST_TTL_ATTRIBUTE=__expires diff --git a/requirements.txt b/requirements.txt index c0c3436..04043b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ twine # Requirements for testing [async] extra aiohttp pytest-asyncio -pytest \ No newline at end of file +pytest diff --git a/scripts/build b/scripts/build index 2dd0a9a..6f9d97a 100755 --- a/scripts/build +++ b/scripts/build @@ -1,7 +1,7 @@ #!/bin/sh -e -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi diff --git a/scripts/install b/scripts/install index a150cff..123cb74 100755 --- a/scripts/install +++ b/scripts/install @@ -4,7 +4,7 @@ [ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3" REQUIREMENTS="requirements.txt" -VENV="venv" +VENV=".venv" #set -x diff --git a/scripts/publish b/scripts/publish index 61e80af..8bd2bba 100755 --- a/scripts/publish +++ b/scripts/publish @@ -3,18 +3,18 @@ VERSION_FILE="deta/__init__.py" SETUP_FILE="setup.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` - VERSION_SETUP=`grep version ${SETUP_FILE}| grep -o '[0-9][^"]*'` + VERSION_SETUP=`grep version ${SETUP_FILE} | grep -o '[0-9][^"]*'` if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" diff --git a/scripts/release b/scripts/release index 925c7b8..c585a54 100755 --- a/scripts/release +++ b/scripts/release @@ -3,18 +3,18 @@ VERSION_FILE="deta/__init__.py" SETUP_FILE="setup.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` - VERSION_SETUP=`grep version ${SETUP_FILE}| grep -o '[0-9][^"]*'` + VERSION_SETUP=`grep version ${SETUP_FILE} | grep -o '[0-9][^"]*'` if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" diff --git a/scripts/tag b/scripts/tag index 5cb1bca..8b6af1e 100755 --- a/scripts/tag +++ b/scripts/tag @@ -10,7 +10,7 @@ if [ ! -z "$GITHUB_ACTIONS" ]; then if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" - exit 15 + exit 1 fi fi diff --git a/scripts/test_publish b/scripts/test_publish index 255af16..81b3321 100755 --- a/scripts/test_publish +++ b/scripts/test_publish @@ -2,15 +2,15 @@ VERSION_FILE="deta/__init__.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` diff --git a/setup.py b/setup.py index 646e8f6..045b2e0 100644 --- a/setup.py +++ b/setup.py @@ -17,4 +17,8 @@ extras_require={ "async": ["aiohttp>=3,<4"], }, + package_data={ + 'deta': ['py.typed'], + }, + zip_safe=False, ) diff --git a/tests/test_async.py b/tests/test_async.py index 17fbf16..d6400db 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv load_dotenv() -except: +except ImportError: pass pytestmark = pytest.mark.asyncio @@ -18,7 +18,7 @@ PROJECT_KEY = os.getenv("DETA_SDK_TEST_PROJECT_KEY") BASE_NAME = os.getenv("DETA_SDK_TEST_BASE_NAME") -BASE_TEST_TTL_ATTRIBUTE = os.getenv("DETA_SDK_TEST_TTL_ATTRIBUTE") +BASE_TEST_TTL_ATTRIBUTE = os.getenv("DETA_SDK_TEST_TTL_ATTRIBUTE") or "__expires" @pytest.fixture() @@ -61,7 +61,7 @@ async def test_put(db): for input in ["Hello", 1, True, False, 3.14159265359]: resp = await db.put(input) - assert set(resp.keys()) == set(["key", "value"]) + assert set(resp.keys()) == {"key", "value"} async def test_put_fail(db): @@ -86,13 +86,13 @@ async def test_put_many_fail(db): async def test_put_many_fail_limit(db): with pytest.raises(Exception): - await db.put_many([i for i in range(26)]) + await db.put_many(list(range(26))) async def test_insert(db): item = {"msg": "hello"} resp = await db.insert(item) - assert set(resp.keys()) == set(["key", "msg"]) + assert set(resp.keys()) == {"key", "msg"} async def test_insert_fail(db, items): @@ -105,15 +105,15 @@ async def test_get(db, items): assert resp == items[0] resp = await db.get("key_does_not_exist") - assert resp == None + assert resp is None async def test_delete(db, items): resp = await db.delete(items[0]["key"]) - assert resp == None + assert resp is None resp = await db.delete("key_does_not_exist") - assert resp == None + assert resp is None async def test_fetch(db, items): @@ -149,9 +149,7 @@ async def test_fetch(db, items): ) assert res3 == expectedItem - res4 = await db.fetch( - [{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2" - ) + res4 = await db.fetch([{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2") expectedItem = FetchResponse( 1, None, @@ -186,7 +184,7 @@ async def test_fetch(db, items): async def test_update(db, items): resp = await db.update({"value.name": "spongebob"}, "existing4") - assert resp == None + assert resp is None resp = await db.get("existing4") expectedItem = {"key": "existing4", "value": {"name": "spongebob"}} @@ -194,7 +192,7 @@ async def test_update(db, items): resp = await db.update({"value.name": db.util.trim(), "value.age": 32}, "existing4") - assert resp == None + assert resp is None expectedItem = {"key": "existing4", "value": {"age": 32}} resp = await db.get("existing4") @@ -207,13 +205,13 @@ async def test_update(db, items): }, "%@#//#!#)#$_", ) - assert resp == None + assert resp is None resp = await db.update( {"list": db.util.prepend("x"), "value": db.util.increment(2)}, "%@#//#!#)#$_", ) - assert resp == None + assert resp is None expectedItem = {"key": "%@#//#!#)#$_", "list": ["x", "a", "b", "c"], "value": 3} resp = await db.get("%@#//#!#)#$_") assert resp == expectedItem @@ -258,7 +256,7 @@ def get_expire_in(expire_in): async def test_ttl(db, items): item1 = items[0] expire_in = 300 - expire_at = datetime.datetime(2022, 3, 1, 12, 30, 30) + expire_at = datetime.datetime.now() + datetime.timedelta(seconds=300) delta = 2 # allow time delta of 2 seconds test_cases = [ { @@ -323,9 +321,7 @@ async def test_ttl(db, items): # update # only if one of expire_in or expire_at if cexp_in or cexp_at: - await db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + await db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) got = await db.get(item.get("key")) assert abs(expected - got.get(BASE_TEST_TTL_ATTRIBUTE)) <= cdelta else: @@ -336,6 +332,4 @@ async def test_ttl(db, items): with pytest.raises(error): await db.insert(item, expire_in=cexp_in, expire_at=cexp_at) with pytest.raises(error): - await db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + await db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) diff --git a/tests/test_sync.py b/tests/test_sync.py index c4070b0..c58b65e 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv load_dotenv() -except: +except ImportError: pass @@ -209,11 +209,11 @@ def test_put(self): self.assertEqual(self.db.put(item, "one"), resp) self.assertEqual(self.db.put(item, "one"), resp) self.assertEqual({"msg": "hello"}, item) - self.assertEqual(set(self.db.put("Hello").keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(1).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(True).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(False).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(3.14159265359).keys()), set(["key", "value"])) + self.assertEqual(set(self.db.put("Hello").keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(1).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(True).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(False).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(3.14159265359).keys()), {"key", "value"}) @unittest.expectedFailure def test_put_fail(self): @@ -231,11 +231,11 @@ def test_put_many_fail(self): @unittest.expectedFailure def test_put_many_fail_limit(self): - self.db.put_many([i for i in range(26)]) + self.db.put_many(list(range(26))) def test_insert(self): item = {"msg": "hello"} - self.assertEqual(set(self.db.insert(item).keys()), set(["key", "msg"])) + self.assertEqual(set(self.db.insert(item).keys()), {"key", "msg"}) self.assertEqual({"msg": "hello"}, item) @unittest.expectedFailure @@ -283,9 +283,7 @@ def test_fetch(self): ) self.assertEqual(res3, expectedItem) - res4 = self.db.fetch( - [{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2" - ) + res4 = self.db.fetch([{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2") expectedItem = FetchResponse( 1, None, @@ -323,9 +321,7 @@ def test_update(self): self.assertEqual(self.db.get("existing4"), expectedItem) self.assertIsNone( - self.db.update( - {"value.name": self.db.util.trim(), "value.age": 32}, "existing4" - ) + self.db.update({"value.name": self.db.util.trim(), "value.age": 32}, "existing4") ) expectedItem = {"key": "existing4", "value": {"age": 32}} self.assertEqual(self.db.get("existing4"), expectedItem) @@ -382,7 +378,7 @@ def get_expire_in(self, expire_in): def test_ttl(self): expire_in = 300 - expire_at = datetime.datetime(2022, 3, 1, 12, 30, 30) + expire_at = datetime.datetime.now() + datetime.timedelta(seconds=300) delta = 2 # allow time delta of 2 seconds test_cases = [ { @@ -438,43 +434,29 @@ def test_ttl(self): # put self.db.put(item, expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # insert # need to udpate key as insert does not allow pre existing key item["key"] = "".join(random.choices(string.ascii_lowercase, k=6)) self.db.insert(item, expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # put many self.db.put_many([item], expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # update # only if one of expire_in or expire_at if cexp_in or cexp_at: - self.db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + self.db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) else: - self.assertRaises( - error, self.db.put, item, expire_in=cexp_in, expire_at=cexp_at - ) - self.assertRaises( - error, self.db.insert, item, expire_in=cexp_in, expire_at=cexp_at - ) + self.assertRaises(error, self.db.put, item, expire_in=cexp_in, expire_at=cexp_at) + self.assertRaises(error, self.db.insert, item, expire_in=cexp_in, expire_at=cexp_at) self.assertRaises( error, self.db.put_many,