Skip to content

Commit a43145e

Browse files
committed
Type annotations for plain-models
1 parent 5b4bdf4 commit a43145e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+740
-414
lines changed

plain-models/plain/models/aggregates.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
from typing import TYPE_CHECKING, Any
44

55
from plain.models.exceptions import FieldError, FullResultSet
6-
from plain.models.expressions import Case, Func, Star, Value, When
6+
from plain.models.expressions import (
7+
Case,
8+
Func,
9+
ResolvableExpression,
10+
Star,
11+
Value,
12+
When,
13+
)
714
from plain.models.fields import IntegerField
815
from plain.models.functions.comparison import Coalesce
916
from plain.models.functions.mixins import (
@@ -99,7 +106,7 @@ def resolve_expression(
99106
)
100107
if (default := c.default) is None:
101108
return c
102-
if hasattr(default, "resolve_expression"):
109+
if isinstance(default, ResolvableExpression):
103110
default = default.resolve_expression(query, allow_joins, reuse, summarize)
104111
if default._output_field_or_none is None:
105112
default.output_field = c._output_field_or_none
@@ -132,7 +139,7 @@ def as_sql(
132139
if self.filter is not None:
133140
if connection.features.supports_aggregate_filter_clause:
134141
try:
135-
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
142+
filter_sql, filter_params = self.filter.as_sql(compiler, connection) # type: ignore[possibly-missing-attribute]
136143
except FullResultSet:
137144
pass
138145
else:

plain-models/plain/models/backends/base/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ class BaseDatabaseWrapper(ABC):
4848
data_types_suffix: dict[str, str] = {}
4949
# Mapping of Field objects to their SQL for CHECK constraints.
5050
data_type_check_constraints: dict[str, str] = {}
51+
# Mapping of lookup operators to SQL templates (defined on backend subclasses)
52+
operators: dict[str, str]
53+
# Mapping of pattern lookup operators to SQL templates using str.format syntax (defined on backend subclasses)
54+
pattern_ops: dict[str, str]
55+
# SQL template for escaping patterns in LIKE queries using str.format syntax (defined on backend subclasses)
56+
pattern_esc: str
5157
# Instance attributes - always assigned in __init__
5258
ops: BaseDatabaseOperations
5359
client: BaseDatabaseClient
@@ -268,7 +274,7 @@ def ensure_connection(self) -> None:
268274

269275
# ##### Backend-specific wrappers for PEP-249 connection methods #####
270276

271-
def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
277+
def _prepare_cursor(self, cursor: utils.DBAPICursor) -> utils.CursorWrapper:
272278
"""
273279
Validate the connection is usable and perform database cursor wrapping.
274280
"""
@@ -460,7 +466,7 @@ def set_autocommit(
460466
)
461467

462468
if start_transaction_under_autocommit:
463-
self._start_transaction_under_autocommit()
469+
self._start_transaction_under_autocommit() # type: ignore[attr-defined]
464470
elif autocommit:
465471
self._set_autocommit(autocommit)
466472
else:
@@ -629,11 +635,11 @@ def chunked_cursor(self) -> utils.CursorWrapper:
629635
"""
630636
return self.cursor()
631637

632-
def make_debug_cursor(self, cursor: Any) -> utils.CursorDebugWrapper:
638+
def make_debug_cursor(self, cursor: utils.DBAPICursor) -> utils.CursorDebugWrapper:
633639
"""Create a cursor that logs all queries in self.queries_log."""
634640
return utils.CursorDebugWrapper(cursor, self)
635641

636-
def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
642+
def make_cursor(self, cursor: utils.DBAPICursor) -> utils.CursorWrapper:
637643
"""Create a cursor without debug logging."""
638644
return utils.CursorWrapper(cursor, self)
639645

plain-models/plain/models/backends/base/introspection.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import Generator
55
from typing import TYPE_CHECKING, Any, NamedTuple
66

7+
from plain.models.backends.utils import CursorWrapper
8+
79
if TYPE_CHECKING:
810
from plain.models.backends.base.base import BaseDatabaseWrapper
911

@@ -55,15 +57,17 @@ def identifier_converter(self, name: str) -> str:
5557
"""
5658
return name
5759

58-
def table_names(self, cursor: Any = None, include_views: bool = False) -> list[str]:
60+
def table_names(
61+
self, cursor: CursorWrapper | None = None, include_views: bool = False
62+
) -> list[str]:
5963
"""
6064
Return a list of names of all tables that exist in the database.
6165
Sort the returned table list by Python's default sorting. Do NOT use
6266
the database's ORDER BY here to avoid subtle differences in sorting
6367
order between databases.
6468
"""
6569

66-
def get_names(cursor: Any) -> list[str]:
70+
def get_names(cursor: CursorWrapper) -> list[str]:
6771
return sorted(
6872
ti.name
6973
for ti in self.get_table_list(cursor)
@@ -76,15 +80,17 @@ def get_names(cursor: Any) -> list[str]:
7680
return get_names(cursor)
7781

7882
@abstractmethod
79-
def get_table_list(self, cursor: Any) -> list[TableInfo]:
83+
def get_table_list(self, cursor: CursorWrapper) -> list[TableInfo]:
8084
"""
8185
Return an unsorted list of TableInfo named tuples of all tables and
8286
views that exist in the database.
8387
"""
8488
...
8589

8690
@abstractmethod
87-
def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
91+
def get_table_description(
92+
self, cursor: CursorWrapper, table_name: str
93+
) -> list[FieldInfo]:
8894
"""
8995
Return a description of the table with the DB-API cursor.description
9096
interface.
@@ -146,7 +152,7 @@ def sequence_list(self) -> list[dict[str, Any]]:
146152

147153
@abstractmethod
148154
def get_sequences(
149-
self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
155+
self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
150156
) -> list[dict[str, Any]]:
151157
"""
152158
Return a list of introspected sequences for table_name. Each sequence
@@ -156,21 +162,27 @@ def get_sequences(
156162
...
157163

158164
@abstractmethod
159-
def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
165+
def get_relations(
166+
self, cursor: CursorWrapper, table_name: str
167+
) -> dict[str, tuple[str, str]]:
160168
"""
161169
Return a dictionary of {field_name: (field_name_other_table, other_table)}
162170
representing all foreign keys in the given table.
163171
"""
164172
...
165173

166-
def get_primary_key_column(self, cursor: Any, table_name: str) -> str | None:
174+
def get_primary_key_column(
175+
self, cursor: CursorWrapper, table_name: str
176+
) -> str | None:
167177
"""
168178
Return the name of the primary key column for the given table.
169179
"""
170180
columns = self.get_primary_key_columns(cursor, table_name)
171181
return columns[0] if columns else None
172182

173-
def get_primary_key_columns(self, cursor: Any, table_name: str) -> list[str] | None:
183+
def get_primary_key_columns(
184+
self, cursor: CursorWrapper, table_name: str
185+
) -> list[str] | None:
174186
"""Return a list of primary key columns for the given table."""
175187
for constraint in self.get_constraints(cursor, table_name).values():
176188
if constraint["primary_key"]:
@@ -179,7 +191,7 @@ def get_primary_key_columns(self, cursor: Any, table_name: str) -> list[str] | N
179191

180192
@abstractmethod
181193
def get_constraints(
182-
self, cursor: Any, table_name: str
194+
self, cursor: CursorWrapper, table_name: str
183195
) -> dict[str, dict[str, Any]]:
184196
"""
185197
Retrieve any constraints or keys (unique, pk, fk, check, index)

plain-models/plain/models/backends/base/operations.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import sqlparse
1212

1313
from plain.models.backends import utils
14+
from plain.models.backends.utils import CursorWrapper
1415
from plain.models.db import NotSupportedError
16+
from plain.models.expressions import ResolvableExpression
1517
from plain.utils import timezone
1618
from plain.utils.encoding import force_str
1719

@@ -223,7 +225,9 @@ def distinct_sql(
223225
else:
224226
return ["DISTINCT"], []
225227

226-
def fetch_returned_insert_columns(self, cursor: Any, returning_params: Any) -> Any:
228+
def fetch_returned_insert_columns(
229+
self, cursor: CursorWrapper, returning_params: Any
230+
) -> Any:
227231
"""
228232
Given a cursor object that has just performed an INSERT...RETURNING
229233
statement into a table, return the newly created data.
@@ -287,7 +291,7 @@ def limit_offset_sql(self, low_mark: int | None, high_mark: int | None) -> str:
287291

288292
def last_executed_query(
289293
self,
290-
cursor: Any,
294+
cursor: CursorWrapper,
291295
sql: str,
292296
params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
293297
) -> str:
@@ -315,7 +319,9 @@ def to_string(s: Any) -> str:
315319

316320
return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
317321

318-
def last_insert_id(self, cursor: Any, table_name: str, pk_name: str) -> int:
322+
def last_insert_id(
323+
self, cursor: CursorWrapper, table_name: str, pk_name: str
324+
) -> int:
319325
"""
320326
Given a cursor object that has just performed an INSERT statement into
321327
a table that has an auto-incrementing ID, return the newly created ID.
@@ -396,7 +402,7 @@ def bulk_insert_sql(
396402
...
397403

398404
@abstractmethod
399-
def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
405+
def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
400406
"""
401407
Given a cursor object that has just performed an INSERT...RETURNING
402408
statement into a table, return the list of returned data.
@@ -521,7 +527,7 @@ def adapt_datetimefield_value(
521527
if value is None:
522528
return None
523529
# Expression values are adapted by the database.
524-
if hasattr(value, "resolve_expression"):
530+
if isinstance(value, ResolvableExpression):
525531
return value
526532

527533
return str(value)
@@ -536,7 +542,7 @@ def adapt_timefield_value(
536542
if value is None:
537543
return None
538544
# Expression values are adapted by the database.
539-
if hasattr(value, "resolve_expression"):
545+
if isinstance(value, ResolvableExpression):
540546
return value
541547

542548
if timezone.is_aware(value):

plain-models/plain/models/backends/mysql/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from functools import cached_property
10-
from typing import Any, cast
10+
from typing import TYPE_CHECKING, Any, cast
1111

1212
import MySQLdb as Database
1313
from MySQLdb.constants import CLIENT, FIELD_TYPE
@@ -19,11 +19,14 @@
1919
from plain.models.db import IntegrityError
2020
from plain.utils.regex_helper import _lazy_re_compile
2121

22-
# With mysqlclient stubs, we can now type the connection
23-
try:
22+
# Type checkers always see the proper type; runtime falls back to Any if needed
23+
if TYPE_CHECKING:
2424
from MySQLdb.connections import Connection as MySQLConnection
25-
except ImportError:
26-
MySQLConnection: type[Any] = Any # type: ignore[misc]
25+
else:
26+
try:
27+
from MySQLdb.connections import Connection as MySQLConnection
28+
except ImportError:
29+
MySQLConnection = Any # type: ignore[misc]
2730

2831
from .client import DatabaseClient
2932
from .creation import DatabaseCreation

plain-models/plain/models/backends/mysql/introspection.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from plain.models.backends.base.introspection import (
99
BaseDatabaseIntrospection,
1010
)
11+
from plain.models.backends.utils import CursorWrapper
1112
from plain.models.indexes import Index
1213
from plain.utils.datastructures import OrderedSet
1314

@@ -106,7 +107,7 @@ def get_field_type(self, data_type: Any, description: Any) -> str:
106107
return "JSONField"
107108
return field_type
108109

109-
def get_table_list(self, cursor: Any) -> list[TableInfo]:
110+
def get_table_list(self, cursor: CursorWrapper) -> list[TableInfo]:
110111
"""Return a list of table and view names in the current database."""
111112
cursor.execute(
112113
"""
@@ -119,11 +120,13 @@ def get_table_list(self, cursor: Any) -> list[TableInfo]:
119120
"""
120121
)
121122
return [
122-
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
123+
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1], "t"), row[2])
123124
for row in cursor.fetchall()
124125
]
125126

126-
def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
127+
def get_table_description(
128+
self, cursor: CursorWrapper, table_name: str
129+
) -> list[FieldInfo]:
127130
"""
128131
Return a description of the table with the DB-API cursor.description
129132
interface."
@@ -216,15 +219,17 @@ def to_int(i: Any) -> Any:
216219
return fields
217220

218221
def get_sequences(
219-
self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
222+
self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
220223
) -> list[dict[str, Any]]:
221224
for field_info in self.get_table_description(cursor, table_name):
222225
if "auto_increment" in field_info.extra:
223226
# MySQL allows only one auto-increment column per table.
224227
return [{"table": table_name, "column": field_info.name}]
225228
return []
226229

227-
def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
230+
def get_relations(
231+
self, cursor: CursorWrapper, table_name: str
232+
) -> dict[str, tuple[str, str]]:
228233
"""
229234
Return a dictionary of {field_name: (field_name_other_table, other_table)}
230235
representing all foreign keys in the given table.
@@ -245,7 +250,7 @@ def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, st
245250
for field_name, other_field, other_table in cursor.fetchall()
246251
}
247252

248-
def get_storage_engine(self, cursor: Any, table_name: str) -> str:
253+
def get_storage_engine(self, cursor: CursorWrapper, table_name: str) -> str:
249254
"""
250255
Retrieve the storage engine for a given table. Return the default
251256
storage engine if the table doesn't exist.
@@ -281,7 +286,7 @@ def _parse_constraint_columns(
281286
return check_columns
282287

283288
def get_constraints(
284-
self, cursor: Any, table_name: str
289+
self, cursor: CursorWrapper, table_name: str
285290
) -> dict[str, dict[str, Any]]:
286291
"""
287292
Retrieve any constraints or keys (unique, pk, fk, check, index) across

plain-models/plain/models/backends/mysql/operations.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from typing import TYPE_CHECKING, Any
77

88
from plain.models.backends.base.operations import BaseDatabaseOperations
9-
from plain.models.backends.utils import split_tzname_delta
9+
from plain.models.backends.utils import CursorWrapper, split_tzname_delta
1010
from plain.models.constants import OnConflict
11-
from plain.models.expressions import Exists, ExpressionWrapper
11+
from plain.models.expressions import Exists, ExpressionWrapper, ResolvableExpression
1212
from plain.models.lookups import Lookup
1313
from plain.utils import timezone
1414
from plain.utils.encoding import force_str
@@ -191,7 +191,7 @@ def time_trunc_sql(
191191
else:
192192
return f"TIME({sql})", params
193193

194-
def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
194+
def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
195195
"""
196196
Given a cursor object that has just performed an INSERT...RETURNING
197197
statement into a table, return the tuple of returned data.
@@ -217,7 +217,9 @@ def adapt_decimalfield_value(
217217
) -> Any:
218218
return value
219219

220-
def last_executed_query(self, cursor: Any, sql: str, params: Any) -> str | None:
220+
def last_executed_query(
221+
self, cursor: CursorWrapper, sql: str, params: Any
222+
) -> str | None:
221223
# With MySQLdb, cursor objects have an (undocumented) "_executed"
222224
# attribute where the exact query sent to the database is saved.
223225
# See MySQLdb/cursors.py in the source distribution.
@@ -260,7 +262,7 @@ def adapt_datetimefield_value(
260262
return None
261263

262264
# Expression values are adapted by the database.
263-
if hasattr(value, "resolve_expression"):
265+
if isinstance(value, ResolvableExpression):
264266
return value
265267

266268
# MySQL doesn't support tz-aware datetimes
@@ -275,7 +277,7 @@ def adapt_timefield_value(
275277
return None
276278

277279
# Expression values are adapted by the database.
278-
if hasattr(value, "resolve_expression"):
280+
if isinstance(value, ResolvableExpression):
279281
return value
280282

281283
# MySQL doesn't support tz-aware times

0 commit comments

Comments
 (0)