Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-33085: Specify explicit caching option for SQLAlchemy classes #624

Merged
merged 1 commit into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/lsst/daf/butler/core/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class Base64Region(Base64Bytes):
Maps Python `sphgeom.Region` to a base64-encoded `sqlalchemy.String`.
"""

cache_ok = True # have to be set explicitly in each class

def process_bind_param(
self, value: Optional[Region], dialect: sqlalchemy.engine.Dialect
) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class _Replace(sqlalchemy.sql.Insert):
on the primary key constraint for the table.
"""

pass
inherit_cache = True # make it cacheable


# SQLite and PostgreSQL use similar syntax for their ON CONFLICT extension,
Expand Down Expand Up @@ -95,7 +95,7 @@ class _Ensure(sqlalchemy.sql.Insert):
``INSERT ... ON CONFLICT DO NOTHING``.
"""

pass
inherit_cache = True # make it cacheable


@sqlalchemy.ext.compiler.compiles(_Ensure, "sqlite")
Expand Down
74 changes: 49 additions & 25 deletions python/lsst/daf/butler/registry/queries/expressions/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from lsst.utils.iteration import ensure_iterable
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import func
from sqlalchemy.sql.visitors import InternalTraversal

from ....core import (
Dimension,
Expand Down Expand Up @@ -125,54 +126,77 @@ class ExpressionTypeError(TypeError):
"""


class _TimestampColumnElement(sqlalchemy.sql.ColumnElement):
"""Special ColumnElement type used for TIMESTAMP columns or literals in
expressions.
class _TimestampLiteral(sqlalchemy.sql.ColumnElement):
"""Special ColumnElement type used for TIMESTAMP literals in expressions.

SQLite stores timestamps as strings which sometimes can cause issues when
comparing strings. For more reliable comparison SQLite needs DATETIME()
wrapper for those strings. For PostgreSQL it works better if we add
TIMESTAMP to string literals.
"""

inherit_cache = True
_traverse_internals = [("_literal", InternalTraversal.dp_plain_obj)]

def __init__(self, literal: datetime):
super().__init__()
self._literal = literal


@compiles(_TimestampLiteral, "sqlite")
def compile_timestamp_literal_sqlite(element: Any, compiler: Any, **kw: Mapping[str, Any]) -> str:
"""Compilation of TIMESTAMP literal for SQLite.

SQLite defines ``datetiem`` function that can be used to convert timestamp
value to Unix seconds.
"""
return compiler.process(func.datetime(sqlalchemy.sql.literal(element._literal)), **kw)


@compiles(_TimestampLiteral, "postgresql")
def compile_timestamp_literal_pg(element: Any, compiler: Any, **kw: Mapping[str, Any]) -> str:
"""Compilation of TIMESTAMP literal for PostgreSQL.

For PostgreSQL it works better if we add TIMESTAMP to string literals.
"""
literal = element._literal.isoformat(sep=" ", timespec="microseconds")
return "TIMESTAMP " + compiler.process(sqlalchemy.sql.literal(literal), **kw)


class _TimestampColumnElement(sqlalchemy.sql.ColumnElement):
"""Special ColumnElement type used for TIMESTAMP columns or in expressions.

SQLite stores timestamps as strings which sometimes can cause issues when
comparing strings. For more reliable comparison SQLite needs DATETIME()
wrapper for columns.

This mechanism is only used for expressions in WHERE clause, values of the
TIMESTAMP columns returned from queries are still handled by standard
mechanism and they are converted to `datetime` instances.
"""

def __init__(
self, column: Optional[sqlalchemy.sql.ColumnElement] = None, literal: Optional[datetime] = None
):
inherit_cache = True
_traverse_internals = [("_column", InternalTraversal.dp_clauseelement)]

def __init__(self, column: sqlalchemy.sql.ColumnElement):
super().__init__()
self._column = column
self._literal = literal


@compiles(_TimestampColumnElement, "sqlite")
def compile_timestamp_sqlite(element: Any, compiler: Any, **kw: Mapping[str, Any]) -> str:
"""Compilation of TIMESTAMP column for SQLite.

SQLite defines ``strftime`` function that can be used to convert timestamp
SQLite defines ``datetime`` function that can be used to convert timestamp
value to Unix seconds.
"""
assert element._column is not None or element._literal is not None, "Must have column or literal"
if element._column is not None:
return compiler.process(func.datetime(element._column), **kw)
else:
return compiler.process(func.datetime(sqlalchemy.sql.literal(element._literal)), **kw)
return compiler.process(func.datetime(element._column), **kw)


@compiles(_TimestampColumnElement, "postgresql")
def compile_timestamp_pg(element: Any, compiler: Any, **kw: Mapping[str, Any]) -> str:
"""Compilation of TIMESTAMP column for PostgreSQL.

PostgreSQL can use `EXTRACT(epoch FROM timestamp)` function.
"""
assert element._column is not None or element._literal is not None, "Must have column or literal"
if element._column is not None:
return compiler.process(element._column, **kw)
else:
literal = element._literal.isoformat(sep=" ", timespec="microseconds")
return "TIMESTAMP " + compiler.process(sqlalchemy.sql.literal(literal), **kw)
"""Compilation of TIMESTAMP column for PostgreSQL."""
return compiler.process(element._column, **kw)


class WhereClauseConverter(ABC):
Expand Down Expand Up @@ -323,7 +347,7 @@ def fromLiteral(cls, value: Any) -> ScalarWhereClauseConverter:
"""
dtype = type(value)
if dtype is datetime:
column = _TimestampColumnElement(literal=value)
column = _TimestampLiteral(value)
else:
column = sqlalchemy.sql.literal(value, type_=ddl.AstropyTimeNsecTai if dtype is Time else None)
return cls(column, value, dtype)
Expand Down Expand Up @@ -1034,7 +1058,7 @@ def visitIdentifier(self, name: str, node: Node) -> WhereClauseConverter:
assert self.columns.datasets is not None
assert self.columns.datasets.ingestDate is not None, "dataset.ingest_date is not in the query"
return ScalarWhereClauseConverter.fromExpression(
_TimestampColumnElement(column=self.columns.datasets.ingestDate),
_TimestampColumnElement(self.columns.datasets.ingestDate),
datetime,
)
elif constant is ExpressionConstant.NULL:
Expand Down