Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-42685 Add a Tensor interface to analysis tools #195

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/lsst/analysis/tools/contexts/_baseContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,10 @@ class ContextApplier:
@overload
def __get__(
self, instance: AnalysisAction, klass: type[AnalysisAction] | None = None
) -> Callable[[ContextType], None]:
...
) -> Callable[[ContextType], None]: ...

@overload
def __get__(self, instance: None, klass: type[AnalysisAction] | None = None) -> ContextApplier:
...
def __get__(self, instance: None, klass: type[AnalysisAction] | None = None) -> ContextApplier: ...

def __get__(
self, instance: AnalysisAction | None, klass: type[AnalysisAction] | None = None
Expand Down
12 changes: 11 additions & 1 deletion python/lsst/analysis/tools/interfaces/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from lsst.pex.config.configurableActions import ConfigurableAction, ConfigurableActionField

from ..contexts import ContextApplier
from ._interfaces import KeyedData, KeyedDataSchema, MetricResultType, PlotResultType, Scalar, Vector
from ._interfaces import KeyedData, KeyedDataSchema, MetricResultType, PlotResultType, Scalar, Tensor, Vector


class AnalysisAction(ConfigurableAction):
Expand Down Expand Up @@ -155,6 +155,16 @@ def __call__(self, data: KeyedData, **kwargs) -> Vector:
raise NotImplementedError("This is not implemented on the base class")


class TensorAction(AnalysisAction):
"""A `TensorAction` is an `AnalysisAction` that returns a `Tensor` when
called.
"""

@abstractmethod
def __call__(self, data: KeyedData, **kwargs) -> Tensor:
raise NotImplementedError("This is not implemented on the base class")


class ScalarAction(AnalysisAction):
"""A `ScalarAction` is an `AnalysisAction` that returns a `Scalar` when
called.
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/analysis/tools/interfaces/_analysisTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

@runtime_checkable
class _HasOutputNames(Protocol):
def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]:
...
def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: ...


def _finalizeWrapper(
Expand Down Expand Up @@ -116,6 +115,7 @@ class AnalysisTool(AnalysisAction):
The stages themselves are also configurable, allowing control over various
aspects of the individual `AnalysisAction`\ s.
"""

prep = ConfigurableActionField[AnalysisAction](doc="Action to run to prepare inputs", default=BasePrep)
process = ConfigurableActionField[AnalysisAction](
doc="Action to process data into intended form", default=BaseProcess
Expand Down Expand Up @@ -177,9 +177,9 @@ def _call_single(self, data: KeyedData, **kwargs) -> KeyedResults:
kwargs["metric_tags"] = list(self.metric_tags or ())
prepped: KeyedData = self.prep(data, **kwargs) # type: ignore
processed: KeyedData = self.process(prepped, **kwargs) # type: ignore
finalized: Mapping[str, PlotTypes] | PlotTypes | Mapping[
str, Measurement
] | Measurement | JointResults = self.produce(
finalized: (
Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults
) = self.produce(
processed, **kwargs
) # type: ignore
return self._process_single_results(finalized)
Expand Down
39 changes: 35 additions & 4 deletions python/lsst/analysis/tools/interfaces/_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

__all__ = (
"Tensor",
"Scalar",
"ScalarType",
"KeyedData",
Expand All @@ -34,7 +35,7 @@

from abc import ABCMeta
from numbers import Number
from typing import Any, Iterable, Mapping, MutableMapping, TypeAlias
from typing import Any, Iterable, Mapping, MutableMapping, Protocol, TypeAlias, runtime_checkable

import numpy as np
from healsparse import HealSparseMap
Expand All @@ -43,6 +44,36 @@
from numpy.typing import NDArray


@runtime_checkable
class Tensor(Protocol):
r"""This is an interface only class and is intended to represent data that
is 2+ dimensions.

Technically one could use this for scalars or 1D arrays,
but for those the Scalar or Vector interface should be preferred.

`Tensor`\ s abstract around the idea of a multidimensional array, and work
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of the "\ s"? Is it an attempt to preserve a space? Is it needed? And will it work? I'd imagine you may need to double-escape "\" here to make this work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past this has been how you have a link (to Tensor in this case) and have an s at the end of it. You can't include it in the quotes or sphinx can't find the object. the \ is an escape sequence on the space that says don't include this in formatting.

with a variety of backends including Numpy, CuPy, Tensorflow, PyTorch,
MXNet, TVM, and mpi4py. This intentionally has a minimum interface to
comply with the industry standard dlpack which ensures each of these
backend native types will work.

To ensure that a `Tensor` is in a desired container (e.g. ndarray) one can
call the corresponding ``from_dlpack`` method. Whenever possible this will
be a zero copy action. For instance to work with a Tensor named
``input_tensor`` as if it were a numpy object, one would do
``image = np.from_dlpack(input_tensor)``.
"""

ndim: int
shape: tuple[int, ...]
strides: tuple[int, ...]

def __dlpack__(self, /, *, stream: int | None = ...) -> Any: ...

def __dlpack_device__(self) -> tuple[int, int]: ...


class ScalarMeta(ABCMeta):
def __instancecheck__(cls: ABCMeta, instance: Any) -> Any:
return isinstance(instance, tuple(cls.mro()[1:]))
Expand Down Expand Up @@ -72,18 +103,18 @@ def __init__(self) -> None:
like an NDArray should be considered a Vector.
"""

KeyedData = MutableMapping[str, Vector | Scalar | HealSparseMap]
KeyedData = MutableMapping[str, Vector | Scalar | HealSparseMap | Tensor]
"""KeyedData is an interface where either a `Vector` or `Scalar` can be
retrieved using a key which is of str type.
"""

KeyedDataTypes = MutableMapping[str, type[Vector] | ScalarType | type[HealSparseMap]]
KeyedDataTypes = MutableMapping[str, type[Vector] | ScalarType | type[HealSparseMap] | type[Tensor]]
r"""A mapping of str keys to the Types which are valid in `KeyedData` objects.
This is useful in conjunction with `AnalysisAction`\ 's ``getInputSchema`` and
``getOutputSchema`` methods.
"""

KeyedDataSchema = Iterable[tuple[str, type[Vector] | ScalarType | type[HealSparseMap]]]
KeyedDataSchema = Iterable[tuple[str, type[Vector] | ScalarType | type[HealSparseMap] | type[Tensor]]]
r"""An interface that represents a type returned by `AnalysisAction`\ 's
``getInputSchema`` and ``getOutputSchema`` methods.
"""
Expand Down
32 changes: 24 additions & 8 deletions python/lsst/analysis/tools/interfaces/_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any, cast

import astropy.units as apu
from healsparse import HealSparseMap
from lsst.pex.config import ListField
from lsst.pex.config.configurableActions import ConfigurableActionStructField
from lsst.pex.config.dictField import DictField
Expand All @@ -38,6 +39,7 @@
MetricAction,
MetricResultType,
NoPlot,
Tensor,
VectorAction,
)
from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector
Expand All @@ -46,19 +48,27 @@
class BasePrep(KeyedDataAction):
"""Base class for actions which prepare data for processing."""

vectorKeys = ListField[str](doc="Keys to extract from KeyedData and return", default=[])
keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[])

vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[])

selectors = ConfigurableActionStructField[VectorAction](
doc="Selectors for selecting rows, will be AND together",
)

def getInputSchema(self) -> KeyedDataSchema:
yield from ((column, Vector | Scalar) for column in self.vectorKeys) # type: ignore
yield from (
(column, Vector | Scalar | HealSparseMap | Tensor)
for column in set(self.keysToLoad).union(self.vectorKeys)
)
for action in self.selectors:
yield from action.getInputSchema()

def getOutputSchema(self) -> KeyedDataSchema:
return ((column, Vector | Scalar) for column in self.vectorKeys) # type: ignore
return (
(column, Vector | Scalar | HealSparseMap | Tensor)
for column in set(self.keysToLoad).union(self.vectorKeys)
)

def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
mask: Vector | None = None
Expand All @@ -69,16 +79,22 @@ def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
else:
mask *= subMask # type: ignore
result: dict[str, Any] = {}
for key in self.vectorKeys:
for key in set(self.keysToLoad).union(self.vectorKeys):
formattedKey = key.format_map(kwargs)
result[formattedKey] = cast(Vector, data[formattedKey])
if mask is not None:
return {key: cast(Vector, col)[mask] for key, col in result.items()}
else:
return result
for key in self.vectorKeys:
# ignore type since there is not fully proper mypy support for
# vector type casting. In the future there will be, and this
# makes it clearer now what type things should be.
result[key] = cast(Vector, result[key])[mask] # type: ignore
return result

def addInputSchema(self, inputSchema: KeyedDataSchema) -> None:
self.vectorKeys = [name for name, _ in inputSchema]
existing = list(self.keysToLoad)
for name, _ in inputSchema:
existing.append(name)
self.keysToLoad = existing


class BaseProcess(KeyedDataAction):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
max-line-length = 110
max-doc-length = 79
ignore = E133, E226, E228, N802, N803, N806, N812, N813, N815, N816, W503, E203
ignore = E133, E226, E228, N802, N803, N806, N812, N813, N815, N816, W503, E203, E704
exclude =
./bin,
doc/conf.py,
Expand Down