Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38514: Extend obscore query method to return fewer columns #811

Merged
merged 2 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/lsst/daf/butler/registry/interfaces/_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,17 @@ def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[i

@abstractmethod
@contextmanager
def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]:
def query(
self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
) -> Iterator[sqlalchemy.engine.CursorResult]:
"""Run a SELECT query against obscore table and return result rows.

Parameters
----------
columns : `~collections.abc.Iterable` [`str`]
Columns to return from query. It is a sequence which can include
column names or any other column elements (e.g.
`sqlalchemy.sql.functions.count` function).
**kwargs
Restriction on values of individual obscore columns. Key is the
column name, value is the required value of the column. Multiple
Expand All @@ -248,10 +254,5 @@ def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]:
result_context : `sqlalchemy.engine.CursorResult`
Context manager that returns the query result object when entered.
These results are invalidated when the context is exited.

Notes
-----
This method is intended mostly for tests that need to check the
contents of obscore table.
"""
raise NotImplementedError()
26 changes: 15 additions & 11 deletions python/lsst/daf/butler/registry/obscore/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,21 @@ def update_exposure_regions(self, instrument: str, region_data: Iterable[tuple[i
return count

@contextmanager
def query(self, **kwargs: Any) -> Iterator[sqlalchemy.engine.CursorResult]:
"""Run a SELECT query against obscore table and return result rows.

Parameters
----------
**kwargs
Restriction on values of individual obscore columns. Key is the
column name, value is the required value of the column. Multiple
restrictions are ANDed together.
"""
query = self.table.select()
def query(
self, columns: Iterable[str | sqlalchemy.sql.expression.ColumnElement] | None = None, /, **kwargs: Any
) -> Iterator[sqlalchemy.engine.CursorResult]:
# Docstring inherited from base class.
if columns is not None:
column_elements: list[sqlalchemy.sql.ColumnElement] = []
for column in columns:
if isinstance(column, str):
column_elements.append(self.table.columns[column])
else:
column_elements.append(column)
query = sqlalchemy.sql.select(*column_elements).select_from(self.table)
else:
query = self.table.select()

if kwargs:
query = query.where(
sqlalchemy.sql.expression.and_(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ def test_insert_existing_collection(self):
rows = list(result)
self.assertEqual(len(rows), count)

# Also check `query` method with COUNT(*)
with obscore.query([sqlalchemy.sql.func.count()]) as result:
scalar = result.scalar_one()
self.assertEqual(scalar, count)

def test_drop_datasets(self):
"""Test for dropping datasets after obscore insert."""

Expand Down Expand Up @@ -479,7 +484,7 @@ def test_update_exposure_region(self) -> None:
)
self.assertEqual(count, 2)

with obscore.query() as result:
with obscore.query(["s_ra", "s_dec", "s_region", "lsst_detector"]) as result:
rows = list(result)
self.assertEqual(len(rows), 4)
for row in rows:
Expand Down