Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-43769: Improve Pydantic support for sphgeom.Region and Timespan #997

Merged
merged 10 commits into from
Apr 13, 2024
1 change: 1 addition & 0 deletions doc/changes/DM-43769.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `Timespan` a Pydantic model and add a `SerializableRegion` type alias that allows `lsst.sphgeom.Region` to be used directly as a Pydantic model field.
144 changes: 51 additions & 93 deletions python/lsst/daf/butler/_timespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,17 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = (
"SerializedTimespan",
"Timespan",
)
__all__ = ("Timespan",)

import enum
import warnings
from collections.abc import Generator
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, TypeAlias
from typing import Any, ClassVar, TypeAlias

import astropy.time
import astropy.utils.exceptions
import pydantic
import yaml
from pydantic import Field

# As of astropy 4.2, the erfa interface is shipped independently and
# ErfaWarning is no longer an AstropyWarning
Expand All @@ -50,14 +47,8 @@

from lsst.utils.classes import cached_getter

from .json import from_json_generic, to_json_generic
from .time_utils import TimeConverter

if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
from .dimensions import DimensionUniverse
from .registry import Registry


_ONE_DAY = astropy.time.TimeDelta("1d", scale="tai")


Expand All @@ -76,13 +67,8 @@

TimespanBound: TypeAlias = astropy.time.Time | _SpecialTimespanBound | None

SerializedTimespan = Annotated[list[int], Field(min_length=2, max_length=2)]
"""JSON-serializable representation of the Timespan class, as a list of two
integers ``[begin, end]`` in nanoseconds since the epoch.
"""


class Timespan:
class Timespan(pydantic.BaseModel):
"""A half-open time interval with nanosecond precision.

Parameters
Expand Down Expand Up @@ -196,14 +182,22 @@
# here simplifies all other operations (including interactions
# with TimespanDatabaseRepresentation implementations).
_nsec = (converter.max_nsec, converter.min_nsec)
self._nsec = _nsec
super().__init__(nsec=_nsec)

__slots__ = ("_nsec", "_cached_begin", "_cached_end")
nsec: tuple[int, int] = pydantic.Field(frozen=True)

model_config = pydantic.ConfigDict(
json_schema_extra={
"description": (
"A [begin, end) TAI timespan with bounds as integer nanoseconds since 1970-01-01 00:00:00."
)
}
)

EMPTY: ClassVar[_SpecialTimespanBound] = _SpecialTimespanBound.EMPTY

# YAML tag name for Timespan
yaml_tag = "!lsst.daf.butler.Timespan"
yaml_tag: ClassVar[str] = "!lsst.daf.butler.Timespan"

@classmethod
def makeEmpty(cls) -> Timespan:
Expand Down Expand Up @@ -294,10 +288,10 @@
"""
if self.isEmpty():
return self.EMPTY
elif self._nsec[0] == TimeConverter().min_nsec:
elif self.nsec[0] == TimeConverter().min_nsec:
return None
else:
return TimeConverter().nsec_to_astropy(self._nsec[0])
return TimeConverter().nsec_to_astropy(self.nsec[0])

@property
@cached_getter
Expand All @@ -310,14 +304,14 @@
"""
if self.isEmpty():
return self.EMPTY
elif self._nsec[1] == TimeConverter().max_nsec:
elif self.nsec[1] == TimeConverter().max_nsec:
return None
else:
return TimeConverter().nsec_to_astropy(self._nsec[1])
return TimeConverter().nsec_to_astropy(self.nsec[1])

def isEmpty(self) -> bool:
"""Test whether ``self`` is the empty timespan (`bool`)."""
return self._nsec[0] >= self._nsec[1]
return self.nsec[0] >= self.nsec[1]

def __str__(self) -> str:
if self.isEmpty():
Expand Down Expand Up @@ -356,15 +350,15 @@
return False
# Correctness of this simple implementation depends on __init__
# standardizing all empty timespans to a single value.
return self._nsec == other._nsec
return self.nsec == other.nsec

def __hash__(self) -> int:
# Correctness of this simple implementation depends on __init__
# standardizing all empty timespans to a single value.
return hash(self._nsec)
return hash(self.nsec)

def __reduce__(self) -> tuple:
return (Timespan, (None, None, False, self._nsec))
return (Timespan, (None, None, False, self.nsec))

def __lt__(self, other: astropy.time.Time | Timespan) -> bool:
"""Test if a Timespan's bounds are strictly less than the given time.
Expand All @@ -389,9 +383,9 @@
# first term is also false.
if isinstance(other, astropy.time.Time):
nsec = TimeConverter().astropy_to_nsec(other)
return self._nsec[1] <= nsec and self._nsec[0] < nsec
return self.nsec[1] <= nsec and self.nsec[0] < nsec
else:
return self._nsec[1] <= other._nsec[0] and self._nsec[0] < other._nsec[1]
return self.nsec[1] <= other.nsec[0] and self.nsec[0] < other.nsec[1]

def __gt__(self, other: astropy.time.Time | Timespan) -> bool:
"""Test if a Timespan's bounds are strictly greater than given time.
Expand All @@ -416,9 +410,9 @@
# first term is also false.
if isinstance(other, astropy.time.Time):
nsec = TimeConverter().astropy_to_nsec(other)
return self._nsec[0] > nsec and self._nsec[1] > nsec
return self.nsec[0] > nsec and self.nsec[1] > nsec
else:
return self._nsec[0] >= other._nsec[1] and self._nsec[1] > other._nsec[0]
return self.nsec[0] >= other.nsec[1] and self.nsec[1] > other.nsec[0]

def overlaps(self, other: Timespan | astropy.time.Time) -> bool:
"""Test if the intersection of this Timespan with another is empty.
Expand All @@ -442,7 +436,7 @@
"""
if isinstance(other, astropy.time.Time):
return self.contains(other)
return self._nsec[1] > other._nsec[0] and other._nsec[1] > self._nsec[0]
return self.nsec[1] > other.nsec[0] and other.nsec[1] > self.nsec[0]

def contains(self, other: astropy.time.Time | Timespan) -> bool:
"""Test if the supplied timespan is within this one.
Expand Down Expand Up @@ -475,9 +469,9 @@
"""
if isinstance(other, astropy.time.Time):
nsec = TimeConverter().astropy_to_nsec(other)
return self._nsec[0] <= nsec and self._nsec[1] > nsec
return self.nsec[0] <= nsec and self.nsec[1] > nsec
else:
return self._nsec[0] <= other._nsec[0] and self._nsec[1] >= other._nsec[1]
return self.nsec[0] <= other.nsec[0] and self.nsec[1] >= other.nsec[1]

def intersection(self, *args: Timespan) -> Timespan:
"""Return a new `Timespan` that is contained by all of the given ones.
Expand All @@ -494,10 +488,10 @@
"""
if not args:
return self
lowers = [self._nsec[0]]
lowers.extend(ts._nsec[0] for ts in args)
uppers = [self._nsec[1]]
uppers.extend(ts._nsec[1] for ts in args)
lowers = [self.nsec[0]]
lowers.extend(ts.nsec[0] for ts in args)
uppers = [self.nsec[1]]
uppers.extend(ts.nsec[1] for ts in args)
nsec = (max(*lowers), min(*uppers))
return Timespan(begin=None, end=None, _nsec=nsec)

Expand Down Expand Up @@ -527,59 +521,10 @@
elif intersection == self:
yield from ()
else:
if intersection._nsec[0] > self._nsec[0]:
yield Timespan(None, None, _nsec=(self._nsec[0], intersection._nsec[0]))
if intersection._nsec[1] < self._nsec[1]:
yield Timespan(None, None, _nsec=(intersection._nsec[1], self._nsec[1]))

def to_simple(self, minimal: bool = False) -> SerializedTimespan:
"""Return simple python type form suitable for serialization.

Parameters
----------
minimal : `bool`, optional
Use minimal serialization. Has no effect on for this class.

Returns
-------
simple : `list` of `int`
The internal span as integer nanoseconds.
"""
# Return the internal nanosecond form rather than astropy ISO string
return list(self._nsec)

@classmethod
def from_simple(
cls,
simple: SerializedTimespan | None,
universe: DimensionUniverse | None = None,
registry: Registry | None = None,
) -> Timespan | None:
"""Construct a new object from simplified form.

Designed to use the data returned from the `to_simple` method.

Parameters
----------
simple : `list` of `int`, or `None`
The values returned by `to_simple()`.
universe : `DimensionUniverse`, optional
Unused.
registry : `lsst.daf.butler.Registry`, optional
Unused.

Returns
-------
result : `Timespan` or `None`
Newly-constructed object.
"""
if simple is None:
return None
nsec1, nsec2 = simple # for mypy
return cls(begin=None, end=None, _nsec=(nsec1, nsec2))

to_json = to_json_generic
from_json: ClassVar = classmethod(from_json_generic)
if intersection.nsec[0] > self.nsec[0]:
yield Timespan(None, None, _nsec=(self.nsec[0], intersection.nsec[0]))
if intersection.nsec[1] < self.nsec[1]:
yield Timespan(None, None, _nsec=(intersection.nsec[1], self.nsec[1]))

@classmethod
def to_yaml(cls, dumper: yaml.Dumper, timespan: Timespan) -> Any:
Expand Down Expand Up @@ -627,6 +572,19 @@
d = loader.construct_mapping(node)
return Timespan(d["begin"], d["end"])

@pydantic.model_validator(mode="before")
@classmethod
def _validate(cls, value: Any) -> Any:
if isinstance(value, Timespan):
return value

Check warning on line 579 in python/lsst/daf/butler/_timespan.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/_timespan.py#L579

Added line #L579 was not covered by tests
if isinstance(value, dict):
return value
return {"nsec": value}

@pydantic.model_serializer(mode="plain")
def _serialize(self) -> tuple[int, int]:
return self.nsec


# Register Timespan -> YAML conversion method with Dumper class
yaml.Dumper.add_representer(Timespan, Timespan.to_yaml)
Expand Down
4 changes: 1 addition & 3 deletions python/lsst/daf/butler/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ def data_type(self) -> pa.DataType:

def append(self, value: Timespan | None, column: list[pa.StructScalar | None]) -> None:
# Docstring inherited.
column.append(
{"begin_nsec": value._nsec[0], "end_nsec": value._nsec[1]} if value is not None else None
)
column.append({"begin_nsec": value.nsec[0], "end_nsec": value.nsec[1]} if value is not None else None)

def finish(self, column: list[Any]) -> pa.Array:
# Docstring inherited.
Expand Down
59 changes: 29 additions & 30 deletions python/lsst/daf/butler/dimensions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
field_type = field_type | None # type: ignore
members[field.name] = (field_type, ...)
if definition.temporal:
members["timespan"] = (tuple[int, int] | None, ...) # type: ignore
members["timespan"] = (Timespan | None, ...) # type: ignore
if definition.spatial:
members["region"] = (str, ...)

Expand Down Expand Up @@ -154,7 +154,7 @@
)

# Use strict types to prevent casting
record: dict[str, None | StrictBool | StrictInt | StrictFloat | StrictStr | tuple[int, int]] = Field(
record: dict[str, None | StrictBool | StrictInt | StrictFloat | StrictStr | Timespan] = Field(
...,
title="Dimension record keys and values.",
examples=[
Expand All @@ -178,7 +178,7 @@
cls,
*,
definition: str,
record: dict[str, None | StrictFloat | StrictStr | StrictBool | StrictInt | tuple[int, int]],
record: dict[str, Any],
) -> SerializedDimensionRecord:
"""Construct a `SerializedDimensionRecord` directly without validators.

Expand All @@ -204,9 +204,12 @@
"""
# This method requires tuples as values of the mapping, but JSON
# readers will read things in as lists. Be kind and transparently
# transform to tuples
# transform to tuples.
_recItems = {
k: v if type(v) != list else tuple(v) for k, v in record.items() # type: ignore # noqa: E721
k: (
v if type(v) is not list else Timespan(begin=None, end=None, _nsec=tuple(v)) # noqa: E721
) # type: ignore
for k, v in record.items()
}

# Type ignore because the ternary statement seems to confuse mypy
Expand Down Expand Up @@ -374,21 +377,17 @@
return result

mapping = {name: getattr(self, name) for name in self.__slots__}
# If the item in mapping supports simplification update it
for k, v in mapping.items():
try:
mapping[k] = v.to_simple(minimal=minimal)
except AttributeError:
if isinstance(v, lsst.sphgeom.Region):
# YAML serialization specifies the class when it
# doesn't have to. This is partly for explicitness
# and also history. Here use a different approach.
# This code needs to be migrated to sphgeom
mapping[k] = v.encode().hex()
if isinstance(v, bytes):
# We actually can't handle serializing out to bytes for
# hash objects, encode it here to a hex string
mapping[k] = v.hex()
if isinstance(v, lsst.sphgeom.Region):
# YAML serialization specifies the class when it
# doesn't have to. This is partly for explicitness
# and also history. Here use a different approach.
# This code needs to be migrated to sphgeom
mapping[k] = v.encode().hex()
if isinstance(v, bytes):
# We actually can't handle serializing out to bytes for
# hash objects, encode it here to a hex string
mapping[k] = v.hex()

Check warning on line 390 in python/lsst/daf/butler/dimensions/_records.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/dimensions/_records.py#L390

Added line #L390 was not covered by tests
definition = self.definition.to_simple(minimal=minimal)
dimRec = SerializedDimensionRecord(definition=definition, record=mapping)
if cache is not None:
Expand Down Expand Up @@ -454,19 +453,19 @@
record_model_cls = _createSimpleRecordSubclass(definition)
record_model = record_model_cls(**simple.record)

# Timespan and region have to be converted to native form
# for now assume that those keys are special
rec = record_model.model_dump()
# Region and hash have to be converted to native form; for now assume
# that the keys are special. We make the mapping we need to pass to
# the DimensionRecord constructor via getattr, because we don't
# model_dump re-disassembling things like Timespans that we've already
# assembled.
mapping = {k: getattr(record_model, k) for k in definition.schema.names}

if (ts := "timespan") in rec:
rec[ts] = Timespan.from_simple(rec[ts], universe=universe, registry=registry)
if (reg := "region") in rec:
encoded = bytes.fromhex(rec[reg])
rec[reg] = lsst.sphgeom.Region.decode(encoded)
if (hsh := "hash") in rec:
rec[hsh] = bytes.fromhex(rec[hsh].decode())
if "region" in mapping:
mapping["region"] = lsst.sphgeom.Region.decode(bytes.fromhex(mapping["region"]))
if "hash" in mapping:
mapping["hash"] = bytes.fromhex(mapping["hash"].decode())

Check warning on line 466 in python/lsst/daf/butler/dimensions/_records.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/dimensions/_records.py#L466

Added line #L466 was not covered by tests

dimRec = _reconstructDimensionRecord(definition, rec)
dimRec = _reconstructDimensionRecord(definition, mapping)
if cache is not None:
cache[key] = dimRec
return dimRec
Expand Down