-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce configurable actions for DataFrames
This commit introduces a base class for configurable actions that works on Data Frames, basic concrete actions, and some supporting functions.
- Loading branch information
Showing
4 changed files
with
316 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._actions import * | ||
from ._baseDataFrameActions import DataFrameAction | ||
from ._evalColumnExpression import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("SingleColumnAction", "MultiColumnAction", "CoordColumn", "MagColumnDN", "SumColumns", "AddColumn", | ||
"DivideColumns", "SubtractColumns", "MultiplyColumns",) | ||
|
||
from typing import Iterable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from astropy import units | ||
|
||
from ..configurableActions import ConfigurableActionStructField, ConfigurableActionField | ||
from ._baseDataFrameActions import DataFrameAction | ||
from ._evalColumnExpression import makeColumnExpressionAction | ||
|
||
from lsst.pex.config import Field | ||
|
||
|
||
class SingleColumnAction(DataFrameAction): | ||
column = Field(doc="Column to load for this action", dtype=str, optional=False) | ||
|
||
@property | ||
def columns(self) -> Iterable[str]: | ||
return (self.column, ) | ||
|
||
def __call__(self, df, **kwargs): | ||
return df[self.column] | ||
|
||
|
||
class MultiColumnAction(DataFrameAction): | ||
actions = ConfigurableActionStructField(doc="Configurable actions to use in a joint action") | ||
|
||
@property | ||
def columns(self) -> Iterable[str]: | ||
yield from (column for action in self.actions for column in action.columns) | ||
|
||
|
||
class CoordColumn(SingleColumnAction): | ||
inRadians = Field(doc="Return the column in radians if true", default=True, dtype=bool) | ||
|
||
def __call__(self, df): | ||
col = super().__call__(df) | ||
return col * 180 / np.pi if self.inRadians else col | ||
|
||
|
||
class MagColumnDN(SingleColumnAction): | ||
coadd_zeropoint = Field(doc="Magnitude zero point", dtype=float, default=27) | ||
|
||
def __call__(self, df: pd.DataFrame, **kwargs): | ||
if not (fluxMag0 := kwargs.get('fluxMag0')): | ||
fluxMag0 = 1/np.power(10, -0.4*self.coadd_zeropoint) | ||
|
||
with np.warnings.catch_warnings(): | ||
np.warnings.filterwarnings('ignore', r'invalid value encountered') | ||
np.warnings.filterwarnings('ignore', r'divide by zero') | ||
return -2.5 * np.log10(df[self.column] / fluxMag0) | ||
|
||
|
||
class NanoJansky(SingleColumnAction): | ||
ab_flux_scale = Field(doc="Scaling of ab flux", dtype=float, default=(0*units.ABmag).to_value(units.nJy)) | ||
coadd_zeropoint = Field(doc="Magnitude zero point", dtype=float, default=27) | ||
|
||
def __call__(self, df, **kwargs): | ||
dataNumber = super().__call__(df, **kwargs) | ||
if not (fluxMag0 := kwargs.get('fluxMag0')): | ||
fluxMag0 = 1/np.power(10, -0.4*self.coadd_zeropoint) | ||
return self.ab_flux_scale * dataNumber / fluxMag0 | ||
|
||
def setDefaults(self): | ||
super().setDefaults() | ||
self.cache = True # cache this action for future calls | ||
|
||
|
||
class NanoJanskyErr(SingleColumnAction): | ||
flux_mag_err = Field(doc="Error in the magnitude zeropoint", dtype=float, default=0) | ||
flux_action = ConfigurableActionField(doc="Action to use if flux is not provided to the call method", | ||
default=NanoJansky, dtype=DataFrameAction) | ||
|
||
@property | ||
def columns(self): | ||
yield from zip((self.column,), self.flux_action.columns) | ||
|
||
def __call__(self, df, flux_column=None, flux_mag_err=None, **kwargs): | ||
if flux_column is None: | ||
flux_column = self.flux_action(df, **kwargs) | ||
if flux_mag_err is None: | ||
flux_mag_err = self.flux_mag_err | ||
|
||
|
||
_docs = """This is a `DataFrameAction` that is designed to add two columns | ||
together and return the result. | ||
""" | ||
SumColumns = makeColumnExpressionAction("SumColumns", "colA+colB", | ||
exprDefaults={"colA": SingleColumnAction, | ||
"colB": SingleColumnAction}, | ||
docstring=_docs) | ||
|
||
_docs = """This is a `MultiColumnAction` that is designed to subtract two columns | ||
together and return the result. | ||
""" | ||
SubtractColumns = makeColumnExpressionAction("SubtractColumns", "colA-colB", | ||
exprDefaults={"colA": SingleColumnAction, | ||
"colB": SingleColumnAction}, | ||
docstring=_docs) | ||
|
||
_docs = """This is a `MultiColumnAction` that is designed to multiply two columns | ||
together and return the result. | ||
""" | ||
MultiplyColumns = makeColumnExpressionAction("MultiplyColumns", "colA*colB", | ||
exprDefaults={"colA": SingleColumnAction, | ||
"colB": SingleColumnAction}, | ||
docstring=_docs) | ||
|
||
_docs = """This is a `MultiColumnAction` that is designed to multiply two columns | ||
together and return the result. | ||
""" | ||
DivideColumns = makeColumnExpressionAction("DivideColumns", "colA/colB", | ||
exprDefaults={"colA": SingleColumnAction, | ||
"colB": SingleColumnAction}, | ||
docstring=_docs) | ||
|
||
|
||
class AddColumn(DataFrameAction): | ||
aggregator = ConfigurableActionField(doc="This is an instance of a Dataframe action that will be used " | ||
"to create a new column", dtype=DataFrameAction) | ||
newColumn = Field(doc="Name of the new column to add", dtype=str) | ||
|
||
@property | ||
def columns(self) -> Iterable[str]: | ||
yield from self.aggregator.columns | ||
|
||
def __call__(self, df, **kwargs) -> pd.DataFrame: | ||
# do your calculation and and | ||
df[self.newColumn] = self.aggregator(df, kwargs) | ||
return df |
56 changes: 56 additions & 0 deletions
56
python/lsst/pipe/tasks/dataFrameActions/_baseDataFrameActions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("DataFrameAction",) | ||
|
||
from lsst.pex.config import Field, ListField | ||
from typing import Iterable, Any, Mapping | ||
|
||
from ..configurableActions import ConfigurableAction | ||
|
||
|
||
class DataFrameAction(ConfigurableAction): | ||
_actionCache: Mapping[int, Any] | ||
|
||
cache = Field(doc="Controls if the results of this action should be cached," | ||
" only works on frozen actions", | ||
dtype=bool, default=False) | ||
cacheArgs = ListField(doc="If cache is True, this is a list of argument keys that will be used to " | ||
"compute the cache key in addition to the DataFrameId", dtype=str) | ||
|
||
def __init_subclass__(cls, **kwargs) -> None: | ||
cls._actionCache = {} | ||
|
||
def call_wrapper(function): | ||
def inner_wrapper(self, dataFrame, **kwargs): | ||
dfId = id(dataFrame) | ||
extra = [] | ||
for name in (self.cacheArgs or tuple()): | ||
if name not in kwargs: | ||
raise ValueError(f"{name} is not part of call signature and cant be used for " | ||
"caching") | ||
extra.append(kwargs[name]) | ||
extra.append(dfId) | ||
key = tuple(extra) | ||
if self.cache and self._frozen: | ||
# look up to see if the value is in cache already | ||
if result := self._actionCache.get(key): | ||
return result | ||
result = function(self, dataFrame, **kwargs) | ||
if self.cache and self._frozen: | ||
self._actionCache[key] = result | ||
return result | ||
return inner_wrapper | ||
cls.__call__ = call_wrapper(cls.__call__) | ||
super().__init_subclass__(**kwargs) | ||
|
||
def __call__(self, dataFrame, **kwargs) -> Iterable[Any]: | ||
"""This method should return the result of an action performed on a | ||
dataframe | ||
""" | ||
raise NotImplementedError("This method should be overloaded in a subclass") | ||
|
||
@property | ||
def columns(self) -> Iterable[str]: | ||
"""This property should return an iterable of columns needed by this action | ||
""" | ||
raise NotImplementedError("This method should be overloaded in a subclass") |
122 changes: 122 additions & 0 deletions
122
python/lsst/pipe/tasks/dataFrameActions/_evalColumnExpression.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# This file is part of pipe_tasks. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (https://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ("makeColumnExpressionAction", ) | ||
|
||
import ast | ||
import operator as op | ||
|
||
from typing import Mapping, MutableMapping, Set, Type, Union, Optional, Any | ||
|
||
from numpy import log10 as log | ||
from numpy import (cos, sin, cosh, sinh) | ||
import pandas as pd | ||
|
||
from ..configurableActions import ConfigurableActionField | ||
from ._baseDataFrameActions import DataFrameAction | ||
|
||
|
||
OPERATORS = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, | ||
ast.Div: op.truediv, ast.Pow: op.pow, ast.BitXor: op.xor, | ||
ast.USub: op.neg} | ||
|
||
EXTRA_MATH = {"cos": cos, "sin": sin, "cosh": cosh, "sinh": sinh, "log": log} | ||
|
||
|
||
class ExpressionParser(ast.NodeVisitor): | ||
def __init__(self, **kwargs): | ||
self.variables = kwargs | ||
self.variables['log'] = log | ||
|
||
def visit_Name(self, node): | ||
if node.id in self.variables: | ||
return self.variables[node.id] | ||
else: | ||
return None | ||
|
||
def visit_Num(self, node): | ||
return node.n | ||
|
||
def visit_NameConstant(self, node): | ||
return node.value | ||
|
||
def visit_UnaryOp(self, node): | ||
val = self.visit(node.operand) | ||
return OPERATORS[type(node.op)](val) | ||
|
||
def visit_BinOp(self, node): | ||
lhs = self.visit(node.left) | ||
rhs = self.visit(node.right) | ||
return OPERATORS[type(node.op)](lhs, rhs) | ||
|
||
def visit_Call(self, node): | ||
if node.func.id in self.variables: | ||
function = self.visit(node.func) | ||
return function(self.visit(node.args[0])) | ||
else: | ||
raise ValueError("String not recognized") | ||
|
||
def generic_visit(self, node): | ||
raise ValueError("String not recognized") | ||
|
||
|
||
def makeColumnExpressionAction(className: str, expr: str, | ||
exprDefaults: Optional[Mapping[str, Union[DataFrameAction, | ||
Type[DataFrameAction]]]] = None, | ||
docstring: str = None | ||
) -> Type[DataFrameAction]: | ||
node = ast.parse(expr, mode='eval') | ||
|
||
# gather the specified names | ||
names: Set[str] = set() | ||
for elm in ast.walk(node): | ||
if isinstance(elm, ast.Name): | ||
names.add(elm.id) | ||
|
||
# remove the known Math names | ||
names -= EXTRA_MATH.keys() | ||
|
||
fields: Mapping[str, ConfigurableActionField] = {} | ||
for name in names: | ||
if exprDefaults is not None and (value := exprDefaults.get(name)) is not None: | ||
kwargs = {"default": value} | ||
else: | ||
kwargs = {} | ||
fields[name] = ConfigurableActionField(doc=f"expression action {name}", **kwargs) | ||
|
||
# skip flake8 on N807 because this is a stand alone function, but it is | ||
# intended to be patched in as a method on a dynamically generated class | ||
def __call__(self, df: pd.DataFrame, **kwargs) -> pd.Series: # noqa: N807 | ||
values_map = {} | ||
for name in fields: | ||
values_map[name] = getattr(self, name)(df, **kwargs) | ||
|
||
parser = ExpressionParser(**values_map) | ||
return parser.visit(node.body) | ||
|
||
dct: MutableMapping[str, Any] = {"__call__": __call__} | ||
if docstring is not None: | ||
dct['__doc__'] = docstring | ||
dct.update(**fields) | ||
|
||
return type(className, (DataFrameAction, ), dct) |