Skip to content

Commit 7635258

Browse files
committed
Add type annotations to CursorWrapper fetch methods
Explicit fetchone, fetchmany, and fetchall methods with return type annotations so type checkers know the return types. Also fixes several places that assumed fetchone() always returns a row.
1 parent 0cb5a84 commit 7635258

File tree

6 files changed

+34
-16
lines changed

6 files changed

+34
-16
lines changed

plain-models/plain/models/backends/base/features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def supports_transactions(self) -> bool:
184184
self.connection.rollback()
185185
self.connection.set_autocommit(True)
186186
cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST")
187-
(count,) = cursor.fetchone()
187+
row = cursor.fetchone()
188+
assert row is not None
189+
(count,) = row
188190
cursor.execute("DROP TABLE ROLLBACK_TEST")
189191
return count == 0

plain-models/plain/models/backends/mysql/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def mysql_server_data(self) -> dict[str, Any]:
357357
"""
358358
)
359359
row = cursor.fetchone()
360+
assert row is not None
360361
return {
361362
"version": row[0],
362363
"sql_mode": row[1],

plain-models/plain/models/backends/sqlite3/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def disable_constraint_checking(self) -> bool:
230230
# Foreign key constraints cannot be turned off while in a multi-
231231
# statement transaction. Fetch the current state of the pragma
232232
# to determine if constraints are effectively disabled.
233-
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
233+
row = cursor.execute("PRAGMA foreign_keys").fetchone()
234+
assert row is not None
235+
enabled = row[0]
234236
return not bool(enabled)
235237

236238
def enable_constraint_checking(self) -> None:
@@ -273,10 +275,12 @@ def check_constraints(self, table_names: list[str] | None = None) -> None:
273275
assert primary_key_column_name is not None, (
274276
f"Table {table_name} must have a primary key"
275277
)
276-
primary_key_value, bad_value = cursor.execute(
278+
row = cursor.execute(
277279
f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
278280
(rowid,),
279281
).fetchone()
282+
assert row is not None
283+
primary_key_value, bad_value = row
280284
raise IntegrityError(
281285
f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
282286
f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "

plain-models/plain/models/backends/sqlite3/introspection.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,11 @@ def get_constraints(
338338
"""
339339
constraints: dict[str, dict[str, Any]] = {}
340340
# Find inline check constraints.
341-
try:
342-
table_schema = cursor.execute(
343-
f"SELECT sql FROM sqlite_master WHERE type='table' and name={self.connection.ops.quote_name(table_name)}"
344-
).fetchone()[0]
345-
except TypeError:
346-
# table_name is a view.
347-
pass
348-
else:
341+
row = cursor.execute(
342+
f"SELECT sql FROM sqlite_master WHERE type='table' and name={self.connection.ops.quote_name(table_name)}"
343+
).fetchone()
344+
if row is not None:
345+
table_schema = row[0]
349346
columns = {
350347
info.name for info in self.get_table_description(cursor, table_name)
351348
}

plain-models/plain/models/backends/sqlite3/schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def alter_field(
168168
# that don't affect the on-disk content.
169169
# https://sqlite.org/lang_altertable.html#otheralter
170170
with self.connection.cursor() as cursor:
171-
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
172-
0
173-
]
171+
row = cursor.execute("PRAGMA schema_version").fetchone()
172+
assert row is not None
173+
schema_version = row[0]
174174
cursor.execute("PRAGMA writable_schema = 1")
175175
references_template = f' REFERENCES "{table_name}" ("%s") '
176176
new_column_name = new_field.get_attname_column()[1]

plain-models/plain/models/backends/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def fetchone(self) -> tuple[Any, ...] | None:
6161
"""Fetch the next row of a query result set."""
6262
...
6363

64-
def fetchmany(self, size: int = 0) -> list[tuple[Any, ...]]:
64+
def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
6565
"""Fetch the next set of rows of a query result set."""
6666
...
6767

@@ -79,7 +79,7 @@ def __init__(self, cursor: DBAPICursor, db: BaseDatabaseWrapper) -> None:
7979
self.cursor = cursor
8080
self.db = db
8181

82-
WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
82+
WRAP_ERROR_ATTRS = frozenset(["nextset"])
8383

8484
def __getattr__(self, attr: str) -> Any:
8585
cursor_attr = getattr(self.cursor, attr)
@@ -92,6 +92,20 @@ def __iter__(self) -> Iterator[tuple[Any, ...]]:
9292
with self.db.wrap_database_errors:
9393
yield from self.cursor
9494

95+
def fetchone(self) -> tuple[Any, ...] | None:
96+
with self.db.wrap_database_errors:
97+
return self.cursor.fetchone()
98+
99+
def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
100+
with self.db.wrap_database_errors:
101+
if size is None:
102+
return self.cursor.fetchmany()
103+
return self.cursor.fetchmany(size)
104+
105+
def fetchall(self) -> list[tuple[Any, ...]]:
106+
with self.db.wrap_database_errors:
107+
return self.cursor.fetchall()
108+
95109
def __enter__(self) -> Self:
96110
return self
97111

0 commit comments

Comments
 (0)