Skip to content

Commit

Permalink
Update: type hints for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yankeexe committed Apr 15, 2021
1 parent cc7d58a commit 9f1e0ec
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
17 changes: 13 additions & 4 deletions deta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import typing
import urllib.error
import urllib.request

from .base import Base

try:
Expand All @@ -19,21 +19,30 @@

class Deta:
def __init__(
self, project_key: str = None, *, project_id: str = None, host: str = None
self,
project_key: typing.Optional[str] = None,
*,
project_id: typing.Optional[str] = None,
host: typing.Optional[str] = None,
):
self.project_key = project_key or os.getenv("DETA_PROJECT_KEY")
self.project_id = project_id
if not self.project_id:
self.project_id = self.project_key.split("_")[0]

def Base(self, name: str, host: str = None):
def Base(self, name: str, host: typing.Optional[str] = None):
return Base(name, self.project_key, self.project_id, host)

def send_email(self, to, subject, message, charset="UTF-8"):
return send_email(to, subject, message, charset)


def send_email(to, subject, message, charset="UTF-8"):
def send_email(
to: typing.Union[str, typing.List[str]],
subject: str,
message: str,
charset: str = "UTF-8",
):
pid = os.getenv("AWS_LAMBDA_FUNCTION_NAME")
url = os.getenv("DETA_MAILER_URL")
api_key = os.getenv("DETA_PROJECT_KEY")
Expand Down
33 changes: 21 additions & 12 deletions deta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _is_socket_closed(self):
return True
return False

def _request(self, path: str, method: str, data: dict = None):
def _request(
self, path: str, method: str, data: typing.Optional[dict] = None
) -> typing.Tuple[int, typing.Mapping]:
url = self.base_path + path

# close connection if socket is closed
Expand All @@ -93,7 +95,7 @@ def _request(self, path: str, method: str, data: dict = None):
return status, json.loads(payload) if status != 404 else None
raise urllib.error.HTTPError(url, status, res.reason, res.headers, res.fp)

def get(self, key: str) -> dict:
def get(self, key: str) -> typing.Mapping:
if key == "":
raise ValueError("Key is empty")

Expand All @@ -114,7 +116,11 @@ def delete(self, key: str) -> bool:
_, _ = self._request("/items/{}".format(key), "DELETE")
return None

def insert(self, data: typing.Union[dict, list, str, int, bool], key: str = None):
def insert(
self,
data: typing.Union[dict, list, str, int, bool],
key: typing.Optional[str] = None,
):
if not isinstance(data, dict):
data = {"value": data}
else:
Expand All @@ -129,10 +135,14 @@ def insert(self, data: typing.Union[dict, list, str, int, bool], key: str = None
elif code == 409:
raise Exception("Item with key '{4}' already exists".format(key))

def put(self, data: typing.Union[dict, list, str, int, bool], key: str = None):
""" store (put) an item in the database. Overrides an item if key already exists.
`key` could be provided as function argument or a field in the data dict.
If `key` is not provided, the server will generate a random 12 chars key.
def put(
self,
data: typing.Union[dict, list, str, int, bool],
key: typing.Optional[str] = None,
):
"""store (put) an item in the database. Overrides an item if key already exists.
`key` could be provided as function argument or a field in the data dict.
If `key` is not provided, the server will generate a random 12 chars key.
"""

if not isinstance(data, dict):
Expand Down Expand Up @@ -163,9 +173,8 @@ def _fetch(
query: typing.Union[dict, list] = None,
buffer: int = None,
last: str = None,
) -> typing.Tuple[int, list]:
"""This is where actual fetch happens.
"""
) -> typing.Tuple[int, typing.Mapping]:
"""This is where actual fetch happens."""
payload = {
"limit": buffer,
"last": last if not isinstance(last, bool) else None,
Expand All @@ -186,7 +195,7 @@ def fetch(
) -> typing.Generator:
"""
fetch items from the database.
`query` is an optional filter or list of filters. Without filter, it will return the whole db.
`query` is an optional filter or list of filters. Without filter, it will return the whole db.
Returns a generator with all the result, We will paginate the request based on `buffer`.
"""
last = True
Expand All @@ -202,7 +211,7 @@ def fetch(
def update(self, updates: dict, key: str):
"""
update an item in the database
`updates` specifies the attribute names and values to update,add or remove
`updates` specifies the attribute names and values to update,add or remove
`key` is the kye of the item to be updated
"""

Expand Down
9 changes: 3 additions & 6 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

try:
from dotenv import load_dotenv

load_dotenv()
except:
pass
Expand Down Expand Up @@ -80,9 +81,7 @@ def test_put_many_fail_limit(self):

def test_insert(self):
item = {"msg": "hello"}
self.assertEqual(
set(self.db.insert(item).keys()), set(["key", "msg"])
)
self.assertEqual(set(self.db.insert(item).keys()), set(["key", "msg"]))
self.assertEqual({"msg": "hello"}, item)

@unittest.expectedFailure
Expand Down Expand Up @@ -119,9 +118,7 @@ def test_update(self):
self.assertEqual(self.db.get("existing4"), expectedItem)

self.assertIsNone(
self.db.update(
{"value.name": self.db.util.trim(), "value.age": 32}, "existing4"
)
self.db.update({"value.name": self.db.util.trim(), "value.age": 32}, "existing4")
)
expectedItem = {"key": "existing4", "value": {"age": 32}}
self.assertEqual(self.db.get("existing4"), expectedItem)
Expand Down

0 comments on commit 9f1e0ec

Please sign in to comment.