Skip to content

Commit d1c730d

Browse files
authored
fix(types): improve typing (#1136)
A few types were missing, which makes it hard to use Firestore with strict typechecking.
1 parent d637aee commit d1c730d

File tree

9 files changed

+39
-18
lines changed

9 files changed

+39
-18
lines changed

google/cloud/firestore_v1/async_batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.api_core import retry_async as retries
2020

2121
from google.cloud.firestore_v1.base_batch import BaseWriteBatch
22+
from google.cloud.firestore_v1.types.write import WriteResult
2223

2324

2425
class AsyncWriteBatch(BaseWriteBatch):
@@ -40,7 +41,7 @@ async def commit(
4041
self,
4142
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
4243
timeout: float | None = None,
43-
) -> list:
44+
) -> list[WriteResult]:
4445
"""Commit the changes accumulated in this batch.
4546
4647
Args:

google/cloud/firestore_v1/async_client.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@
2525
"""
2626
from __future__ import annotations
2727

28-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union
28+
from typing import (
29+
TYPE_CHECKING,
30+
Any,
31+
AsyncGenerator,
32+
Iterable,
33+
List,
34+
Optional,
35+
Union,
36+
)
2937

3038
from google.api_core import gapic_v1
3139
from google.api_core import retry_async as retries
@@ -40,6 +48,7 @@
4048
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
4149
from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore
4250
from google.cloud.firestore_v1.base_client import _CLIENT_INFO, BaseClient, _path_helper
51+
from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS
4352
from google.cloud.firestore_v1.field_path import FieldPath
4453
from google.cloud.firestore_v1.services.firestore import (
4554
async_client as firestore_client,
@@ -410,7 +419,9 @@ def batch(self) -> AsyncWriteBatch:
410419
"""
411420
return AsyncWriteBatch(self)
412421

413-
def transaction(self, **kwargs) -> AsyncTransaction:
422+
def transaction(
423+
self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False
424+
) -> AsyncTransaction:
414425
"""Get a transaction that uses this client.
415426
416427
See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for
@@ -426,4 +437,4 @@ def transaction(self, **kwargs) -> AsyncTransaction:
426437
:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`:
427438
A transaction attached to this client.
428439
"""
429-
return AsyncTransaction(self, **kwargs)
440+
return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only)

google/cloud/firestore_v1/async_collection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Classes for representing collections for the Google Cloud Firestore API."""
1616
from __future__ import annotations
1717

18-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple
18+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, cast
1919

2020
from google.api_core import gapic_v1
2121
from google.api_core import retry_async as retries
@@ -153,7 +153,8 @@ def document(self, document_id: str | None = None) -> AsyncDocumentReference:
153153
:class:`~google.cloud.firestore_v1.document.async_document.AsyncDocumentReference`:
154154
The child document.
155155
"""
156-
return super(AsyncCollectionReference, self).document(document_id)
156+
doc = super(AsyncCollectionReference, self).document(document_id)
157+
return cast("AsyncDocumentReference", doc)
157158

158159
async def list_documents(
159160
self,

google/cloud/firestore_v1/base_batch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Helpers for batch requests to the Google Cloud Firestore API."""
1616
from __future__ import annotations
1717
import abc
18-
from typing import Dict, Union
18+
from typing import Any, Dict, Union
1919

2020
# Types needed only for Type Hints
2121
from google.api_core import retry as retries
@@ -67,7 +67,9 @@ def commit(self):
6767
write depend on the implementing class."""
6868
raise NotImplementedError()
6969

70-
def create(self, reference: BaseDocumentReference, document_data: dict) -> None:
70+
def create(
71+
self, reference: BaseDocumentReference, document_data: dict[str, Any]
72+
) -> None:
7173
"""Add a "change" to this batch to create a document.
7274
7375
If the document given by ``reference`` already exists, then this
@@ -120,7 +122,7 @@ def set(
120122
def update(
121123
self,
122124
reference: BaseDocumentReference,
123-
field_updates: dict,
125+
field_updates: dict[str, Any],
124126
option: _helpers.WriteOption | None = None,
125127
) -> None:
126128
"""Add a "change" to update a document.

google/cloud/firestore_v1/base_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
DocumentSnapshot,
5858
)
5959
from google.cloud.firestore_v1.base_query import BaseQuery
60-
from google.cloud.firestore_v1.base_transaction import BaseTransaction
60+
from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS, BaseTransaction
6161
from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions
6262
from google.cloud.firestore_v1.field_path import render_field_path
6363
from google.cloud.firestore_v1.services.firestore import client as firestore_client
@@ -497,7 +497,9 @@ def collections(
497497
def batch(self) -> BaseWriteBatch:
498498
raise NotImplementedError
499499

500-
def transaction(self, **kwargs) -> BaseTransaction:
500+
def transaction(
501+
self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False
502+
) -> BaseTransaction:
501503
raise NotImplementedError
502504

503505

google/cloud/firestore_v1/base_collection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from google.api_core import retry as retries
3636

3737
from google.cloud.firestore_v1 import _helpers
38+
from google.cloud.firestore_v1.base_document import BaseDocumentReference
3839
from google.cloud.firestore_v1.base_query import QueryType
3940

4041
if TYPE_CHECKING: # pragma: NO COVER
@@ -133,7 +134,7 @@ def _aggregation_query(self) -> BaseAggregationQuery:
133134
def _vector_query(self) -> BaseVectorQuery:
134135
raise NotImplementedError
135136

136-
def document(self, document_id: Optional[str] = None):
137+
def document(self, document_id: Optional[str] = None) -> BaseDocumentReference:
137138
"""Create a sub-document underneath the current collection.
138139
139140
Args:

google/cloud/firestore_v1/base_document.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def _client(self):
418418
return self._reference._client
419419

420420
@property
421-
def exists(self):
421+
def exists(self) -> bool:
422422
"""Existence flag.
423423
424424
Indicates if the document existed at the time this snapshot
@@ -430,7 +430,7 @@ def exists(self):
430430
return self._exists
431431

432432
@property
433-
def id(self):
433+
def id(self) -> str:
434434
"""The document identifier (within its collection).
435435
436436
Returns:
@@ -439,7 +439,7 @@ def id(self):
439439
return self._reference.id
440440

441441
@property
442-
def reference(self):
442+
def reference(self) -> BaseDocumentReference:
443443
"""Document reference corresponding to document that owns this data.
444444
445445
Returns:

google/cloud/firestore_v1/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
# Types needed only for Type Hints
4141
from google.cloud.firestore_v1.base_document import DocumentSnapshot
42+
from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS
4243
from google.cloud.firestore_v1.batch import WriteBatch
4344
from google.cloud.firestore_v1.collection import CollectionReference
4445
from google.cloud.firestore_v1.document import DocumentReference
@@ -391,7 +392,9 @@ def batch(self) -> WriteBatch:
391392
"""
392393
return WriteBatch(self)
393394

394-
def transaction(self, **kwargs) -> Transaction:
395+
def transaction(
396+
self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False
397+
) -> Transaction:
395398
"""Get a transaction that uses this client.
396399
397400
See :class:`~google.cloud.firestore_v1.transaction.Transaction` for
@@ -407,4 +410,4 @@ def transaction(self, **kwargs) -> Transaction:
407410
:class:`~google.cloud.firestore_v1.transaction.Transaction`:
408411
A transaction attached to this client.
409412
"""
410-
return Transaction(self, **kwargs)
413+
return Transaction(self, max_attempts=max_attempts, read_only=read_only)

google/cloud/firestore_v1/document.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def set(
169169

170170
def update(
171171
self,
172-
field_updates: dict,
172+
field_updates: dict[str, Any],
173173
option: _helpers.WriteOption | None = None,
174174
retry: retries.Retry | object | None = gapic_v1.method.DEFAULT,
175175
timeout: float | None = None,

0 commit comments

Comments
 (0)