Skip to content

Commit c3b00a6

Browse files
committed
Type annotations for querysets using ClassVar
1 parent dc1f9c4 commit c3b00a6

File tree

14 files changed

+94
-22
lines changed

14 files changed

+94
-22
lines changed

example/app/users/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import ClassVar
4+
15
from plain import models
26
from plain.models import types
37
from plain.passwords.types import PasswordField
@@ -8,3 +12,5 @@ class User(models.Model):
812
email: str = types.EmailField()
913
password: str = PasswordField()
1014
is_admin: bool = types.BooleanField(default=False)
15+
16+
query: ClassVar[models.QuerySet[User]] = models.QuerySet()

plain-api/plain/api/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from __future__ import annotations
2+
13
import binascii
24
import os
35
import uuid
46
from datetime import datetime
7+
from typing import ClassVar
58
from uuid import UUID
69

710
from plain import models
@@ -26,6 +29,8 @@ class APIKey(models.Model):
2629

2730
api_version: str = types.CharField(max_length=255, required=False)
2831

32+
query: ClassVar[models.QuerySet[APIKey]] = models.QuerySet()
33+
2934
model_options = models.Options(
3035
constraints=[
3136
models.UniqueConstraint(

plain-cache/plain/cache/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from datetime import datetime
4-
from typing import Any, Self
4+
from typing import Any, ClassVar, Self
55

66
from plain import models
77
from plain.models import types
@@ -27,7 +27,7 @@ class CachedItem(models.Model):
2727
created_at: datetime = types.DateTimeField(auto_now_add=True)
2828
updated_at: datetime = types.DateTimeField(auto_now=True)
2929

30-
query = CachedItemQuerySet()
30+
query: ClassVar[CachedItemQuerySet] = CachedItemQuerySet()
3131

3232
model_options = models.Options(
3333
indexes=[

plain-flags/plain/flags/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import re
44
from datetime import datetime
5+
from typing import ClassVar
56

67
from plain import models
78
from plain.exceptions import ValidationError
@@ -21,6 +22,8 @@ class FlagResult(models.Model):
2122
key: str = types.CharField(max_length=255)
2223
value = types.JSONField()
2324

25+
query: ClassVar[models.QuerySet[FlagResult]] = models.QuerySet()
26+
2427
model_options = models.Options(
2528
constraints=[
2629
models.UniqueConstraint(
@@ -49,6 +52,8 @@ class Flag(models.Model):
4952
# To provide an easier way to see if a flag is still being used
5053
used_at: datetime | None = types.DateTimeField(required=False, allow_null=True)
5154

55+
query: ClassVar[models.QuerySet[Flag]] = models.QuerySet()
56+
5257
model_options = models.Options(
5358
constraints=[
5459
models.UniqueConstraint(

plain-jobs/plain/jobs/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
import logging
55
import traceback
6-
from typing import TYPE_CHECKING, Any, Self
6+
from typing import TYPE_CHECKING, Any, ClassVar, Self
77
from uuid import UUID, uuid4
88

99
from opentelemetry import trace
@@ -72,6 +72,8 @@ class JobRequest(models.Model):
7272

7373
# expires_at = models.DateTimeField(required=False, allow_null=True)
7474

75+
query: ClassVar[models.QuerySet[JobRequest]] = models.QuerySet()
76+
7577
model_options = models.Options(
7678
ordering=["priority", "-created_at"],
7779
indexes=[
@@ -180,7 +182,7 @@ class JobProcess(models.Model):
180182
max_length=18, required=False, allow_null=True
181183
)
182184

183-
query = JobQuerySet()
185+
query: ClassVar[JobQuerySet] = JobQuerySet()
184186

185187
model_options = models.Options(
186188
ordering=["-created_at"],
@@ -499,7 +501,7 @@ class JobResult(models.Model):
499501
max_length=18, required=False, allow_null=True
500502
)
501503

502-
query = JobResultQuerySet()
504+
query: ClassVar[JobResultQuerySet] = JobResultQuerySet()
503505

504506
model_options = models.Options(
505507
ordering=["-created_at"],

plain-models/plain/models/README.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- [Querying](#querying)
88
- [Migrations](#migrations)
99
- [Fields](#fields)
10-
- [Typed field definitions](#typed-field-definitions)
10+
- [Typing](#typing)
1111
- [Validation](#validation)
1212
- [Indexes and constraints](#indexes-and-constraints)
1313
- [Custom QuerySets](#custom-querysets)
@@ -174,7 +174,7 @@ Common field types include:
174174
- [`URLField`](./fields/__init__.py#URLField)
175175
- [`UUIDField`](./fields/__init__.py#UUIDField)
176176

177-
## Typed field definitions
177+
## Typing
178178

179179
For better IDE support and type checking, use `plain.models.types` with type annotations:
180180

@@ -203,6 +203,28 @@ author: Author = types.ForeignKey(Author, on_delete=models.CASCADE)
203203

204204
All field types from the [Fields](#fields) section are available through [`types`](./types.py). Typed and untyped fields can be mixed in the same model. The database behavior is identical - typed fields only add type checking.
205205

206+
### Typing QuerySets
207+
208+
For better type checking of query results, you can explicitly type the `query` attribute using `ClassVar`:
209+
210+
```python
211+
from __future__ import annotations
212+
213+
from typing import ClassVar
214+
215+
from plain import models
216+
from plain.models import types
217+
218+
@models.register_model
219+
class User(models.Model):
220+
email: str = types.EmailField()
221+
is_admin: bool = types.BooleanField(default=False)
222+
223+
query: ClassVar[models.QuerySet[User]] = models.QuerySet()
224+
```
225+
226+
With this annotation, type checkers will know that `User.query.get()` returns a `User` instance and `User.query.filter()` returns `QuerySet[User]`. The `ClassVar` annotation tells type checkers that `query` is a class-level attribute, not an instance field. This is optional - the query attribute works without the annotation, but adding it improves IDE autocomplete and type checking.
227+
206228
## Validation
207229

208230
Models can be validated before saving:

plain-models/plain/models/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from collections.abc import Iterable, Iterator, Sequence
66
from itertools import chain
7-
from typing import TYPE_CHECKING, Any, dataclass_transform
7+
from typing import TYPE_CHECKING, Any, ClassVar, dataclass_transform
88

99
if TYPE_CHECKING:
1010
from plain.models.meta import Meta
@@ -95,7 +95,7 @@ class Model(metaclass=ModelBase):
9595
id = PrimaryKeyField()
9696

9797
# Descriptors for other model behavior
98-
query = QuerySet()
98+
query: ClassVar[QuerySet[Model]] = QuerySet()
9999
model_options = Options()
100100
_model_meta = Meta()
101101
DoesNotExist = DoesNotExistDescriptor()

plain-models/plain/models/query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable, Iterator
1212
from functools import cached_property
1313
from itertools import chain, islice
14-
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, overload
14+
from typing import TYPE_CHECKING, Any, Generic, Never, Self, TypeVar, overload
1515

1616
import plain.runtime
1717
from plain.exceptions import ValidationError
@@ -338,6 +338,12 @@ def from_model(cls, model: type[T], query: sql.Query | None = None) -> Self:
338338
instance._deferred_filter = None
339339
return instance
340340

341+
@overload
342+
def __get__(self, instance: None, owner: type[T]) -> Self: ...
343+
344+
@overload
345+
def __get__(self, instance: Model, owner: type[T]) -> Never: ...
346+
341347
def __get__(self, instance: Any, owner: type[T]) -> Self:
342348
"""Descriptor protocol - return a new QuerySet bound to the model."""
343349
if instance is not None:

plain-oauth/plain/oauth/models.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
import datetime
2-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, ClassVar
35

46
from plain import models
57
from plain.auth import get_user_model
@@ -42,6 +44,8 @@ class OAuthConnection(models.Model):
4244
required=False, allow_null=True
4345
)
4446

47+
query: ClassVar[models.QuerySet[OAuthConnection]] = models.QuerySet()
48+
4549
model_options = models.Options(
4650
constraints=[
4751
models.UniqueConstraint(
@@ -71,13 +75,13 @@ def refresh_access_token(self) -> None:
7175
self.set_token_fields(refreshed_oauth_token)
7276
self.save()
7377

74-
def set_token_fields(self, oauth_token: "OAuthToken") -> None:
78+
def set_token_fields(self, oauth_token: OAuthToken) -> None:
7579
self.access_token = oauth_token.access_token
7680
self.refresh_token = oauth_token.refresh_token
7781
self.access_token_expires_at = oauth_token.access_token_expires_at
7882
self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
7983

80-
def set_user_fields(self, oauth_user: "OAuthUser") -> None:
84+
def set_user_fields(self, oauth_user: OAuthUser) -> None:
8185
self.provider_user_id = oauth_user.provider_id
8286

8387
def access_token_expired(self) -> bool:
@@ -94,8 +98,8 @@ def refresh_token_expired(self) -> bool:
9498

9599
@classmethod
96100
def get_or_create_user(
97-
cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
98-
) -> "OAuthConnection":
101+
cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
102+
) -> OAuthConnection:
99103
try:
100104
connection = cls.query.get(
101105
provider_key=provider_key,
@@ -129,9 +133,9 @@ def connect(
129133
*,
130134
user: Any,
131135
provider_key: str,
132-
oauth_token: "OAuthToken",
133-
oauth_user: "OAuthUser",
134-
) -> "OAuthConnection":
136+
oauth_token: OAuthToken,
137+
oauth_user: OAuthUser,
138+
) -> OAuthConnection:
135139
"""
136140
Connect will either create a new connection or update an existing connection
137141
"""

plain-observer/plain/observer/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Iterable, Mapping, Sequence
77
from datetime import UTC, datetime
88
from functools import cached_property
9-
from typing import TYPE_CHECKING, Any, cast
9+
from typing import TYPE_CHECKING, Any, ClassVar, cast
1010

1111
import sqlparse
1212
from opentelemetry.sdk.trace import ReadableSpan
@@ -66,6 +66,8 @@ class Trace(models.Model):
6666
spans: BaseRelatedManager
6767
logs: BaseRelatedManager
6868

69+
query: ClassVar[models.QuerySet[Trace]] = models.QuerySet()
70+
6971
model_options = models.Options(
7072
ordering=["-start_time"],
7173
constraints=[
@@ -335,7 +337,7 @@ class Span(models.Model):
335337
status: str = types.CharField(max_length=50, default="", required=False)
336338
span_data: dict = types.JSONField(default=dict, required=False)
337339

338-
query = SpanQuerySet()
340+
query: ClassVar[SpanQuerySet] = SpanQuerySet()
339341

340342
model_options = models.Options(
341343
ordering=["-start_time"],
@@ -523,6 +525,8 @@ class Log(models.Model):
523525
level: str = types.CharField(max_length=20)
524526
message: str = types.TextField()
525527

528+
query: ClassVar[models.QuerySet[Log]] = models.QuerySet()
529+
526530
model_options = models.Options(
527531
ordering=["timestamp"],
528532
indexes=[

0 commit comments

Comments
 (0)