Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: OR Query implementation #698

Merged
merged 10 commits into from
Apr 3, 2023
90 changes: 54 additions & 36 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@


from typing import (
Optional,
Any,
AsyncGenerator,
Coroutine,
Expand Down Expand Up @@ -113,7 +114,7 @@ def _query(self) -> BaseQuery:
def _aggregation_query(self) -> BaseAggregationQuery:
raise NotImplementedError

def document(self, document_id: str = None) -> DocumentReference:
def document(self, document_id: Optional[str] = None) -> DocumentReference:
"""Create a sub-document underneath the current collection.

Args:
Expand Down Expand Up @@ -160,9 +161,9 @@ def _parent_info(self) -> Tuple[Any, str]:
def _prep_add(
self,
document_data: dict,
document_id: str = None,
retry: retries.Retry = None,
timeout: float = None,
document_id: Optional[str] = None,
Mariatta marked this conversation as resolved.
Show resolved Hide resolved
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Tuple[DocumentReference, dict]:
"""Shared setup for async / sync :method:`add`"""
if document_id is None:
Expand All @@ -176,17 +177,17 @@ def _prep_add(
def add(
self,
document_data: dict,
document_id: str = None,
retry: retries.Retry = None,
timeout: float = None,
document_id: Optional[str] = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]:
raise NotImplementedError

def _prep_list_documents(
self,
page_size: int = None,
retry: retries.Retry = None,
timeout: float = None,
page_size: Optional[int] = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Tuple[dict, dict]:
"""Shared setup for async / sync :method:`list_documents`"""
parent, _ = self._parent_info()
Expand All @@ -206,9 +207,9 @@ def _prep_list_documents(

def list_documents(
self,
page_size: int = None,
retry: retries.Retry = None,
timeout: float = None,
page_size: Optional[int] = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Union[
Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any]
]:
Expand Down Expand Up @@ -236,7 +237,14 @@ def select(self, field_paths: Iterable[str]) -> BaseQuery:
query = self._query()
return query.select(field_paths)

def where(self, field_path: str, op_string: str, value) -> BaseQuery:
def where(
self,
field_path: Optional[str] = None,
op_string: Optional[str] = None,
value=None,
*,
filter=None
) -> BaseQuery:
"""Create a "where" query with this collection as parent.

See
Expand All @@ -245,33 +253,43 @@ def where(self, field_path: str, op_string: str, value) -> BaseQuery:

Args:
field_path (str): A field path (``.``-delimited list of
field names) for the field to filter on.
field names) for the field to filter on. Optional.
op_string (str): A comparison operation in the form of a string.
Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``,
and ``in``.
and ``in``. Optional.
value (Any): The value to compare the field against in the filter.
If ``value`` is :data:`None` or a NaN, then ``==`` is the only
allowed operation. If ``op_string`` is ``in``, ``value``
must be a sequence of values.

must be a sequence of values. Optional.
filter (class:`~google.cloud.firestore_v1.base_query.BaseFilter`): an instance of a Filter.
Either a FieldFilter or a CompositeFilter.
Returns:
:class:`~google.cloud.firestore_v1.query.Query`:
A filtered query.
Raises:
ValueError, if both the positional arguments (field_path, op_string, value)
and the filter keyword argument are passed at the same time.
"""
if field_path == "__name__" and op_string == "in":
wrapped_names = []

for name in value:
query = self._query()
if field_path and op_string:
if filter is not None:
raise ValueError(
"Can't pass in both the positional arguments and 'filter' at the same time"
)
if field_path == "__name__" and op_string == "in":
wrapped_names = []

if isinstance(name, str):
name = self.document(name)
for name in value:

wrapped_names.append(name)
if isinstance(name, str):
name = self.document(name)

value = wrapped_names
wrapped_names.append(name)

query = self._query()
return query.where(field_path, op_string, value)
value = wrapped_names
return query.where(field_path, op_string, value)
else:
return query.where(filter=filter)

def order_by(self, field_path: str, **kwargs) -> BaseQuery:
"""Create an "order by" query with this collection as parent.
Expand Down Expand Up @@ -450,8 +468,8 @@ def end_at(

def _prep_get_or_stream(
self,
retry: retries.Retry = None,
timeout: float = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Tuple[Any, dict]:
"""Shared setup for async / sync :meth:`get` / :meth:`stream`"""
query = self._query()
Expand All @@ -461,19 +479,19 @@ def _prep_get_or_stream(

def get(
self,
transaction: Transaction = None,
retry: retries.Retry = None,
timeout: float = None,
transaction: Optional[Transaction] = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Union[
Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any]
]:
raise NotImplementedError

def stream(
self,
transaction: Transaction = None,
retry: retries.Retry = None,
timeout: float = None,
transaction: Optional[Transaction] = None,
retry: Optional[retries.Retry] = None,
timeout: Optional[float] = None,
) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]:
raise NotImplementedError

Expand Down