Skip to content

Commit

Permalink
🏷️ (io) type better
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwoerpel committed Feb 16, 2024
1 parent 53272f8 commit 75d574b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
13 changes: 8 additions & 5 deletions anystore/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import contextlib
import logging
from pathlib import Path
import sys
from typing import Any, BinaryIO, Generator, TextIO
from typing import Any, BinaryIO, Generator, TextIO, TypeAlias

from fsspec import open
from fsspec.core import OpenFile
Expand All @@ -12,6 +13,8 @@

DEFAULT_MODE = "rb"

Uri: TypeAlias = Path | BinaryIO | TextIO | str


def _get_sysio(mode: str | None = DEFAULT_MODE) -> TextIO | BinaryIO:
if mode and mode.startswith("r"):
Expand All @@ -22,7 +25,7 @@ def _get_sysio(mode: str | None = DEFAULT_MODE) -> TextIO | BinaryIO:
class SmartHandler:
def __init__(
self,
uri: Any,
uri: Uri,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -57,7 +60,7 @@ def __exit__(self, *args, **kwargs) -> None:

@contextlib.contextmanager
def smart_open(
uri: Any,
uri: Uri,
mode: str | None = None,
*args,
**kwargs,
Expand All @@ -73,13 +76,13 @@ def smart_open(
handler.close()


def smart_stream(uri, *args, **kwargs) -> Generator[str | bytes, None, None]:
def smart_stream(uri: Uri, *args, **kwargs) -> Generator[str | bytes, None, None]:
with smart_open(uri, *args, **kwargs) as fh:
while line := fh.readline():
yield line


def smart_read(uri, *args, **kwargs) -> Any:
def smart_read(uri: Uri, *args, **kwargs) -> Any:
with smart_open(uri, *args, **kwargs) as fh:
return fh.read()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_data2(*args, **kwargs):
assert get_data2("x") == "data2"
assert store.get("X") == b"data2"

# not yet existing
# not yet existing store
@anycache(uri=tmp_path / "foo", key_func=lambda x: x)
def get_data3(data):
return data
Expand Down

0 comments on commit 75d574b

Please sign in to comment.