Skip to content

Commit

Permalink
feat: use literal in type hints (#1827)
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Shaver <benpshaver@gmail.com>
  • Loading branch information
bpshaver committed Oct 24, 2023
1 parent d5d928b commit 522811f
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 36 deletions.
43 changes: 23 additions & 20 deletions docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_dict_to_access_paths,
)
from docarray.utils._internal.compress import _decompress_bytes, _get_compress_ctx
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import import_library, ProtocolType

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -57,9 +57,9 @@

def _protocol_and_compress_from_file_path(
file_path: Union[pathlib.Path, str],
default_protocol: Optional[str] = None,
default_protocol: Optional[ProtocolType] = None,
default_compress: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
) -> Tuple[Optional[ProtocolType], Optional[str]]:
"""Extract protocol and compression algorithm from a string, use defaults if not found.
:param file_path: path of a file.
:param default_protocol: default serialization protocol used in case not found.
Expand All @@ -79,7 +79,7 @@ def _protocol_and_compress_from_file_path(
file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes]
for extension in file_extensions:
if extension in ALLOWED_PROTOCOLS:
protocol = extension
protocol = cast(ProtocolType, extension)
elif extension in ALLOWED_COMPRESSIONS:
compress = extension

Expand Down Expand Up @@ -135,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto':
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
Expand All @@ -157,7 +157,7 @@ def from_bytes(
def _write_bytes(
self,
bf: BinaryIO,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
Expand Down Expand Up @@ -201,7 +201,7 @@ def _write_bytes(

def _to_binary_stream(
self,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _to_binary_stream(

def to_bytes(
self,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
file_ctx: Optional[BinaryIO] = None,
show_progress: bool = False,
Expand Down Expand Up @@ -273,7 +273,7 @@ def to_bytes(
def from_base64(
cls: Type[T],
data: str,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
Expand All @@ -294,7 +294,7 @@ def from_base64(

def to_base64(
self,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> str:
Expand Down Expand Up @@ -383,7 +383,6 @@ def _from_csv_file(
file: Union[StringIO, TextIOWrapper],
dialect: Union[str, csv.Dialect],
) -> 'T':

rows = csv.DictReader(file, dialect=dialect)

doc_type = cls.doc_type
Expand Down Expand Up @@ -576,7 +575,7 @@ def _get_proto_class(cls: Type[T]):
def _load_binary_all(
cls: Type[T],
file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]],
protocol: Optional[str],
protocol: Optional[ProtocolType],
compress: Optional[str],
show_progress: bool,
tensor_type: Optional[Type['AbstractTensor']] = None,
Expand Down Expand Up @@ -659,7 +658,9 @@ def _load_binary_all(
start_pos = end_doc_pos

# variable length bytes doc
load_protocol: str = protocol or 'protobuf'
load_protocol: ProtocolType = protocol or cast(
ProtocolType, 'protobuf'
)
doc = cls.doc_type.from_bytes(
d[start_doc_pos:end_doc_pos],
protocol=load_protocol,
Expand All @@ -680,7 +681,7 @@ def _load_binary_all(
def _load_binary_stream(
cls: Type[T],
file_ctx: ContextManager[io.BufferedReader],
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Generator['T_doc', None, None]:
Expand Down Expand Up @@ -728,7 +729,7 @@ def _load_binary_stream(
len_current_doc_in_bytes = int.from_bytes(
f.read(4), 'big', signed=False
)
load_protocol: str = protocol
load_protocol: ProtocolType = protocol
yield cls.doc_type.from_bytes(
f.read(len_current_doc_in_bytes),
protocol=load_protocol,
Expand All @@ -743,10 +744,12 @@ def _load_binary_stream(
@staticmethod
def _get_file_context(
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str,
protocol: ProtocolType,
compress: Optional[str] = None,
) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]:
load_protocol: Optional[str] = protocol
) -> Tuple[
Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]
]:
load_protocol: Optional[ProtocolType] = protocol
load_compress: Optional[str] = compress
file_ctx: Union[nullcontext, io.BufferedReader]
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
Expand All @@ -765,7 +768,7 @@ def _get_file_context(
def load_binary(
cls: Type[T],
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
Expand Down Expand Up @@ -814,7 +817,7 @@ def load_binary(
def save_binary(
self,
file: Union[str, pathlib.Path],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions docarray/array/doc_vec/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.pydantic import is_pydantic_v2
from docarray.utils._internal.misc import ProtocolType

if TYPE_CHECKING:
import csv
Expand Down Expand Up @@ -134,7 +135,6 @@ def _from_json_col_dict(
json_columns: Dict[str, Any],
tensor_type: Type[AbstractTensor] = NdArray,
) -> T:

tensor_cols = json_columns['tensor_columns']
doc_cols = json_columns['doc_columns']
docs_vec_cols = json_columns['docs_vec_columns']
Expand Down Expand Up @@ -351,7 +351,7 @@ def from_csv(
def from_base64(
cls: Type[T],
data: str,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
tensor_type: Type['AbstractTensor'] = NdArray,
Expand All @@ -377,7 +377,7 @@ def from_base64(
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
tensor_type: Type['AbstractTensor'] = NdArray,
Expand Down Expand Up @@ -454,7 +454,7 @@ class Person(BaseDoc):
def load_binary(
cls: Type[T],
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
Expand Down
11 changes: 4 additions & 7 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.compress import _compress_bytes, _decompress_bytes
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import ProtocolType, import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if TYPE_CHECKING:
Expand All @@ -37,7 +37,6 @@
from docarray.proto import DocProto, NodeProto
from docarray.typing import TensorFlowTensor, TorchTensor


else:
tf = import_library('tensorflow', raise_error=False)
if tf is not None:
Expand Down Expand Up @@ -150,7 +149,7 @@ def __bytes__(self) -> bytes:
return self.to_bytes()

def to_bytes(
self, protocol: str = 'protobuf', compress: Optional[str] = None
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> bytes:
"""Serialize itself into bytes.
Expand All @@ -177,7 +176,7 @@ def to_bytes(
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
) -> T:
"""Build Document object from binary bytes
Expand All @@ -203,7 +202,7 @@ def from_bytes(
)

def to_base64(
self, protocol: str = 'protobuf', compress: Optional[str] = None
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> str:
"""Serialize a Document object into as base64 string
Expand Down Expand Up @@ -329,7 +328,6 @@ def _get_content_from_node_proto(
return_field = getattr(value, content_key)

elif content_key in arg_to_container.keys():

if field_name and field_name in cls._docarray_fields():
field_type = cls._get_field_inner_type(field_name)
else:
Expand All @@ -347,7 +345,6 @@ def _get_content_from_node_proto(
deser_dict: Dict[str, Any] = dict()

if field_name and field_name in cls._docarray_fields():

if is_pydantic_v2:
dict_args = get_args(
cls._docarray_fields()[field_name].annotation
Expand Down
9 changes: 5 additions & 4 deletions docarray/store/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich import filesize
from typing_extensions import TYPE_CHECKING, Protocol

from docarray.utils._internal.misc import ProtocolType
from docarray.utils._internal.progress_bar import _get_progressbar

if TYPE_CHECKING:
Expand Down Expand Up @@ -112,12 +113,12 @@ def raise_req_error(resp: 'requests.Response') -> NoReturn:
class Streamable(Protocol):
"""A protocol for streamable objects."""

def to_bytes(self, protocol: str, compress: Optional[str]) -> bytes:
def to_bytes(self, protocol: ProtocolType, compress: Optional[str]) -> bytes:
...

@classmethod
def from_bytes(
cls: Type[T_Elem], bytes: bytes, protocol: str, compress: Optional[str]
cls: Type[T_Elem], bytes: bytes, protocol: ProtocolType, compress: Optional[str]
) -> 'T_Elem':
...

Expand All @@ -133,7 +134,7 @@ def close(self):
def _to_binary_stream(
iterator: Iterator['Streamable'],
total: Optional[int] = None,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
Expand Down Expand Up @@ -170,7 +171,7 @@ def _from_binary_stream(
cls: Type[T],
stream: ReadableBytes,
total: Optional[int] = None,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator['T']:
Expand Down
6 changes: 5 additions & 1 deletion docarray/utils/_internal/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import types
from typing import Any, Optional
from typing import Any, Optional, Literal

import numpy as np

Expand Down Expand Up @@ -52,6 +52,10 @@
'pymilvus': '"docarray[milvus]"',
}

ProtocolType = Literal[
'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'
]


def import_library(
package: str, raise_error: bool = True
Expand Down

0 comments on commit 522811f

Please sign in to comment.