Skip to content

Commit

Permalink
Unify checks for write access in database operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Oct 16, 2020
1 parent 75baaf0 commit 9ae1c2a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
8 changes: 3 additions & 5 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import psycopg2
import sqlalchemy.dialects.postgresql

from ..interfaces import Database, ReadOnlyDatabaseError
from ..interfaces import Database
from ..nameShrinker import NameShrinker
from ...core import DatabaseTimespanRepresentation, ddl, Timespan, time_utils

Expand Down Expand Up @@ -144,8 +144,7 @@ def getTimespanRepresentation(cls) -> Type[DatabaseTimespanRepresentation]:
return _RangeTimespanRepresentation

def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to replace into read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
if not rows:
return
# This uses special support for UPSERT in PostgreSQL backend:
Expand All @@ -163,8 +162,7 @@ def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:

def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
# Docstring inherited.
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to esnure into read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
if not rows:
return 0
# Like `replace`, this uses UPSERT, but it's a bit simpler because
Expand Down
9 changes: 4 additions & 5 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import sqlalchemy
import sqlalchemy.ext.compiler

from ..interfaces import Database, ReadOnlyDatabaseError, StaticTablesContext
from ..interfaces import Database, StaticTablesContext
from ...core import ddl


Expand Down Expand Up @@ -365,6 +365,7 @@ def insert(self, table: sqlalchemy.schema.Table, *rows: dict, returnIds: bool =
select: Optional[sqlalchemy.sql.Select] = None,
names: Optional[Iterable[str]] = None,
) -> Optional[List[int]]:
self.assertTableWriteable(table, f"Cannot insert into read-only table {table}.")
autoincr = self._autoincr.get(table.name)
if autoincr is not None:
if select is not None:
Expand Down Expand Up @@ -423,8 +424,7 @@ def insert(self, table: sqlalchemy.schema.Table, *rows: dict, returnIds: bool =
return super().insert(table, *rows, select=select, names=names, returnIds=returnIds)

def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to replace into read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
if not rows:
return
if table.name in self._autoincr:
Expand All @@ -434,8 +434,7 @@ def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
self._connection.execute(_Replace(table), *rows)

def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to ensure into read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
if not rows:
return 0
if table.name in self._autoincr:
Expand Down
46 changes: 37 additions & 9 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,37 @@ def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
"""
raise NotImplementedError()

def isTableWriteable(self, table: sqlalchemy.schema.Table) -> bool:
"""Check whether a table is writeable, either because the database
connection is read-write or the table is a temporary table.
Parameters
----------
table : `sqlalchemy.schema.Table`
SQLAlchemy table object to check.
Returns
-------
writeable : `bool`
Whether this table is writeable.
"""
return self.isWriteable() or table.key in self._tempTables

def assertTableWriteable(self, table: sqlalchemy.schema.Table, msg: str) -> None:
"""Raise if the given table is not writeable, either because the
database connection is read-write or the table is a temporary table.
Parameters
----------
table : `sqlalchemy.schema.Table`
SQLAlchemy table object to check.
msg : `str`, optional
If provided, raise `ReadOnlyDatabaseError` instead of returning
`False`, with this message.
"""
if not self.isTableWriteable(table):
raise ReadOnlyDatabaseError(msg)

@contextmanager
def declareStaticTables(self, *, create: bool) -> Iterator[StaticTablesContext]:
"""Return a context manager in which the database's static DDL schema
Expand Down Expand Up @@ -1071,9 +1102,9 @@ def safeNotEqual(a: Any, b: Any) -> bool:
toReturn = None
return 1, inconsistencies, toReturn

if self.isWriteable() or table.key in self._tempTables:
# Database is writeable. Try an insert first, but allow it to fail
# (in only specific ways).
if self.isTableWriteable(table):
# Try an insert first, but allow it to fail (in only specific
# ways).
row = keys.copy()
if compared is not None:
row.update(compared)
Expand Down Expand Up @@ -1174,8 +1205,7 @@ def insert(self, table: sqlalchemy.schema.Table, *rows: dict, returnIds: bool =
May be used inside transaction contexts, so implementations may not
perform operations that interrupt transactions.
"""
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to insert into read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot insert into read-only table {table}.")
if select is not None and (rows or returnIds):
raise TypeError("'select' is incompatible with passing value rows or returnIds=True.")
if not rows and select is None:
Expand Down Expand Up @@ -1299,8 +1329,7 @@ def delete(self, table: sqlalchemy.schema.Table, columns: Iterable[str], *rows:
The default implementation should be sufficient for most derived
classes.
"""
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to delete from read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot delete from read-only table {table}.")
if columns and not rows:
# If there are no columns, this operation is supposed to delete
# everything (so we proceed as usual). But if there are columns,
Expand Down Expand Up @@ -1351,8 +1380,7 @@ def update(self, table: sqlalchemy.schema.Table, where: Dict[str, str], *rows: d
The default implementation should be sufficient for most derived
classes.
"""
if not (self.isWriteable() or table.key in self._tempTables):
raise ReadOnlyDatabaseError(f"Attempt to update read-only database '{self}'.")
self.assertTableWriteable(table, f"Cannot update read-only table {table}.")
if not rows:
return 0
sql = table.update().where(
Expand Down

0 comments on commit 9ae1c2a

Please sign in to comment.