From 63f844204aceeaa259c680b6270d9f433b79d61d Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Sun, 27 Mar 2022 01:18:22 -0400 Subject: [PATCH 01/12] Clean up code --- .github/workflows/pull_request.yml | 12 ++++++------ .github/workflows/tag_release.yml | 6 +++--- .github/workflows/test_release.yml | 5 ++--- .gitignore | 8 +++++--- CONTRIBUTING.md | 20 ++++++++++---------- README.md | 6 +++--- deta/__init__.py | 3 +-- deta/_async/client.py | 5 +++-- deta/base.py | 19 ++++++++++--------- deta/drive.py | 30 +++++++++++++++++++----------- deta/service.py | 10 +++++----- deta/utils.py | 8 ++++---- env.sample | 3 ++- requirements.txt | 2 +- tests/test_async.py | 20 ++++++++++---------- tests/test_sync.py | 4 ++-- 16 files changed, 86 insertions(+), 75 deletions(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 4a536b3..c5b72df 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -12,7 +12,7 @@ jobs: # tests can't run in parallel as they write and read data with same keys max-parallel: 1 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.6", "3.7", "3.8", "3.9"] steps: # Get the code into the container - name: Checkout @@ -25,12 +25,12 @@ jobs: # Test the code - name: Test code env: - DETA_SDK_TEST_PROJECT_KEY: ${{secrets.DETA_SDK_TEST_PROJECT_KEY}} - DETA_SDK_TEST_BASE_NAME: ${{secrets.DETA_SDK_TEST_BASE_NAME}} - DETA_SDK_TEST_DRIVE_NAME: ${{secrets.DETA_SDK_TEST_DRIVE_NAME}} - DETA_SDK_TEST_DRIVE_HOST: ${{secrets.DETA_SDK_TEST_DRIVE_HOST}} + DETA_SDK_TEST_PROJECT_KEY: ${{ secrets.DETA_SDK_TEST_PROJECT_KEY }} + DETA_SDK_TEST_BASE_NAME: ${{ secrets.DETA_SDK_TEST_BASE_NAME }} + DETA_SDK_TEST_DRIVE_NAME: ${{ secrets.DETA_SDK_TEST_DRIVE_NAME }} + DETA_SDK_TEST_DRIVE_HOST: ${{ secrets.DETA_SDK_TEST_DRIVE_HOST }} DETA_SDK_TEST_TTL_ATTRIBUTE: __expires run: | python -m pip install --upgrade pip python -m pip install pytest pytest-asyncio aiohttp - pytest tests \ No newline at end of file + pytest tests diff --git a/.github/workflows/tag_release.yml b/.github/workflows/tag_release.yml index da880e0..22047d9 100644 --- a/.github/workflows/tag_release.yml +++ b/.github/workflows/tag_release.yml @@ -1,6 +1,6 @@ name: Lint and tag before release -on: +on: push: branches: - release @@ -16,10 +16,10 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: ^3.5 + python-version: ^3.6 # Install dependencies - name: Install dependencies run: "scripts/install" # Make tag - name: Make git tag - run: "scripts/tag" \ No newline at end of file + run: "scripts/tag" diff --git a/.github/workflows/test_release.yml b/.github/workflows/test_release.yml index cf59e6a..b2f142a 100644 --- a/.github/workflows/test_release.yml +++ b/.github/workflows/test_release.yml @@ -1,6 +1,6 @@ name: Test release -on: +on: workflow_dispatch: jobs: @@ -14,7 +14,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: ^3.5 + python-version: ^3.6 # Install dependencies - name: Install dependencies run: "scripts/install" @@ -27,4 +27,3 @@ jobs: env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TEST_TOKEN }} - \ No newline at end of file diff --git a/.gitignore b/.gitignore index 567f9df..d48fe4b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,13 @@ .* !.gitignore !.github/ -__pycache__ +__pycache__/ test.py build/ +dist/ *.egg-info -dist testEnv .env -venv/ \ No newline at end of file +env/ +venv/ +.venv/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index af12021..a817003 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,10 +17,10 @@ source .venv/bin/activate **Windows** ```powershell -venv/scripts/activate +.venv/scripts/activate ``` -### Installing the dependencies +### Install the dependencies ```sh pip install -r requirements.txt @@ -36,21 +36,21 @@ cp env.sample .env Then provide the values as follows: -- `DETA_SDK_TEST_PROJECT_KEY` – Test project key (create a new Deta project for testing and grab the generated key). -- `DETA_SDK_TEST_BASE_NAME` – Name of your Base, default is fine. -- `DETA_SDK_TEST_DRIVE_NAME` – Name of your Drive, default is fine. -- `DETA_SDK_TEST_DRIVE_HOST` – Host URL, default is fine. +- `DETA_SDK_TEST_PROJECT_KEY` - Test project key (create a new Deta project for testing and grab the generated key). +- `DETA_SDK_TEST_BASE_NAME` - Name of your Base, default is fine. +- `DETA_SDK_TEST_DRIVE_NAME` - Name of your Drive, default is fine. +- `DETA_SDK_TEST_DRIVE_HOST` - Host URL, default is fine. ### Run the tests ```sh -python tests.py +pytest tests ``` 🎉 Now you are ready to contribute! ### How to contribute -1. Git clone and make a feature branch -2. Make a draft PR +1. Clone this repo and make a feature branch +2. Make a draft pull request 3. Make your changes to the feature branch -4. Mark draft as ready for review +4. Mark your draft as ready for review diff --git a/README.md b/README.md index 247118d..ad556d6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Deta Python Library (SDK) -Supports Python 3.5+ only. [Read the docs.](https://docs.deta.sh/docs/base/sdk) +Supports Python 3.6+ only. [Read the docs.](https://docs.deta.sh/docs/base/sdk) -Install from PyPi +Install from PyPI ```sh pip install deta @@ -13,4 +13,4 @@ If you are interested in contributing, please look at [**CONTRIBUTING.md**](CONT ## How to release (for maintainers) 1. Add changes to `CHANGELOG.md` 2. Merge the `master` branch with the `release` branch. -3. After scripts finish, update release and tag with relevant info \ No newline at end of file +3. After scripts finish, update release and tag with relevant info diff --git a/deta/__init__.py b/deta/__init__.py index 1eec891..129e65b 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -23,7 +23,6 @@ __version__ = "1.1.0" - def Base(name: str): project_key, project_id = _get_project_key_id() return _Base(name, project_key, project_id) @@ -65,7 +64,7 @@ def send_email(to, subject, message, charset="UTF-8"): api_key = os.getenv("DETA_PROJECT_KEY") endpoint = f"{url}/mail/{pid}" - to = to if type(to) == list else [to] + to = to if isinstance(to, list) else [to] data = { "to": to, "subject": subject, diff --git a/deta/_async/client.py b/deta/_async/client.py index f183253..26547a2 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -17,7 +17,7 @@ def AsyncBase(name: str): class _AsyncBase: def __init__(self, name: str, project_key: str, project_id: str, host: str = None): if not project_key: - raise AssertionError("No Base name provided") + raise ValueError("Base name not provided or empty") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" self._base_url = f"https://{host}/v1/{project_id}/{name}" @@ -110,7 +110,8 @@ async def put_many( expire_at: typing.Union[int, float, datetime.datetime] = None, ): if len(items) > 25: - raise AssertionError("We can't put more than 25 items at a time.") + raise ValueError("Cannot put more than 25 items at a time") + _items = [] for i in items: data = i diff --git a/deta/base.py b/deta/base.py index c15c127..ac5e315 100644 --- a/deta/base.py +++ b/deta/base.py @@ -1,6 +1,5 @@ import os import datetime -from re import I import typing from urllib.parse import quote @@ -74,7 +73,8 @@ def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): class _Base(_Service): def __init__(self, name: str, project_key: str, project_id: str, host: str = None): - assert name, "No Base name provided" + if not name: + raise ValueError("Base name not provided or empty") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" super().__init__( @@ -93,7 +93,7 @@ def get(self, key: str): # encode key key = quote(key, safe="") - _, res = self._request("/items/{}".format(key), "GET") + _, res = self._request(f"/items/{key}", "GET") return res or None def delete(self, key: str): @@ -105,7 +105,7 @@ def delete(self, key: str): # encode key key = quote(key, safe="") - self._request("/items/{}".format(key), "DELETE") + self._request(f"/items/{key}", "DELETE") return None def insert( @@ -131,7 +131,7 @@ def insert( if code == 201: return res elif code == 409: - raise Exception("Item with key '{4}' already exists".format(key)) + raise Exception(f"Item with key '{key}' already exists") def put( self, @@ -167,7 +167,8 @@ def put_many( expire_in: int = None, expire_at: typing.Union[int, float, datetime.datetime] = None, ): - assert len(items) <= 25, "We can't put more than 25 items at a time." + if len(items) > 25: + raise ValueError("Cannot put more than 25 items at a time") _items = [] for i in items: data = i @@ -264,17 +265,17 @@ def update( encoded_key = quote(key, safe="") code, _ = self._request( - "/items/{}".format(encoded_key), "PATCH", payload, content_type=JSON_MIME + f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME ) if code == 200: return None elif code == 404: - raise Exception("Key '{}' not found".format(key)) + raise Exception(f"Key '{key}' not found") def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): if expire_in and expire_at: - raise ValueError("both expire_in and expire_at provided") + raise ValueError("Both expire_in and expire_at provided") if not expire_in and not expire_at: return diff --git a/deta/drive.py b/deta/drive.py index 2d4db38..0e2f6fe 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -29,7 +29,7 @@ def iter_chunks(self, chunk_size: int = 1024): if not chunk: break yield chunk - + def iter_lines(self, chunk_size: int = 1024): while True: chunk = self.__stream.readline(chunk_size) @@ -41,7 +41,7 @@ def close(self): # close stream try: self.__stream.close() - except: + except Exception: pass @@ -53,7 +53,8 @@ def __init__( project_id: str = None, host: str = None, ): - assert name, "No Drive name provided" + if not name: + raise ValueError("Drive name not provided or empty") host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" super().__init__( @@ -73,7 +74,8 @@ def get(self, name: str): `name` is the name of the file. Returns a DriveStreamingBody. """ - assert name, "No name provided" + if not name: + raise ValueError("Drive name not provided or empty") _, res = self._request( f"/files/download?name={self._quote(name)}", "GET", stream=True ) @@ -86,8 +88,10 @@ def delete_many(self, names: typing.List[str]): `names` are the names of the files to be deleted. Returns a dict with 'deleted' and 'failed' files. """ - assert names, "Names is empty" - assert len(names) <= 1000, "More than 1000 names to delete" + if not names: + raise ValueError("Names is empty") + if len(names) > 1000: + raise ValueError("More than 1000 names to delete") _, res = self._request( "/files", "DELETE", {"names": names}, content_type=JSON_MIME ) @@ -98,7 +102,8 @@ def delete(self, name: str): `name` is the name of the file. Returns the name of the file deleted. """ - assert name, "Name not provided or empty" + if not name: + raise ValueError("Name not provided or empty") payload = self.delete_many([name]) failed = payload.get("failed") if failed: @@ -168,9 +173,12 @@ def put( `content_type` is the mime type of the file. Returns the name of the file. """ - assert name, "No name provided" - assert path or data, "No data or path provided" - assert not (path and data), "Both path and data provided" + if not name: + raise ValueError("Name not provided or empty") + if not path and not data: + raise ValueError("No data or path provided") + if path and data: + raise ValueError("Both path and data provided") # start upload upload_id = self._start_upload(name) @@ -181,7 +189,7 @@ def put( # upload chunks while True: chunk = content_stream.read(UPLOAD_CHUNK_SIZE) - ## eof stop the loop + # eof stop the loop if not chunk: self._finish_upload(name, upload_id) content_stream.close() diff --git a/deta/service.py b/deta/service.py index e1f1cdf..e3d5b24 100644 --- a/deta/service.py +++ b/deta/service.py @@ -20,7 +20,7 @@ def __init__( keep_alive: bool = True, ): self.project_key = project_key - self.base_path = "/v1/{0}/{1}".format(project_id, name) + self.base_path = f"/v1/{project_id}/{name}" self.host = host self.timeout = timeout self.keep_alive = keep_alive @@ -66,7 +66,7 @@ def _request( and self._is_socket_closed() ): self.client.close() - except: + except Exception: pass # send request @@ -81,16 +81,16 @@ def _request( res.read() if not self.keep_alive: self.client.close() - ## return None if not found + # return None if not found if status == 404: return status, None raise urllib.error.HTTPError(url, status, res.reason, res.headers, res.fp) - ## if stream return the response and client without reading and closing the client + # if stream return the response and client without reading and closing the client if stream: return status, res - ## return json if application/json + # return json if application/json payload = ( json.loads(res.read()) if JSON_MIME in res.getheader("content-type") diff --git a/deta/utils.py b/deta/utils.py index f94c598..fb827fc 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -2,15 +2,15 @@ def _get_project_key_id(project_key: str = None, project_id: str = None): - project_key = project_key or os.getenv("DETA_PROJECT_KEY", "") + project_key = project_key or os.getenv("DETA_PROJECT_KEY") if not project_key: - raise AssertionError("No project key defined") + raise ValueError("No project key defined") if not project_id: project_id = project_key.split("_")[0] if project_id == project_key: - raise AssertionError("Bad project key provided") + raise ValueError("Bad project key provided") - return project_key, project_id \ No newline at end of file + return project_key, project_id diff --git a/env.sample b/env.sample index 0f3d86f..74ff164 100644 --- a/env.sample +++ b/env.sample @@ -1,4 +1,5 @@ DETA_SDK_TEST_PROJECT_KEY= DETA_SDK_TEST_DRIVE_NAME=testdrive DETA_SDK_TEST_DRIVE_HOST=drive.deta.sh -DETA_SDK_TEST_BASE_NAME=testbase \ No newline at end of file +DETA_SDK_TEST_BASE_NAME=testbase +DETA_SDK_TEST_TTL_ATTRIBUTE=__expires \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c0c3436..04043b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ twine # Requirements for testing [async] extra aiohttp pytest-asyncio -pytest \ No newline at end of file +pytest diff --git a/tests/test_async.py b/tests/test_async.py index 17fbf16..72af8c3 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv load_dotenv() -except: +except ImportError: pass pytestmark = pytest.mark.asyncio @@ -18,7 +18,7 @@ PROJECT_KEY = os.getenv("DETA_SDK_TEST_PROJECT_KEY") BASE_NAME = os.getenv("DETA_SDK_TEST_BASE_NAME") -BASE_TEST_TTL_ATTRIBUTE = os.getenv("DETA_SDK_TEST_TTL_ATTRIBUTE") +BASE_TEST_TTL_ATTRIBUTE = os.getenv("DETA_SDK_TEST_TTL_ATTRIBUTE") or "__expires" @pytest.fixture() @@ -105,15 +105,15 @@ async def test_get(db, items): assert resp == items[0] resp = await db.get("key_does_not_exist") - assert resp == None + assert resp is None async def test_delete(db, items): resp = await db.delete(items[0]["key"]) - assert resp == None + assert resp is None resp = await db.delete("key_does_not_exist") - assert resp == None + assert resp is None async def test_fetch(db, items): @@ -186,7 +186,7 @@ async def test_fetch(db, items): async def test_update(db, items): resp = await db.update({"value.name": "spongebob"}, "existing4") - assert resp == None + assert resp is None resp = await db.get("existing4") expectedItem = {"key": "existing4", "value": {"name": "spongebob"}} @@ -194,7 +194,7 @@ async def test_update(db, items): resp = await db.update({"value.name": db.util.trim(), "value.age": 32}, "existing4") - assert resp == None + assert resp is None expectedItem = {"key": "existing4", "value": {"age": 32}} resp = await db.get("existing4") @@ -207,13 +207,13 @@ async def test_update(db, items): }, "%@#//#!#)#$_", ) - assert resp == None + assert resp is None resp = await db.update( {"list": db.util.prepend("x"), "value": db.util.increment(2)}, "%@#//#!#)#$_", ) - assert resp == None + assert resp is None expectedItem = {"key": "%@#//#!#)#$_", "list": ["x", "a", "b", "c"], "value": 3} resp = await db.get("%@#//#!#)#$_") assert resp == expectedItem @@ -258,7 +258,7 @@ def get_expire_in(expire_in): async def test_ttl(db, items): item1 = items[0] expire_in = 300 - expire_at = datetime.datetime(2022, 3, 1, 12, 30, 30) + expire_at = datetime.datetime.now() + datetime.timedelta(seconds=300) delta = 2 # allow time delta of 2 seconds test_cases = [ { diff --git a/tests/test_sync.py b/tests/test_sync.py index c4070b0..a83c95f 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv load_dotenv() -except: +except ImportError: pass @@ -382,7 +382,7 @@ def get_expire_in(self, expire_in): def test_ttl(self): expire_in = 300 - expire_at = datetime.datetime(2022, 3, 1, 12, 30, 30) + expire_at = datetime.datetime.now() + datetime.timedelta(seconds=300) delta = 2 # allow time delta of 2 seconds test_cases = [ { From 29eedd4a48a68db582607994e19472e6da720eaa Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Sun, 27 Mar 2022 01:28:52 -0400 Subject: [PATCH 02/12] Format using black --- deta/__init__.py | 1 + deta/_async/client.py | 20 +++++--------------- deta/base.py | 26 ++++++-------------------- deta/drive.py | 8 ++------ deta/service.py | 12 +++--------- tests/test_async.py | 12 +++--------- tests/test_sync.py | 36 +++++++++--------------------------- 7 files changed, 29 insertions(+), 86 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index 129e65b..d093799 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -44,6 +44,7 @@ def Base(self, name: str, host: str = None): def AsyncBase(self, name: str, host: str = None): from ._async.client import _AsyncBase + return _AsyncBase(name, self.project_key, self.project_id, host) def Drive(self, name: str, host: str = None): diff --git a/deta/_async/client.py b/deta/_async/client.py index 26547a2..698ebc4 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -71,9 +71,7 @@ async def insert( data["key"] = key insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - async with self._session.post( - f"{self._base_url}/items", json={"item": data} - ) as resp: + async with self._session.post(f"{self._base_url}/items", json={"item": data}) as resp: return await resp.json() async def put( @@ -93,9 +91,7 @@ async def put( data["key"] = key insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - async with self._session.put( - f"{self._base_url}/items", json={"items": [data]} - ) as resp: + async with self._session.put(f"{self._base_url}/items", json={"items": [data]}) as resp: if resp.status == 207: resp_json = await resp.json() return resp_json["processed"]["items"][0] @@ -117,14 +113,10 @@ async def put_many( data = i if not isinstance(i, dict): data = {"value": i} - insert_ttl( - data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at - ) + insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) _items.append(data) - async with self._session.put( - f"{self._base_url}/items", json={"items": _items} - ) as resp: + async with self._session.put(f"{self._base_url}/items", json={"items": _items}) as resp: return await resp.json() async def fetch( @@ -144,9 +136,7 @@ async def fetch( async with self._session.post(f"{self._base_url}/query", json=payload) as resp: resp_json = await resp.json() paging = resp_json.get("paging") - return FetchResponse( - paging.get("size"), paging.get("last"), resp_json.get("items") - ) + return FetchResponse(paging.get("size"), paging.get("last"), resp_json.get("items")) async def update( self, diff --git a/deta/base.py b/deta/base.py index ac5e315..02ae0af 100644 --- a/deta/base.py +++ b/deta/base.py @@ -29,11 +29,7 @@ def items(self): return self._items def __eq__(self, other): - return ( - self.count == other.count - and self.last == other.last - and self.items == other.items - ) + return self.count == other.count and self.last == other.last and self.items == other.items class Util: @@ -125,9 +121,7 @@ def insert( data["key"] = key insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - code, res = self._request( - "/items", "POST", {"item": data}, content_type=JSON_MIME - ) + code, res = self._request("/items", "POST", {"item": data}, content_type=JSON_MIME) if code == 201: return res elif code == 409: @@ -155,9 +149,7 @@ def put( data["key"] = key insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - code, res = self._request( - "/items", "PUT", {"items": [data]}, content_type=JSON_MIME - ) + code, res = self._request("/items", "PUT", {"items": [data]}, content_type=JSON_MIME) return res["processed"]["items"][0] if res and code == 207 else None def put_many( @@ -174,14 +166,10 @@ def put_many( data = i if not isinstance(i, dict): data = {"value": i} - insert_ttl( - data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at - ) + insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) _items.append(data) - _, res = self._request( - "/items", "PUT", {"items": _items}, content_type=JSON_MIME - ) + _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res def _fetch( @@ -264,9 +252,7 @@ def update( ) encoded_key = quote(key, safe="") - code, _ = self._request( - f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME - ) + code, _ = self._request(f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME) if code == 200: return None elif code == 404: diff --git a/deta/drive.py b/deta/drive.py index 0e2f6fe..ee537d7 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -76,9 +76,7 @@ def get(self, name: str): """ if not name: raise ValueError("Drive name not provided or empty") - _, res = self._request( - f"/files/download?name={self._quote(name)}", "GET", stream=True - ) + _, res = self._request(f"/files/download?name={self._quote(name)}", "GET", stream=True) if res: return DriveStreamingBody(res) return None @@ -92,9 +90,7 @@ def delete_many(self, names: typing.List[str]): raise ValueError("Names is empty") if len(names) > 1000: raise ValueError("More than 1000 names to delete") - _, res = self._request( - "/files", "DELETE", {"names": names}, content_type=JSON_MIME - ) + _, res = self._request("/files", "DELETE", {"names": names}, content_type=JSON_MIME) return res def delete(self, name: str): diff --git a/deta/service.py b/deta/service.py index e3d5b24..4361c73 100644 --- a/deta/service.py +++ b/deta/service.py @@ -24,9 +24,7 @@ def __init__( self.host = host self.timeout = timeout self.keep_alive = keep_alive - self.client = ( - http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None - ) + self.client = http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None def _is_socket_closed(self): if not self.client.sock: @@ -92,9 +90,7 @@ def _request( # return json if application/json payload = ( - json.loads(res.read()) - if JSON_MIME in res.getheader("content-type") - else res.read() + json.loads(res.read()) if JSON_MIME in res.getheader("content-type") else res.read() ) if not self.keep_alive: @@ -113,9 +109,7 @@ def _send_request_with_retry( while retry > 0: try: if not self.keep_alive or reinitializeConnection: - self.client = http.client.HTTPSConnection( - host=self.host, timeout=self.timeout - ) + self.client = http.client.HTTPSConnection(host=self.host, timeout=self.timeout) self.client.request( method, diff --git a/tests/test_async.py b/tests/test_async.py index 72af8c3..ae5d9f2 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -149,9 +149,7 @@ async def test_fetch(db, items): ) assert res3 == expectedItem - res4 = await db.fetch( - [{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2" - ) + res4 = await db.fetch([{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2") expectedItem = FetchResponse( 1, None, @@ -323,9 +321,7 @@ async def test_ttl(db, items): # update # only if one of expire_in or expire_at if cexp_in or cexp_at: - await db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + await db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) got = await db.get(item.get("key")) assert abs(expected - got.get(BASE_TEST_TTL_ATTRIBUTE)) <= cdelta else: @@ -336,6 +332,4 @@ async def test_ttl(db, items): with pytest.raises(error): await db.insert(item, expire_in=cexp_in, expire_at=cexp_at) with pytest.raises(error): - await db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + await db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) diff --git a/tests/test_sync.py b/tests/test_sync.py index a83c95f..a317ec6 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -283,9 +283,7 @@ def test_fetch(self): ) self.assertEqual(res3, expectedItem) - res4 = self.db.fetch( - [{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2" - ) + res4 = self.db.fetch([{"value?gt": 6}, {"value?lt": 50}], limit=2, last="existing2") expectedItem = FetchResponse( 1, None, @@ -323,9 +321,7 @@ def test_update(self): self.assertEqual(self.db.get("existing4"), expectedItem) self.assertIsNone( - self.db.update( - {"value.name": self.db.util.trim(), "value.age": 32}, "existing4" - ) + self.db.update({"value.name": self.db.util.trim(), "value.age": 32}, "existing4") ) expectedItem = {"key": "existing4", "value": {"age": 32}} self.assertEqual(self.db.get("existing4"), expectedItem) @@ -438,43 +434,29 @@ def test_ttl(self): # put self.db.put(item, expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # insert # need to udpate key as insert does not allow pre existing key item["key"] = "".join(random.choices(string.ascii_lowercase, k=6)) self.db.insert(item, expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # put many self.db.put_many([item], expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) # update # only if one of expire_in or expire_at if cexp_in or cexp_at: - self.db.update( - None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at - ) + self.db.update(None, item.get("key"), expire_in=cexp_in, expire_at=cexp_at) got = self.db.get(item.get("key")) - self.assertAlmostEqual( - expected, got.get(self.ttl_attribute), delta=cdelta - ) + self.assertAlmostEqual(expected, got.get(self.ttl_attribute), delta=cdelta) else: - self.assertRaises( - error, self.db.put, item, expire_in=cexp_in, expire_at=cexp_at - ) - self.assertRaises( - error, self.db.insert, item, expire_in=cexp_in, expire_at=cexp_at - ) + self.assertRaises(error, self.db.put, item, expire_in=cexp_in, expire_at=cexp_at) + self.assertRaises(error, self.db.insert, item, expire_in=cexp_in, expire_at=cexp_at) self.assertRaises( error, self.db.put_many, From c5157ca10d38a34ed4fdbb0f9bf44c00b1492415 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Sun, 27 Mar 2022 01:59:02 -0400 Subject: [PATCH 03/12] Update __init__.py and env.sample --- deta/__init__.py | 8 ++++---- env.sample | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index d093799..01bb4a8 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -23,14 +23,14 @@ __version__ = "1.1.0" -def Base(name: str): +def Base(name: str, host: str = None): project_key, project_id = _get_project_key_id() - return _Base(name, project_key, project_id) + return _Base(name, project_key, project_id, host) -def Drive(name: str): +def Drive(name: str, host: str = None): project_key, project_id = _get_project_key_id() - return _Drive(name, project_key, project_id) + return _Drive(name, project_key, project_id, host) class Deta: diff --git a/env.sample b/env.sample index 74ff164..dc4ffc5 100644 --- a/env.sample +++ b/env.sample @@ -2,4 +2,4 @@ DETA_SDK_TEST_PROJECT_KEY= DETA_SDK_TEST_DRIVE_NAME=testdrive DETA_SDK_TEST_DRIVE_HOST=drive.deta.sh DETA_SDK_TEST_BASE_NAME=testbase -DETA_SDK_TEST_TTL_ATTRIBUTE=__expires \ No newline at end of file +DETA_SDK_TEST_TTL_ATTRIBUTE=__expires From 740077abac2e435e8a623b4de28a904b630f1e6f Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Mon, 30 May 2022 12:12:11 -0400 Subject: [PATCH 04/12] Update type hints and general fixes --- deta/__init__.py | 58 ++++----- deta/_async/client.py | 198 ++++++++++++++++++++--------- deta/base.py | 286 +++++++++++++++++++++++++----------------- deta/drive.py | 192 +++++++++++++++------------- deta/service.py | 16 +-- deta/utils.py | 9 +- 6 files changed, 461 insertions(+), 298 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index 01bb4a8..9ed7fe3 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -1,22 +1,18 @@ import os +import json +import typing import urllib.error import urllib.request -import json from .base import _Base +from ._async.client import _AsyncBase from .drive import _Drive from .utils import _get_project_key_id - try: - from detalib.app import App + from detalib.app import App # type: ignore app = App() -except Exception: - pass - -try: - from ._async.client import AsyncBase except ImportError: pass @@ -28,6 +24,11 @@ def Base(name: str, host: str = None): return _Base(name, project_key, project_id, host) +def AsyncBase(name: str, host: str = None): + project_key, project_id = _get_project_key_id() + return _AsyncBase(name, project_key, project_id, host) + + def Drive(name: str, host: str = None): project_key, project_id = _get_project_key_id() return _Drive(name, project_key, project_id, host) @@ -35,44 +36,45 @@ def Drive(name: str, host: str = None): class Deta: def __init__(self, project_key: str = None, *, project_id: str = None): - project_key, project_id = _get_project_key_id(project_key, project_id) - self.project_key = project_key - self.project_id = project_id + self.project_key, self.project_id = _get_project_key_id(project_key, project_id) def Base(self, name: str, host: str = None): return _Base(name, self.project_key, self.project_id, host) def AsyncBase(self, name: str, host: str = None): - from ._async.client import _AsyncBase - return _AsyncBase(name, self.project_key, self.project_id, host) def Drive(self, name: str, host: str = None): - return _Drive( - name=name, - project_key=self.project_key, - project_id=self.project_id, - host=host, - ) - - def send_email(self, to, subject, message, charset="UTF-8"): - return send_email(to, subject, message, charset) - - -def send_email(to, subject, message, charset="UTF-8"): + return _Drive(name, self.project_key, self.project_id, host) + + def send_email( + self, + to: typing.Union[str, typing.Sequence[str]], + subject: str, + message: str, + charset: str = "utf-8", + ): + send_email(to, subject, message, charset) + + +def send_email( + to: typing.Union[str, typing.Sequence[str]], + subject: str, + message: str, + charset: str = "utf-8", +): pid = os.getenv("AWS_LAMBDA_FUNCTION_NAME") url = os.getenv("DETA_MAILER_URL") api_key = os.getenv("DETA_PROJECT_KEY") endpoint = f"{url}/mail/{pid}" - to = to if isinstance(to, list) else [to] + to = list(to) if isinstance(to, (list, tuple, set)) else [to] data = { "to": to, "subject": subject, "message": message, "charset": charset, } - headers = {"X-API-Key": api_key} req = urllib.request.Request(endpoint, json.dumps(data).encode("utf-8"), headers) @@ -82,4 +84,4 @@ def send_email(to, subject, message, charset="UTF-8"): if resp.getcode() != 200: raise Exception(resp.read().decode("utf-8")) except urllib.error.URLError as e: - raise Exception(e.reason) + raise Exception(e.reason) from e diff --git a/deta/_async/client.py b/deta/_async/client.py index 698ebc4..11235a9 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -1,29 +1,31 @@ -import typing - -import datetime import os -import aiohttp +import datetime +import typing from urllib.parse import quote -from deta.utils import _get_project_key_id -from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTTRIBUTE +try: + import aiohttp +except ImportError: + has_aiohttp = False +else: + has_aiohttp = True - -def AsyncBase(name: str): - project_key, project_id = _get_project_key_id() - return _AsyncBase(name, project_key, project_id) +from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTRIBUTE class _AsyncBase: def __init__(self, name: str, project_key: str, project_id: str, host: str = None): - if not project_key: - raise ValueError("Base name not provided or empty") + if not has_aiohttp: + raise RuntimeError("aiohttp library is required for async support") + + if not name: + raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" self._base_url = f"https://{host}/v1/{project_id}/{name}" self.util = Util() - self.__ttl_attribute = BASE_TTL_ATTTRIBUTE + self._ttl_attribute = BASE_TTL_ATTRIBUTE self._session = aiohttp.ClientSession( headers={ @@ -33,12 +35,11 @@ def __init__(self, name: str, project_key: str, project_id: str, host: str = Non raise_for_status=True, ) - async def close(self) -> None: + async def close(self): await self._session.close() - async def get(self, key: str): + async def get(self, key: str) -> dict: key = quote(key, safe="") - try: async with self._session.get(f"{self._base_url}/items/{key}") as resp: return await resp.json() @@ -50,70 +51,123 @@ async def get(self, key: str): async def delete(self, key: str): key = quote(key, safe="") - async with self._session.delete(f"{self._base_url}/items/{key}"): return + @typing.overload + async def insert( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + ) -> dict: + ... + + @typing.overload async def insert( self, data: typing.Union[dict, list, str, int, bool], key: str = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + + @typing.overload + async def insert( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def insert(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) async with self._session.post(f"{self._base_url}/items", json={"item": data}) as resp: return await resp.json() + @typing.overload + async def put( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + ) -> dict: + ... + + @typing.overload async def put( self, data: typing.Union[dict, list, str, int, bool], key: str = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + + @typing.overload + async def put( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def put(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) async with self._session.put(f"{self._base_url}/items", json={"items": [data]}) as resp: - if resp.status == 207: - resp_json = await resp.json() - return resp_json["processed"]["items"][0] - else: + if resp.status != 207: return None + resp_json = await resp.json() + return resp_json["processed"]["items"][0] + @typing.overload async def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + ) -> dict: + ... + + @typing.overload + async def put_many( + self, + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): + expire_in: int, + ) -> dict: + ... + + @typing.overload + async def put_many( + self, + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + async def put_many(self, items, *, expire_in=None, expire_at=None): if len(items) > 25: - raise ValueError("Cannot put more than 25 items at a time") + raise ValueError("cannot put more than 25 items at a time") _items = [] for i in items: data = i if not isinstance(i, dict): data = {"value": i} - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) _items.append(data) async with self._session.put(f"{self._base_url}/items", json={"items": _items}) as resp: @@ -126,28 +180,50 @@ async def fetch( limit: int = 1000, last: str = None, ): - payload = {} + payload = { + "limit": limit, + "last": last if not isinstance(last, bool) else None, + } + if query: payload["query"] = query if isinstance(query, list) else [query] - if limit: - payload["limit"] = limit - if last: - payload["last"] = last + async with self._session.post(f"{self._base_url}/query", json=payload) as resp: resp_json = await resp.json() paging = resp_json.get("paging") return FetchResponse(paging.get("size"), paging.get("last"), resp_json.get("items")) + @typing.overload + async def update( + self, + updates: typing.Mapping, + key: str, + ): + ... + + @typing.overload + async def update( + self, + updates: typing.Mapping, + key: str, + *, + expire_in: int, + ): + ... + + @typing.overload async def update( self, - updates: dict, + updates: typing.Mapping, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_at: typing.Union[int, float, datetime.datetime], ): - if key == "": - raise ValueError("Key is empty") + ... + + async def update(self, updates, key, *, expire_in=None, expire_at=None): + if not key: + raise ValueError("parameter 'key' must be a non-empty string") payload = { "set": {}, @@ -156,29 +232,29 @@ async def update( "prepend": {}, "delete": [], } + if updates: for attr, value in updates.items(): if isinstance(value, Util.Trim): payload["delete"].append(attr) elif isinstance(value, Util.Increment): - payload["increment"][attr] = value.val + payload["increment"][attr] = value.value elif isinstance(value, Util.Append): - payload["append"][attr] = value.val + payload["append"][attr] = value.value elif isinstance(value, Util.Prepend): - payload["prepend"][attr] = value.val + payload["prepend"][attr] = value.value else: payload["set"][attr] = value if not payload: - raise ValueError("Provide at least one update action.") + raise ValueError("must provide at least one update action") insert_ttl( payload["set"], - self.__ttl_attribute, + self._ttl_attribute, expire_in=expire_in, expire_at=expire_at, ) - key = quote(key, safe="") - - await self._session.patch(f"{self._base_url}/items/{key}", json=payload) + encoded_key = quote(key, safe="") + await self._session.patch(f"{self._base_url}/items/{encoded_key}", json=payload) diff --git a/deta/base.py b/deta/base.py index 02ae0af..0db86e1 100644 --- a/deta/base.py +++ b/deta/base.py @@ -7,26 +7,14 @@ # timeout for Base service in seconds BASE_SERVICE_TIMEOUT = 300 -BASE_TTL_ATTTRIBUTE = "__expires" +BASE_TTL_ATTRIBUTE = "__expires" class FetchResponse: - def __init__(self, count=0, last=None, items=[]): - self._count = count - self._last = last - self._items = items - - @property - def count(self): - return self._count - - @property - def last(self): - return self._last - - @property - def items(self): - return self._items + def __init__(self, count: int = 0, last: str = None, items: list = None): + self.count = count + self.last = last + self.items = items if items is not None else [] def __eq__(self, other): return self.count == other.count and self.last == other.last and self.items == other.items @@ -37,27 +25,21 @@ class Trim: pass class Increment: - def __init__(self, value=None): - self.val = value - if not value: - self.val = 1 + def __init__(self, value=1): + self.value = value class Append: def __init__(self, value): - self.val = value - if not isinstance(value, list): - self.val = [value] + self.value = value if isinstance(value, list) else [value] class Prepend: def __init__(self, value): - self.val = value - if not isinstance(value, list): - self.val = [value] + self.value = value if isinstance(value, list) else [value] def trim(self): return self.Trim() - def increment(self, value: typing.Union[int, float] = None): + def increment(self, value: typing.Union[int, float] = 1): return self.Increment(value) def append(self, value: typing.Union[dict, list, str, int, float, bool]): @@ -70,7 +52,7 @@ def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): class _Base(_Service): def __init__(self, name: str, project_key: str, project_id: str, host: str = None): if not name: - raise ValueError("Base name not provided or empty") + raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" super().__init__( @@ -80,115 +62,168 @@ def __init__(self, name: str, project_key: str, project_id: str, host: str = Non name=name, timeout=BASE_SERVICE_TIMEOUT, ) - self.__ttl_attribute = "__expires" + self._ttl_attribute = BASE_TTL_ATTRIBUTE self.util = Util() - def get(self, key: str): - if key == "": - raise ValueError("Key is empty") + def get(self, key: str) -> dict: + if not key: + raise ValueError("parameter 'key' must be a non-empty string") - # encode key key = quote(key, safe="") _, res = self._request(f"/items/{key}", "GET") - return res or None + return res def delete(self, key: str): """Delete an item from the database - key: the key of item to be deleted + + Args: + key: The key of item to be deleted. """ - if key == "": - raise ValueError("Key is empty") + if not key: + raise ValueError("parameter 'key' must be a non-empty string") - # encode key key = quote(key, safe="") self._request(f"/items/{key}", "DELETE") - return None + @typing.overload + def insert( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + ) -> dict: + ... + + @typing.overload def insert( self, data: typing.Union[dict, list, str, int, bool], key: str = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + expire_in: int, + ) -> dict: + ... + @typing.overload + def insert( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + def insert(self, data, key=None, *, expire_in=None, expire_at=None): + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) code, res = self._request("/items", "POST", {"item": data}, content_type=JSON_MIME) + if code == 201: return res elif code == 409: - raise Exception(f"Item with key '{key}' already exists") + raise ValueError(f"item with key '{key}' already exists") + @typing.overload def put( self, data: typing.Union[dict, list, str, int, bool], key: str = None, - *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): - """store (put) an item in the database. Overrides an item if key already exists. - `key` could be provided as function argument or a field in the data dict. - If `key` is not provided, the server will generate a random 12 chars key. - """ + ) -> dict: + ... - if not isinstance(data, dict): - data = {"value": data} - else: - data = data.copy() + @typing.overload + def put( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + *, + expire_in: int, + ) -> dict: + ... + @typing.overload + def put( + self, + data: typing.Union[dict, list, str, int, bool], + key: str = None, + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + def put(self, data, key=None, *, expire_in=None, expire_at=None): + """Store (put) an item in the database. Overrides an item if key already exists. + `key` could be provided as an argument or a field in the data dict. + If `key` is not provided, the server will generate a random 12-character key. + """ + data = data.copy() if isinstance(data, dict) else {"value": data} if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) code, res = self._request("/items", "PUT", {"items": [data]}, content_type=JSON_MIME) return res["processed"]["items"][0] if res and code == 207 else None + @typing.overload def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + ) -> dict: + ... + + @typing.overload + def put_many( + self, + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, - ): + expire_in: int, + ) -> dict: + ... + + @typing.overload + def put_many( + self, + items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + *, + expire_at: typing.Union[int, float, datetime.datetime], + ) -> dict: + ... + + def put_many(self, items, *, expire_in=None, expire_at=None): if len(items) > 25: - raise ValueError("Cannot put more than 25 items at a time") + raise ValueError("cannot put more than 25 items at a time") + _items = [] - for i in items: - data = i - if not isinstance(i, dict): - data = {"value": i} - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + for item in items: + data = item + if not isinstance(item, dict): + data = {"value": item} + insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) _items.append(data) _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res - def _fetch( - self, - query: typing.Union[dict, list] = None, - buffer: int = None, - last: str = None, - ) -> typing.Optional[typing.Tuple[int, list]]: - """This is where actual fetch happens.""" - payload = { - "limit": buffer, - "last": last if not isinstance(last, bool) else None, - } + # def _fetch( + # self, + # query: typing.Union[dict, list] = None, + # limit: int = None, + # last: str = None, + # ): + # """This is where actual fetch happens.""" + # payload = { + # "limit": limit, + # "last": last if not isinstance(last, bool) else None, + # } - if query: - payload["query"] = query if isinstance(query, list) else [query] + # if query: + # payload["query"] = query if isinstance(query, list) else [query] - code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) - return code, res + # code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) + # return code, res def fetch( self, @@ -197,32 +232,56 @@ def fetch( limit: int = 1000, last: str = None, ): + """Fetch items from the database. `query` is an optional filter or list of filters. + Without a filter, it will return the whole db. """ - fetch items from the database. - `query` is an optional filter or list of filters. Without filter, it will return the whole db. - """ - _, res = self._fetch(query, limit, last) + payload = { + "limit": limit, + "last": last if not isinstance(last, bool) else None, + } - paging = res.get("paging") + if query: + payload["query"] = query if isinstance(query, list) else [query] + _, res = self._request("/query", "POST", payload, content_type=JSON_MIME) + paging = res.get("paging") return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) + @typing.overload def update( self, - updates: dict, + updates: typing.Mapping, + key: str, + ): + ... + + @typing.overload + def update( + self, + updates: typing.Mapping, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: int, ): - """ - update an item in the database - `updates` specifies the attribute names and values to update,add or remove - `key` is the kye of the item to be updated - """ + ... + + @typing.overload + def update( + self, + updates: typing.Mapping, + key: str, + *, + expire_at: typing.Union[int, float, datetime.datetime], + ): + ... - if key == "": - raise ValueError("Key is empty") + def update(self, updates, key, *, expire_in=None, expire_at=None): + """Update an item in the database. + `updates` specifies the attribute names and values to update, add or remove. + `key` is the key of the item to be updated. + """ + if not key: + raise ValueError("parameter 'key' must be a non-empty string") payload = { "set": {}, @@ -231,37 +290,37 @@ def update( "prepend": {}, "delete": [], } + if updates: for attr, value in updates.items(): if isinstance(value, Util.Trim): payload["delete"].append(attr) elif isinstance(value, Util.Increment): - payload["increment"][attr] = value.val + payload["increment"][attr] = value.value elif isinstance(value, Util.Append): - payload["append"][attr] = value.val + payload["append"][attr] = value.value elif isinstance(value, Util.Prepend): - payload["prepend"][attr] = value.val + payload["prepend"][attr] = value.value else: payload["set"][attr] = value insert_ttl( payload["set"], - self.__ttl_attribute, + self._ttl_attribute, expire_in=expire_in, expire_at=expire_at, ) encoded_key = quote(key, safe="") code, _ = self._request(f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME) - if code == 200: - return None - elif code == 404: - raise Exception(f"Key '{key}' not found") + if code == 404: + raise ValueError(f"key '{key}' not found") def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): if expire_in and expire_at: - raise ValueError("Both expire_in and expire_at provided") + raise ValueError("both 'expire_in' and 'expire_at' provided") + if not expire_in and not expire_at: return @@ -270,8 +329,7 @@ def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): if isinstance(expire_at, datetime.datetime): expire_at = expire_at.replace(microsecond=0).timestamp() - - if not isinstance(expire_at, (int, float)): - raise TypeError("expire_at should one one of int, float or datetime") + elif not isinstance(expire_at, (int, float)): + raise TypeError("'expire_at' must be of type 'int', 'float' or 'datetime'") item[ttl_attribute] = int(expire_at) diff --git a/deta/drive.py b/deta/drive.py index ee537d7..1f513ad 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -14,33 +14,32 @@ class DriveStreamingBody: def __init__(self, res: BufferedIOBase): - self.__stream = res + self._stream = res @property def closed(self): - return self.__stream.closed + return self._stream.closed - def read(self, size: int = None): - return self.__stream.read(size) + def read(self, size: int = None) -> bytes: + return self._stream.read(size) - def iter_chunks(self, chunk_size: int = 1024): + def iter_chunks(self, chunk_size: int = 1024) -> typing.Iterator[bytes]: while True: - chunk = self.__stream.read(chunk_size) + chunk = self._stream.read(chunk_size) if not chunk: break yield chunk - def iter_lines(self, chunk_size: int = 1024): + def iter_lines(self, chunk_size: int = 1024) -> typing.Iterator[bytes]: while True: - chunk = self.__stream.readline(chunk_size) + chunk = self._stream.readline(chunk_size) if not chunk: break yield chunk def close(self): - # close stream try: - self.__stream.close() + self._stream.close() except Exception: pass @@ -48,15 +47,15 @@ def close(self): class _Drive(_Service): def __init__( self, - name: str = None, - project_key: str = None, - project_id: str = None, + name: str, + project_key: str, + project_id: str, host: str = None, ): if not name: - raise ValueError("Drive name not provided or empty") - host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" + raise ValueError("parameter 'name' must be a non-empty string") + host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" super().__init__( project_key=project_key, project_id=project_id, @@ -66,103 +65,71 @@ def __init__( keep_alive=False, ) - def _quote(self, param: str): + def _quote(self, param: str) -> str: return quote_plus(param) - def get(self, name: str): - """Get/Download a file from drive. + def get(self, name: str) -> DriveStreamingBody: + """Download a file from drive. `name` is the name of the file. Returns a DriveStreamingBody. """ if not name: - raise ValueError("Drive name not provided or empty") + raise ValueError("parameter 'name' must be a non-empty string") + _, res = self._request(f"/files/download?name={self._quote(name)}", "GET", stream=True) if res: return DriveStreamingBody(res) - return None - - def delete_many(self, names: typing.List[str]): - """Delete many files from drive in single request. - `names` are the names of the files to be deleted. - Returns a dict with 'deleted' and 'failed' files. - """ - if not names: - raise ValueError("Names is empty") - if len(names) > 1000: - raise ValueError("More than 1000 names to delete") - _, res = self._request("/files", "DELETE", {"names": names}, content_type=JSON_MIME) - return res - def delete(self, name: str): + def delete(self, name: str) -> str: """Delete a file from drive. `name` is the name of the file. Returns the name of the file deleted. """ if not name: - raise ValueError("Name not provided or empty") + raise ValueError("parameter 'name' must be a non-empty string") + payload = self.delete_many([name]) failed = payload.get("failed") if failed: - raise Exception(f"Failed to delete '{name}':{failed[name]}") + raise Exception(f"failed to delete '{name}': {failed[name]}") + return name - def list(self, limit: int = 1000, prefix: str = None, last: str = None): - """List file names from drive. - `limit` is the limit of number of file names to get, defaults to 1000. - `prefix` is the prefix of file names. - `last` is the last name seen in the a previous paginated response. - Returns a dict with 'paging' and 'names'. + def delete_many(self, names: typing.Sequence[str]) -> dict: + """Delete many files from drive in single request. + `names` are the names of the files to be deleted. + Returns a dict with 'deleted' and 'failed' files. """ - url = f"/files?limit={limit}" - if prefix: - url += f"&prefix={prefix}" - if last: - url += f"&last={last}" - _, res = self._request(url, "GET") - return res - - def _start_upload(self, name: str): - _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") - return res["upload_id"] + if not names: + raise ValueError("parameter 'names' must be a non-empty list") - def _finish_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") + if len(names) > 1000: + raise ValueError("cannot delete more than 1000 items") - def _abort_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + _, res = self._request("/files", "DELETE", {"names": names}, content_type=JSON_MIME) + return res - def _upload_part( + @typing.overload + def put( self, name: str, - chunk: bytes, - upload_id: str, - part: int, - content_type: str = None, - ): - self._request( - f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", - "POST", - data=chunk, - content_type=content_type, - ) - - def _get_content_stream( - self, data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] - ): - if isinstance(data, str): - return StringIO(data) - elif isinstance(data, bytes): - return BytesIO(data) - return data + data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + *, + content_type: str, + ) -> str: + ... + @typing.overload def put( self, name: str, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] = None, *, - path: str = None, - content_type: str = None, + path: str, + content_type: str, ) -> str: + ... + + def put(self, name, data=None, *, path=None, content_type=None): """Put a file in drive. `name` is the name of the file. `data` is the data to be put. @@ -170,15 +137,14 @@ def put( Returns the name of the file. """ if not name: - raise ValueError("Name not provided or empty") + raise ValueError("parameter 'name' must be a non-empty string") if not path and not data: - raise ValueError("No data or path provided") + raise ValueError("must provide either 'data' or 'path'") if path and data: - raise ValueError("Both path and data provided") + raise ValueError("'data' and 'path' are exclusive parameters") # start upload upload_id = self._start_upload(name) - content_stream = open(path, "rb") if path else self._get_content_stream(data) part = 1 @@ -201,3 +167,61 @@ def put( self._abort_upload(name, upload_id) content_stream.close() raise e + + def list( + self, + limit: int = 1000, + prefix: str = None, + last: str = None, + ) -> dict: + """List file names from drive. + `limit` is the number of file names to get, defaults to 1000. + `prefix` is the prefix of file names. + `last` is the last name seen in a previous paginated response. + Returns a dict with 'paging' and 'names'. + """ + url = f"/files?limit={limit}" + if prefix: + url += f"&prefix={prefix}" + + if last: + url += f"&last={last}" + + _, res = self._request(url, "GET") + return res + + def _start_upload(self, name: str): + _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") + return res["upload_id"] + + def _finish_upload(self, name: str, upload_id: str): + self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") + + def _abort_upload(self, name: str, upload_id: str): + self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + + def _upload_part( + self, + name: str, + chunk: bytes, + upload_id: str, + part: int, + content_type: str = None, + ): + self._request( + f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", + "POST", + data=chunk, + content_type=content_type, + ) + + def _get_content_stream( + self, + data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + ): + if isinstance(data, str): + return StringIO(data) + elif isinstance(data, bytes): + return BytesIO(data) + else: + return data diff --git a/deta/service.py b/deta/service.py index 4361c73..b44331c 100644 --- a/deta/service.py +++ b/deta/service.py @@ -1,7 +1,7 @@ -import http.client import os import json import socket +import http.client import struct import typing import urllib.error @@ -43,10 +43,10 @@ def _request( path: str, method: str, data: typing.Union[str, bytes, dict] = None, - headers: dict = None, + headers: typing.Mapping = None, content_type: str = None, stream: bool = False, - ): + ) -> typing.Tuple[int, dict]: url = self.base_path + path headers = headers or {} headers["X-Api-Key"] = self.project_key @@ -101,14 +101,14 @@ def _send_request_with_retry( self, method: str, url: str, - headers: dict = None, + headers: typing.Mapping[str, str] = None, body: typing.Union[str, bytes, dict] = None, - retry=2, # try at least twice to regain a new connection + retry: int = 2, # try at least twice to regain a new connection ): - reinitializeConnection = False + reinit_connection = False while retry > 0: try: - if not self.keep_alive or reinitializeConnection: + if not self.keep_alive or reinit_connection: self.client = http.client.HTTPSConnection(host=self.host, timeout=self.timeout) self.client.request( @@ -120,5 +120,5 @@ def _send_request_with_retry( res = self.client.getresponse() return res except http.client.RemoteDisconnected: - reinitializeConnection = True + reinit_connection = True retry -= 1 diff --git a/deta/utils.py b/deta/utils.py index fb827fc..6290a14 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,16 +1,19 @@ import os -def _get_project_key_id(project_key: str = None, project_id: str = None): +def _get_project_key_id( + project_key: str = None, + project_id: str = None, +): project_key = project_key or os.getenv("DETA_PROJECT_KEY") if not project_key: - raise ValueError("No project key defined") + raise ValueError("no project key defined") if not project_id: project_id = project_key.split("_")[0] if project_id == project_key: - raise ValueError("Bad project key provided") + raise ValueError("bad project key provided") return project_key, project_id From 3bed4da1c3ec9660db60521a3423248b7d2c81c7 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Mon, 30 May 2022 12:12:26 -0400 Subject: [PATCH 05/12] Update tests --- tests/test_async.py | 6 +++--- tests/test_sync.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index ae5d9f2..d6400db 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -61,7 +61,7 @@ async def test_put(db): for input in ["Hello", 1, True, False, 3.14159265359]: resp = await db.put(input) - assert set(resp.keys()) == set(["key", "value"]) + assert set(resp.keys()) == {"key", "value"} async def test_put_fail(db): @@ -86,13 +86,13 @@ async def test_put_many_fail(db): async def test_put_many_fail_limit(db): with pytest.raises(Exception): - await db.put_many([i for i in range(26)]) + await db.put_many(list(range(26))) async def test_insert(db): item = {"msg": "hello"} resp = await db.insert(item) - assert set(resp.keys()) == set(["key", "msg"]) + assert set(resp.keys()) == {"key", "msg"} async def test_insert_fail(db, items): diff --git a/tests/test_sync.py b/tests/test_sync.py index a317ec6..c58b65e 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -209,11 +209,11 @@ def test_put(self): self.assertEqual(self.db.put(item, "one"), resp) self.assertEqual(self.db.put(item, "one"), resp) self.assertEqual({"msg": "hello"}, item) - self.assertEqual(set(self.db.put("Hello").keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(1).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(True).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(False).keys()), set(["key", "value"])) - self.assertEqual(set(self.db.put(3.14159265359).keys()), set(["key", "value"])) + self.assertEqual(set(self.db.put("Hello").keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(1).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(True).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(False).keys()), {"key", "value"}) + self.assertEqual(set(self.db.put(3.14159265359).keys()), {"key", "value"}) @unittest.expectedFailure def test_put_fail(self): @@ -231,11 +231,11 @@ def test_put_many_fail(self): @unittest.expectedFailure def test_put_many_fail_limit(self): - self.db.put_many([i for i in range(26)]) + self.db.put_many(list(range(26))) def test_insert(self): item = {"msg": "hello"} - self.assertEqual(set(self.db.insert(item).keys()), set(["key", "msg"])) + self.assertEqual(set(self.db.insert(item).keys()), {"key", "msg"}) self.assertEqual({"msg": "hello"}, item) @unittest.expectedFailure From a1ff1050dac678e34e83457d195ada85f347b922 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Mon, 30 May 2022 12:12:38 -0400 Subject: [PATCH 06/12] Update scripts --- scripts/build | 4 ++-- scripts/install | 2 +- scripts/publish | 10 +++++----- scripts/release | 10 +++++----- scripts/tag | 2 +- scripts/test_publish | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/scripts/build b/scripts/build index 2dd0a9a..6f9d97a 100755 --- a/scripts/build +++ b/scripts/build @@ -1,7 +1,7 @@ #!/bin/sh -e -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi diff --git a/scripts/install b/scripts/install index a150cff..123cb74 100755 --- a/scripts/install +++ b/scripts/install @@ -4,7 +4,7 @@ [ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3" REQUIREMENTS="requirements.txt" -VENV="venv" +VENV=".venv" #set -x diff --git a/scripts/publish b/scripts/publish index 61e80af..8bd2bba 100755 --- a/scripts/publish +++ b/scripts/publish @@ -3,18 +3,18 @@ VERSION_FILE="deta/__init__.py" SETUP_FILE="setup.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` - VERSION_SETUP=`grep version ${SETUP_FILE}| grep -o '[0-9][^"]*'` + VERSION_SETUP=`grep version ${SETUP_FILE} | grep -o '[0-9][^"]*'` if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" diff --git a/scripts/release b/scripts/release index 925c7b8..c585a54 100755 --- a/scripts/release +++ b/scripts/release @@ -3,18 +3,18 @@ VERSION_FILE="deta/__init__.py" SETUP_FILE="setup.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` - VERSION_SETUP=`grep version ${SETUP_FILE}| grep -o '[0-9][^"]*'` + VERSION_SETUP=`grep version ${SETUP_FILE} | grep -o '[0-9][^"]*'` if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" diff --git a/scripts/tag b/scripts/tag index 5cb1bca..8b6af1e 100755 --- a/scripts/tag +++ b/scripts/tag @@ -10,7 +10,7 @@ if [ ! -z "$GITHUB_ACTIONS" ]; then if [ "${VERSION}" != "${VERSION_SETUP}" ] ; then echo "__init__.py version '${VERSION}' did not match setup version '${VERSION_SETUP}'" - exit 15 + exit 1 fi fi diff --git a/scripts/test_publish b/scripts/test_publish index 255af16..81b3321 100755 --- a/scripts/test_publish +++ b/scripts/test_publish @@ -2,15 +2,15 @@ VERSION_FILE="deta/__init__.py" -if [ -d 'venv' ] ; then - PREFIX="venv/bin/" +if [ -d '.venv' ] ; then + PREFIX=".venv/bin/" else PREFIX="" fi if [ ! -z "$GITHUB_ACTIONS" ]; then - git config --local user.email "action@github.com" - git config --local user.name "GitHub Action" + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` From 1430de76e039b0f4e6f11fef8ea04c7c92a98f42 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Mon, 30 May 2022 12:21:41 -0400 Subject: [PATCH 07/12] Remove commented code --- deta/base.py | 18 ------------------ deta/utils.py | 5 +---- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/deta/base.py b/deta/base.py index 0db86e1..f025644 100644 --- a/deta/base.py +++ b/deta/base.py @@ -207,24 +207,6 @@ def put_many(self, items, *, expire_in=None, expire_at=None): _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res - # def _fetch( - # self, - # query: typing.Union[dict, list] = None, - # limit: int = None, - # last: str = None, - # ): - # """This is where actual fetch happens.""" - # payload = { - # "limit": limit, - # "last": last if not isinstance(last, bool) else None, - # } - - # if query: - # payload["query"] = query if isinstance(query, list) else [query] - - # code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) - # return code, res - def fetch( self, query: typing.Union[dict, list] = None, diff --git a/deta/utils.py b/deta/utils.py index 6290a14..51cf5d9 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,10 +1,7 @@ import os -def _get_project_key_id( - project_key: str = None, - project_id: str = None, -): +def _get_project_key_id(project_key: str = None, project_id: str = None): project_key = project_key or os.getenv("DETA_PROJECT_KEY") if not project_key: From 79a99f2577184cb9c052255c02019c357264254b Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Sun, 5 Jun 2022 15:25:12 -0400 Subject: [PATCH 08/12] Minor formatting changes --- deta/_async/client.py | 13 ++++--------- deta/base.py | 29 ++++++----------------------- deta/drive.py | 24 +++--------------------- 3 files changed, 13 insertions(+), 53 deletions(-) diff --git a/deta/_async/client.py b/deta/_async/client.py index 11235a9..2683f93 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -88,7 +88,7 @@ async def insert(self, data, key=None, *, expire_in=None, expire_at=None): if key: data["key"] = key - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) async with self._session.post(f"{self._base_url}/items", json={"item": data}) as resp: return await resp.json() @@ -126,7 +126,7 @@ async def put(self, data, key=None, *, expire_in=None, expire_at=None): if key: data["key"] = key - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) async with self._session.put(f"{self._base_url}/items", json={"items": [data]}) as resp: if resp.status != 207: return None @@ -167,7 +167,7 @@ async def put_many(self, items, *, expire_in=None, expire_at=None): data = i if not isinstance(i, dict): data = {"value": i} - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) _items.append(data) async with self._session.put(f"{self._base_url}/items", json={"items": _items}) as resp: @@ -249,12 +249,7 @@ async def update(self, updates, key, *, expire_in=None, expire_at=None): if not payload: raise ValueError("must provide at least one update action") - insert_ttl( - payload["set"], - self._ttl_attribute, - expire_in=expire_in, - expire_at=expire_at, - ) + insert_ttl(payload["set"], self._ttl_attribute, expire_in, expire_at) encoded_key = quote(key, safe="") await self._session.patch(f"{self._base_url}/items/{encoded_key}", json=payload) diff --git a/deta/base.py b/deta/base.py index f025644..25b5526 100644 --- a/deta/base.py +++ b/deta/base.py @@ -55,13 +55,7 @@ def __init__(self, name: str, project_key: str, project_id: str, host: str = Non raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" - super().__init__( - project_key=project_key, - project_id=project_id, - host=host, - name=name, - timeout=BASE_SERVICE_TIMEOUT, - ) + super().__init__(project_key, project_id, host, name, BASE_SERVICE_TIMEOUT) self._ttl_attribute = BASE_TTL_ATTRIBUTE self.util = Util() @@ -118,7 +112,7 @@ def insert(self, data, key=None, *, expire_in=None, expire_at=None): if key: data["key"] = key - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) code, res = self._request("/items", "POST", {"item": data}, content_type=JSON_MIME) if code == 201: @@ -163,7 +157,7 @@ def put(self, data, key=None, *, expire_in=None, expire_at=None): if key: data["key"] = key - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) code, res = self._request("/items", "PUT", {"items": [data]}, content_type=JSON_MIME) return res["processed"]["items"][0] if res and code == 207 else None @@ -201,19 +195,13 @@ def put_many(self, items, *, expire_in=None, expire_at=None): data = item if not isinstance(item, dict): data = {"value": item} - insert_ttl(data, self._ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self._ttl_attribute, expire_in, expire_at) _items.append(data) _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res - def fetch( - self, - query: typing.Union[dict, list] = None, - *, - limit: int = 1000, - last: str = None, - ): + def fetch(self, query: typing.Union[dict, list] = None, *, limit: int = 1000, last: str = None): """Fetch items from the database. `query` is an optional filter or list of filters. Without a filter, it will return the whole db. """ @@ -286,12 +274,7 @@ def update(self, updates, key, *, expire_in=None, expire_at=None): else: payload["set"][attr] = value - insert_ttl( - payload["set"], - self._ttl_attribute, - expire_in=expire_in, - expire_at=expire_at, - ) + insert_ttl(payload["set"], self._ttl_attribute, expire_in, expire_at) encoded_key = quote(key, safe="") code, _ = self._request(f"/items/{encoded_key}", "PATCH", payload, content_type=JSON_MIME) diff --git a/deta/drive.py b/deta/drive.py index 1f513ad..3106457 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -45,25 +45,12 @@ def close(self): class _Drive(_Service): - def __init__( - self, - name: str, - project_key: str, - project_id: str, - host: str = None, - ): + def __init__(self, name: str, project_key: str, project_id: str, host: str = None): if not name: raise ValueError("parameter 'name' must be a non-empty string") host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" - super().__init__( - project_key=project_key, - project_id=project_id, - host=host, - name=name, - timeout=DRIVE_SERVICE_TIMEOUT, - keep_alive=False, - ) + super().__init__(project_key, project_id, host, name, DRIVE_SERVICE_TIMEOUT, False) def _quote(self, param: str) -> str: return quote_plus(param) @@ -168,12 +155,7 @@ def put(self, name, data=None, *, path=None, content_type=None): content_stream.close() raise e - def list( - self, - limit: int = 1000, - prefix: str = None, - last: str = None, - ) -> dict: + def list(self, limit: int = 1000, prefix: str = None, last: str = None) -> dict: """List file names from drive. `limit` is the number of file names to get, defaults to 1000. `prefix` is the prefix of file names. From f8fb9254c20cd623e66f5def16bca6f98a9ff571 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Sun, 5 Jun 2022 15:46:26 -0400 Subject: [PATCH 09/12] Minor async fixes --- deta/_async/client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/deta/_async/client.py b/deta/_async/client.py index 2683f93..034e59c 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -6,16 +6,14 @@ try: import aiohttp except ImportError: - has_aiohttp = False -else: - has_aiohttp = True + aiohttp = None from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTRIBUTE class _AsyncBase: def __init__(self, name: str, project_key: str, project_id: str, host: str = None): - if not has_aiohttp: + if aiohttp is None: raise RuntimeError("aiohttp library is required for async support") if not name: @@ -163,10 +161,10 @@ async def put_many(self, items, *, expire_in=None, expire_at=None): raise ValueError("cannot put more than 25 items at a time") _items = [] - for i in items: - data = i - if not isinstance(i, dict): - data = {"value": i} + for item in items: + data = item + if not isinstance(item, dict): + data = {"value": item} insert_ttl(data, self._ttl_attribute, expire_in, expire_at) _items.append(data) From ea41b85c7e90b1daebd4365d99591f2d20a0b4cd Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Fri, 19 Aug 2022 14:45:51 -0400 Subject: [PATCH 10/12] Add Optional typings, more typing fixes --- deta/__init__.py | 27 ++++++----- deta/_async/client.py | 82 ++++++++++++++++---------------- deta/base.py | 106 +++++++++++++++++++++++------------------- deta/drive.py | 29 ++++++------ deta/service.py | 15 +++--- deta/utils.py | 3 +- 6 files changed, 142 insertions(+), 120 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index 9ed7fe3..2b8f348 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -1,8 +1,8 @@ import os import json -import typing import urllib.error import urllib.request +from typing import Optional, Sequence, Union from .base import _Base from ._async.client import _AsyncBase @@ -19,37 +19,37 @@ __version__ = "1.1.0" -def Base(name: str, host: str = None): +def Base(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() return _Base(name, project_key, project_id, host) -def AsyncBase(name: str, host: str = None): +def AsyncBase(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() return _AsyncBase(name, project_key, project_id, host) -def Drive(name: str, host: str = None): +def Drive(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() return _Drive(name, project_key, project_id, host) class Deta: - def __init__(self, project_key: str = None, *, project_id: str = None): + def __init__(self, project_key: Optional[str] = None, *, project_id: Optional[str] = None): self.project_key, self.project_id = _get_project_key_id(project_key, project_id) - def Base(self, name: str, host: str = None): + def Base(self, name: str, host: Optional[str] = None): return _Base(name, self.project_key, self.project_id, host) - def AsyncBase(self, name: str, host: str = None): + def AsyncBase(self, name: str, host: Optional[str] = None): return _AsyncBase(name, self.project_key, self.project_id, host) - def Drive(self, name: str, host: str = None): + def Drive(self, name: str, host: Optional[str] = None): return _Drive(name, self.project_key, self.project_id, host) def send_email( self, - to: typing.Union[str, typing.Sequence[str]], + to: Union[str, Sequence[str]], subject: str, message: str, charset: str = "utf-8", @@ -58,17 +58,22 @@ def send_email( def send_email( - to: typing.Union[str, typing.Sequence[str]], + to: Union[str, Sequence[str]], subject: str, message: str, charset: str = "utf-8", ): + # should function continue if these are not present? pid = os.getenv("AWS_LAMBDA_FUNCTION_NAME") url = os.getenv("DETA_MAILER_URL") api_key = os.getenv("DETA_PROJECT_KEY") endpoint = f"{url}/mail/{pid}" - to = list(to) if isinstance(to, (list, tuple, set)) else [to] + if isinstance(to, str): + to = [to] + else: + to = list(to) + data = { "to": to, "subject": subject, diff --git a/deta/_async/client.py b/deta/_async/client.py index 034e59c..79863c3 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -1,6 +1,6 @@ import os import datetime -import typing +from typing import Mapping, Optional, Sequence, Union, overload from urllib.parse import quote try: @@ -12,7 +12,7 @@ class _AsyncBase: - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Optional[str] = None): if aiohttp is None: raise RuntimeError("aiohttp library is required for async support") @@ -36,12 +36,12 @@ def __init__(self, name: str, project_key: str, project_id: str, host: str = Non async def close(self): await self._session.close() - async def get(self, key: str) -> dict: + async def get(self, key: str) -> Optional[dict]: key = quote(key, safe="") try: async with self._session.get(f"{self._base_url}/items/{key}") as resp: return await resp.json() - except aiohttp.ClientResponseError as e: + except aiohttp.ClientResponseError as e: # type: ignore if e.status == 404: return else: @@ -52,31 +52,31 @@ async def delete(self, key: str): async with self._session.delete(f"{self._base_url}/items/{key}"): return - @typing.overload + @overload async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, ) -> dict: ... - @typing.overload + @overload async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, expire_in: int, ) -> dict: ... - @typing.overload + @overload async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -90,31 +90,31 @@ async def insert(self, data, key=None, *, expire_in=None, expire_at=None): async with self._session.post(f"{self._base_url}/items", json={"item": data}) as resp: return await resp.json() - @typing.overload + @overload async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, ) -> dict: ... - @typing.overload + @overload async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, expire_in: int, ) -> dict: ... - @typing.overload + @overload async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -131,28 +131,28 @@ async def put(self, data, key=None, *, expire_in=None, expire_at=None): resp_json = await resp.json() return resp_json["processed"]["items"][0] - @typing.overload + @overload async def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], ) -> dict: ... - @typing.overload + @overload async def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], *, expire_in: int, ) -> dict: ... - @typing.overload + @overload async def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -173,10 +173,10 @@ async def put_many(self, items, *, expire_in=None, expire_at=None): async def fetch( self, - query: typing.Union[dict, list] = None, + query: Optional[Union[Mapping, Sequence[Mapping]]] = None, *, limit: int = 1000, - last: str = None, + last: Optional[str] = None, ): payload = { "limit": limit, @@ -184,38 +184,38 @@ async def fetch( } if query: - payload["query"] = query if isinstance(query, list) else [query] + payload["query"] = query if isinstance(query, Sequence) else [query] async with self._session.post(f"{self._base_url}/query", json=payload) as resp: resp_json = await resp.json() paging = resp_json.get("paging") return FetchResponse(paging.get("size"), paging.get("last"), resp_json.get("items")) - @typing.overload + @overload async def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, ): ... - @typing.overload + @overload async def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, *, expire_in: int, ): ... - @typing.overload + @overload async def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ): ... diff --git a/deta/base.py b/deta/base.py index 25b5526..fa38f67 100644 --- a/deta/base.py +++ b/deta/base.py @@ -1,6 +1,6 @@ import os import datetime -import typing +from typing import Any, Dict, Mapping, Optional, Sequence, Union, overload from urllib.parse import quote from .service import _Service, JSON_MIME @@ -11,46 +11,52 @@ class FetchResponse: - def __init__(self, count: int = 0, last: str = None, items: list = None): + def __init__(self, count: int = 0, last: Optional[str] = None, items: Optional[list] = None): self.count = count self.last = last self.items = items if items is not None else [] - def __eq__(self, other): + def __eq__(self, other: "FetchResponse"): return self.count == other.count and self.last == other.last and self.items == other.items + def __iter__(self): + return iter(self.items) + + def __len__(self): + return len(self.items) + class Util: class Trim: pass class Increment: - def __init__(self, value=1): + def __init__(self, value: Union[int, float] = 1): self.value = value class Append: - def __init__(self, value): + def __init__(self, value: Union[dict, list, str, int, float, bool]): self.value = value if isinstance(value, list) else [value] class Prepend: - def __init__(self, value): + def __init__(self, value: Union[dict, list, str, int, float, bool]): self.value = value if isinstance(value, list) else [value] def trim(self): return self.Trim() - def increment(self, value: typing.Union[int, float] = 1): + def increment(self, value: Union[int, float] = 1): return self.Increment(value) - def append(self, value: typing.Union[dict, list, str, int, float, bool]): + def append(self, value: Union[dict, list, str, int, float, bool]): return self.Append(value) - def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): + def prepend(self, value: Union[dict, list, str, int, float, bool]): return self.Prepend(value) class _Base(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Optional[str] = None): if not name: raise ValueError("parameter 'name' must be a non-empty string") @@ -59,7 +65,7 @@ def __init__(self, name: str, project_key: str, project_id: str, host: str = Non self._ttl_attribute = BASE_TTL_ATTRIBUTE self.util = Util() - def get(self, key: str) -> dict: + def get(self, key: str) -> Dict[str, Any]: if not key: raise ValueError("parameter 'key' must be a non-empty string") @@ -79,31 +85,31 @@ def delete(self, key: str): key = quote(key, safe="") self._request(f"/items/{key}", "DELETE") - @typing.overload + @overload def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, ) -> dict: ... - @typing.overload + @overload def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, expire_in: int, ) -> dict: ... - @typing.overload + @overload def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -120,31 +126,31 @@ def insert(self, data, key=None, *, expire_in=None, expire_at=None): elif code == 409: raise ValueError(f"item with key '{key}' already exists") - @typing.overload + @overload def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, ) -> dict: ... - @typing.overload + @overload def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, expire_in: int, ) -> dict: ... - @typing.overload + @overload def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Optional[str] = None, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -161,28 +167,28 @@ def put(self, data, key=None, *, expire_in=None, expire_at=None): code, res = self._request("/items", "PUT", {"items": [data]}, content_type=JSON_MIME) return res["processed"]["items"][0] if res and code == 207 else None - @typing.overload + @overload def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], ) -> dict: ... - @typing.overload + @overload def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], *, expire_in: int, ) -> dict: ... - @typing.overload + @overload def put_many( self, - items: typing.Sequence[typing.Union[dict, list, str, int, bool]], + items: Sequence[Union[dict, list, str, int, bool]], *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ) -> dict: ... @@ -201,7 +207,13 @@ def put_many(self, items, *, expire_in=None, expire_at=None): _, res = self._request("/items", "PUT", {"items": _items}, content_type=JSON_MIME) return res - def fetch(self, query: typing.Union[dict, list] = None, *, limit: int = 1000, last: str = None): + def fetch( + self, + query: Optional[Union[Mapping, Sequence[Mapping]]] = None, + *, + limit: int = 1000, + last: Optional[str] = None, + ): """Fetch items from the database. `query` is an optional filter or list of filters. Without a filter, it will return the whole db. """ @@ -211,37 +223,37 @@ def fetch(self, query: typing.Union[dict, list] = None, *, limit: int = 1000, la } if query: - payload["query"] = query if isinstance(query, list) else [query] + payload["query"] = query if isinstance(query, Sequence) else [query] _, res = self._request("/query", "POST", payload, content_type=JSON_MIME) paging = res.get("paging") return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) - @typing.overload + @overload def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, ): ... - @typing.overload + @overload def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, *, expire_in: int, ): ... - @typing.overload + @overload def update( self, - updates: typing.Mapping, + updates: Mapping, key: str, *, - expire_at: typing.Union[int, float, datetime.datetime], + expire_at: Union[int, float, datetime.datetime], ): ... diff --git a/deta/drive.py b/deta/drive.py index 3106457..79974e4 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -1,6 +1,6 @@ import os -import typing from io import BufferedIOBase, TextIOBase, RawIOBase, StringIO, BytesIO +from typing import Any, Dict, Iterator, Optional, Sequence, Union, overload from urllib.parse import quote_plus from .service import JSON_MIME, _Service @@ -20,17 +20,17 @@ def __init__(self, res: BufferedIOBase): def closed(self): return self._stream.closed - def read(self, size: int = None) -> bytes: + def read(self, size: Optional[int] = None) -> bytes: return self._stream.read(size) - def iter_chunks(self, chunk_size: int = 1024) -> typing.Iterator[bytes]: + def iter_chunks(self, chunk_size: int = 1024) -> Iterator[bytes]: while True: chunk = self._stream.read(chunk_size) if not chunk: break yield chunk - def iter_lines(self, chunk_size: int = 1024) -> typing.Iterator[bytes]: + def iter_lines(self, chunk_size: int = 1024) -> Iterator[bytes]: while True: chunk = self._stream.readline(chunk_size) if not chunk: @@ -45,7 +45,7 @@ def close(self): class _Drive(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Optional[str] = None): if not name: raise ValueError("parameter 'name' must be a non-empty string") @@ -82,7 +82,7 @@ def delete(self, name: str) -> str: return name - def delete_many(self, names: typing.Sequence[str]) -> dict: + def delete_many(self, names: Sequence[str]) -> dict: """Delete many files from drive in single request. `names` are the names of the files to be deleted. Returns a dict with 'deleted' and 'failed' files. @@ -96,17 +96,17 @@ def delete_many(self, names: typing.Sequence[str]) -> dict: _, res = self._request("/files", "DELETE", {"names": names}, content_type=JSON_MIME) return res - @typing.overload + @overload def put( self, name: str, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], *, content_type: str, ) -> str: ... - @typing.overload + @overload def put( self, name: str, @@ -132,7 +132,10 @@ def put(self, name, data=None, *, path=None, content_type=None): # start upload upload_id = self._start_upload(name) - content_stream = open(path, "rb") if path else self._get_content_stream(data) + if data: + content_stream = self._get_content_stream(data) + else: + content_stream = open(path, "rb") part = 1 # upload chunks @@ -155,7 +158,7 @@ def put(self, name, data=None, *, path=None, content_type=None): content_stream.close() raise e - def list(self, limit: int = 1000, prefix: str = None, last: str = None) -> dict: + def list(self, limit: int = 1000, prefix: Optional[str] = None, last: Optional[str] = None) -> Dict[str, Any]: """List file names from drive. `limit` is the number of file names to get, defaults to 1000. `prefix` is the prefix of file names. @@ -188,7 +191,7 @@ def _upload_part( chunk: bytes, upload_id: str, part: int, - content_type: str = None, + content_type: Optional[str] = None, ): self._request( f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", @@ -199,7 +202,7 @@ def _upload_part( def _get_content_stream( self, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], + data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase], ): if isinstance(data, str): return StringIO(data) diff --git a/deta/service.py b/deta/service.py index b44331c..67b4d20 100644 --- a/deta/service.py +++ b/deta/service.py @@ -3,8 +3,8 @@ import socket import http.client import struct -import typing import urllib.error +from typing import Mapping, MutableMapping, Optional, Tuple, Union JSON_MIME = "application/json" @@ -42,11 +42,11 @@ def _request( self, path: str, method: str, - data: typing.Union[str, bytes, dict] = None, - headers: typing.Mapping = None, - content_type: str = None, + data: Optional[Union[str, bytes, dict]] = None, + headers: Optional[MutableMapping[str, str]] = None, + content_type: Optional[str] = None, stream: bool = False, - ) -> typing.Tuple[int, dict]: + ) -> Tuple[int, Optional[dict]]: url = self.base_path + path headers = headers or {} headers["X-Api-Key"] = self.project_key @@ -101,10 +101,11 @@ def _send_request_with_retry( self, method: str, url: str, - headers: typing.Mapping[str, str] = None, - body: typing.Union[str, bytes, dict] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[str, bytes, dict]] = None, retry: int = 2, # try at least twice to regain a new connection ): + headers = headers if headers is not None else {} reinit_connection = False while retry > 0: try: diff --git a/deta/utils.py b/deta/utils.py index 51cf5d9..95c0dd0 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,7 +1,8 @@ import os +from typing import Optional -def _get_project_key_id(project_key: str = None, project_id: str = None): +def _get_project_key_id(project_key: Optional[str] = None, project_id: Optional[str] = None): project_key = project_key or os.getenv("DETA_PROJECT_KEY") if not project_key: From 3677fc6b3bc1dbe9c7c38fbb6b8ba8c532c7dc65 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Tue, 30 Aug 2022 20:35:15 -0400 Subject: [PATCH 11/12] Improve type hinting, add kw-only args --- deta/__init__.py | 20 +++++++++++--------- deta/_async/client.py | 21 ++++++++++++--------- deta/base.py | 13 +++++++++---- deta/drive.py | 11 ++++++++--- 4 files changed, 40 insertions(+), 25 deletions(-) diff --git a/deta/__init__.py b/deta/__init__.py index 2b8f348..ad81c03 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -21,17 +21,18 @@ def Base(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() - return _Base(name, project_key, project_id, host) + return _Base(name, project_key, project_id, host=host) -def AsyncBase(name: str, host: Optional[str] = None): +# TODO: type hint for session +def AsyncBase(name: str, host: Optional[str] = None, session=None): project_key, project_id = _get_project_key_id() - return _AsyncBase(name, project_key, project_id, host) + return _AsyncBase(name, project_key, project_id, host=host, session=session) def Drive(name: str, host: Optional[str] = None): project_key, project_id = _get_project_key_id() - return _Drive(name, project_key, project_id, host) + return _Drive(name, project_key, project_id, host=host) class Deta: @@ -39,13 +40,14 @@ def __init__(self, project_key: Optional[str] = None, *, project_id: Optional[st self.project_key, self.project_id = _get_project_key_id(project_key, project_id) def Base(self, name: str, host: Optional[str] = None): - return _Base(name, self.project_key, self.project_id, host) + return _Base(name, self.project_key, self.project_id, host=host) - def AsyncBase(self, name: str, host: Optional[str] = None): - return _AsyncBase(name, self.project_key, self.project_id, host) + # TODO: type hint for session + def AsyncBase(self, name: str, host: Optional[str] = None, session=None): + return _AsyncBase(name, self.project_key, self.project_id, host=host, session=session) def Drive(self, name: str, host: Optional[str] = None): - return _Drive(name, self.project_key, self.project_id, host) + return _Drive(name, self.project_key, self.project_id, host=host) def send_email( self, @@ -63,7 +65,7 @@ def send_email( message: str, charset: str = "utf-8", ): - # should function continue if these are not present? + # FIXME: should function continue if these are not present? pid = os.getenv("AWS_LAMBDA_FUNCTION_NAME") url = os.getenv("DETA_MAILER_URL") api_key = os.getenv("DETA_PROJECT_KEY") diff --git a/deta/_async/client.py b/deta/_async/client.py index f8d875f..7e54e43 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -21,7 +21,7 @@ def __init__( project_id: str, *, host: Optional[str] = None, - session: Optional[aiohttp.ClientSession] = None, + session: "Optional[aiohttp.ClientSession]" = None, ): if not has_aiohttp: raise RuntimeError("aiohttp library is required for async support") @@ -35,13 +35,16 @@ def __init__( self.util = Util() self._ttl_attribute = BASE_TTL_ATTRIBUTE - self._session = session or aiohttp.ClientSession( - headers={ - "Content-type": "application/json", - "X-API-Key": project_key, - }, - raise_for_status=True, - ) + if session is not None: + self._session = session + else: + self._session = aiohttp.ClientSession( + headers={ + "Content-type": "application/json", + "X-API-Key": project_key, + }, + raise_for_status=True, + ) async def close(self): await self._session.close() @@ -51,7 +54,7 @@ async def get(self, key: str) -> Optional[dict]: try: async with self._session.get(f"{self._base_url}/items/{key}") as resp: return await resp.json() - except aiohttp.ClientResponseError as e: # type: ignore + except aiohttp.ClientResponseError as e: if e.status == 404: return else: diff --git a/deta/base.py b/deta/base.py index fa38f67..225a0f8 100644 --- a/deta/base.py +++ b/deta/base.py @@ -22,7 +22,7 @@ def __eq__(self, other: "FetchResponse"): def __iter__(self): return iter(self.items) - def __len__(self): + def __len__(self) -> int: return len(self.items) @@ -56,7 +56,7 @@ def prepend(self, value: Union[dict, list, str, int, float, bool]): class _Base(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: Optional[str] = None): + def __init__(self, name: str, project_key: str, project_id: str, *, host: Optional[str] = None): if not name: raise ValueError("parameter 'name' must be a non-empty string") @@ -294,9 +294,14 @@ def update(self, updates, key, *, expire_in=None, expire_at=None): raise ValueError(f"key '{key}' not found") -def insert_ttl(item, ttl_attribute, expire_in=None, expire_at=None): +def insert_ttl( + item, + ttl_attribute: str, + expire_in: Optional[Union[int, float]] = None, + expire_at: Optional[Union[int, float, datetime.datetime]] = None, +): if expire_in and expire_at: - raise ValueError("both 'expire_in' and 'expire_at' provided") + raise ValueError("'expire_in' and 'expire_at' are mutually exclusive parameters") if not expire_in and not expire_at: return diff --git a/deta/drive.py b/deta/drive.py index 79974e4..73875e6 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -45,7 +45,7 @@ def close(self): class _Drive(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: Optional[str] = None): + def __init__(self, name: str, project_key: str, project_id: str, *, host: Optional[str] = None): if not name: raise ValueError("parameter 'name' must be a non-empty string") @@ -55,7 +55,7 @@ def __init__(self, name: str, project_key: str, project_id: str, host: Optional[ def _quote(self, param: str) -> str: return quote_plus(param) - def get(self, name: str) -> DriveStreamingBody: + def get(self, name: str) -> Optional[DriveStreamingBody]: """Download a file from drive. `name` is the name of the file. Returns a DriveStreamingBody. @@ -158,7 +158,12 @@ def put(self, name, data=None, *, path=None, content_type=None): content_stream.close() raise e - def list(self, limit: int = 1000, prefix: Optional[str] = None, last: Optional[str] = None) -> Dict[str, Any]: + def list( + self, + limit: int = 1000, + prefix: Optional[str] = None, + last: Optional[str] = None, + ) -> Dict[str, Any]: """List file names from drive. `limit` is the number of file names to get, defaults to 1000. `prefix` is the prefix of file names. From 3e2e4a355ca4d38ed777df378d8126c17d75dee2 Mon Sep 17 00:00:00 2001 From: LemonPi314 <49930425+LemonPi314@users.noreply.github.com> Date: Thu, 15 Sep 2022 17:43:51 -0400 Subject: [PATCH 12/12] Add py.typed marker file --- deta/py.typed | 0 setup.py | 4 ++++ 2 files changed, 4 insertions(+) create mode 100644 deta/py.typed diff --git a/deta/py.typed b/deta/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 646e8f6..045b2e0 100644 --- a/setup.py +++ b/setup.py @@ -17,4 +17,8 @@ extras_require={ "async": ["aiohttp>=3,<4"], }, + package_data={ + 'deta': ['py.typed'], + }, + zip_safe=False, )