Skip to content

Commit

Permalink
implement and test Functor use with DeferredDatasetHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
timothydmorton committed Nov 9, 2020
2 parents c161528 + 4e6e500 commit e148ebb
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 35 deletions.
176 changes: 141 additions & 35 deletions python/lsst/pipe/tasks/functors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import yaml
import re
from itertools import product

import pandas as pd
import numpy as np
import astropy.units as u

from lsst.daf.persistence import doImport
from .parquetTable import MultilevelParquetTable
from lsst.daf.butler import DeferredDatasetHandle
from .parquetTable import ParquetTable, MultilevelParquetTable


def init_fromDict(initDict, basePath="lsst.pipe.tasks.functors", typeKey="functor", name=None):
Expand Down Expand Up @@ -50,7 +52,8 @@ def init_fromDict(initDict, basePath="lsst.pipe.tasks.functors", typeKey="functo
class Functor(object):
"""Define and execute a calculation on a ParquetTable
The `__call__` method accepts a `ParquetTable` object, and returns the
The `__call__` method accepts either a `ParquetTable` object or a
`DeferredDatasetHandle`, and returns the
result of the calculation as a single column. Each functor defines what
columns are needed for the calculation, and only these columns are read
from the `ParquetTable`.
Expand All @@ -71,18 +74,19 @@ class Functor(object):
On initialization, a `Functor` should declare what filter (`filt` kwarg)
and dataset (e.g. `'ref'`, `'meas'`, `'forced_src'`) it is intended to be
applied to. This enables the `_get_cols` method to extract the proper
applied to. This enables the `_get_data` method to extract the proper
columns from the parquet file. If not specified, the dataset will fall back
on the `_defaultDataset`attribute. If filter is not specified and `dataset`
is anything other than `'ref'`, then an error will be raised when trying to
perform the calculation.
As currently implemented, `Functor` is only set up to expect a
`ParquetTable` of the format of the `deepCoadd_obj` dataset; that is, a
`MultilevelParquetTable` with the levels of the column index being `filter`,
dataset of the format of the `deepCoadd_obj` dataset; that is, a
dataframe with a multi-level column index,
with the levels of the column index being `filter`,
`dataset`, and `column`. This is defined in the `_columnLevels` attribute,
as well as being implicit in the role of the `filt` and `dataset` attributes
defined at initialization. In addition, the `_get_cols` method that reads
defined at initialization. In addition, the `_get_data` method that reads
the dataframe from the `ParquetTable` will return a dataframe with column
index levels defined by the `_dfLevels` attribute; by default, this is
`column`.
Expand Down Expand Up @@ -130,18 +134,68 @@ def columns(self):
raise NotImplementedError("Must define columns property or _columns attribute")
return self._columns

def multilevelColumns(self, parq):
if not set(parq.columnLevels) == set(self._columnLevels):
def _get_data_columnLevels(self, data, columnIndex=None):
if isinstance(data, DeferredDatasetHandle):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
return columnIndex.names
if isinstance(data, MultilevelParquetTable):
return data.columnLevels
else:
raise TypeError(f"Unknown type for data: {type(data)}!")

def _get_data_columnLevelNames(self, data, columnIndex=None):
if isinstance(data, DeferredDatasetHandle):
if columnIndex is None:
columnIndex = data.get(component="columns")
if columnIndex is not None:
columnLevels = columnIndex.names
columnLevelNames = {
level: list(np.unique(np.array([c for c in columnIndex])[:, i]))
for i, level in enumerate(columnLevels)
}
return columnLevelNames
if isinstance(data, MultilevelParquetTable):
return data.columnLevelNames
else:
raise TypeError(f"Unknown type for data: {type(data)}!")

def _colsFromDict(self, colDict, columnIndex=None):
new_colDict = {}
columnLevels = self._get_data_columnLevels(None, columnIndex=columnIndex)

for i, l in enumerate(columnLevels):
if l in colDict:
if isinstance(colDict[l], str):
new_colDict[l] = [colDict[l]]
else:
new_colDict[l] = colDict[l]
else:
new_colDict[l] = columnIndex.levels[i]

levelCols = [new_colDict[l] for l in columnLevels]
cols = product(*levelCols)
return list(cols)

def multilevelColumns(self, data, columnIndex=None, returnTuple=False):
if isinstance(data, DeferredDatasetHandle) and columnIndex is None:
columnIndex = data.get(component="columns")

columnLevels = self._get_data_columnLevels(data, columnIndex)

if not set(columnLevels) == set(self._columnLevels):
raise ValueError(
"ParquetTable does not have the expected column levels. "
f"Got {parq.columnLevels}; expected {self._columnLevels}."
f"Got {columnLevels}; expected {self._columnLevels}."
)

columnDict = {"column": self.columns, "dataset": self.dataset}
if self.filt is None:
if "filter" in parq.columnLevels:
columnLevelNames = self._get_data_columnLevelNames(data, columnIndex)
if "filter" in columnLevels:
if self.dataset == "ref":
columnDict["filter"] = parq.columnLevelNames["filter"][0]
columnDict["filter"] = columnLevelNames["filter"][0]
else:
raise ValueError(
f"'filt' not set for functor {self.name}"
Expand All @@ -153,24 +207,54 @@ def multilevelColumns(self, parq):
else:
columnDict["filter"] = self.filt

return parq._colsFromDict(columnDict)
if isinstance(data, MultilevelParquetTable):
return data._colsFromDict(columnDict)
elif isinstance(data, DeferredDatasetHandle):
if returnTuple:
return self._colsFromDict(columnDict, columnIndex=columnIndex)
else:
return columnDict # gen3 wants the dict.

def _func(self, df, dropna=True):
raise NotImplementedError("Must define calculation on dataframe")

def _get_cols(self, parq):
def _get_columnIndex(self, data):
"""Return columnIndex
"""

if isinstance(data, DeferredDatasetHandle):
return data.get(component="columns")
else:
return None

def _get_data(self, data):
"""Retrieve dataframe necessary for calculation.
The data argument can be a DataFrame, a ParquetTable instance, or a gen3 DeferredDatasetHandle
Returns dataframe upon which `self._func` can act.
"""
if isinstance(parq, MultilevelParquetTable):
columns = self.multilevelColumns(parq)
df = parq.toDataFrame(columns=columns, droplevels=False)
df = self._setLevels(df)
else:
if isinstance(data, pd.DataFrame):
return data

columnIndex = self._get_columnIndex(data)
is_multiLevel = isinstance(data, MultilevelParquetTable) or isinstance(columnIndex, pd.MultiIndex)

# Simple single-level parquet table
if isinstance(data, ParquetTable) and not is_multiLevel:
columns = self.columns
df = parq.toDataFrame(columns=columns)
df = data.toDataFrame(columns=columns)
return df

if is_multiLevel:
columns = self.multilevelColumns(data, columnIndex=columnIndex)

if isinstance(data, MultilevelParquetTable):
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
df = data.get(parameters={"columns": columns})

df = self._setLevels(df)
return df

def _setLevels(self, df):
Expand All @@ -181,9 +265,9 @@ def _setLevels(self, df):
def _dropna(self, vals):
return vals.dropna()

def __call__(self, parq, dropna=False):
def __call__(self, data, dropna=False):
try:
df = self._get_cols(parq)
df = self._get_data(data)
vals = self._func(df)
except Exception:
vals = self.fail(df)
Expand All @@ -192,10 +276,10 @@ def __call__(self, parq, dropna=False):

return vals

def difference(self, parq1, parq2, **kwargs):
def difference(self, data1, data2, **kwargs):
"""Computes difference between functor called on two different ParquetTable objects
"""
return self(parq1, **kwargs) - self(parq2, **kwargs)
return self(data1, **kwargs) - self(data2, **kwargs)

def fail(self, df):
return pd.Series(np.full(len(df), np.nan), index=df.index)
Expand Down Expand Up @@ -283,24 +367,46 @@ def update(self, new):
def columns(self):
return list(set([x for y in [f.columns for f in self.funcDict.values()] for x in y]))

def multilevelColumns(self, parq):
return list(set([x for y in [f.multilevelColumns(parq) for f in self.funcDict.values()] for x in y]))
def multilevelColumns(self, data, **kwargs):
return list(
set(
[
x
for y in [
f.multilevelColumns(data, returnTuple=True, **kwargs) for f in self.funcDict.values()
]
for x in y
]
)
)

def __call__(self, data, **kwargs):
columnIndex = self._get_columnIndex(data)
is_multiLevel = isinstance(data, MultilevelParquetTable) or isinstance(columnIndex, pd.MultiIndex)

if isinstance(data, ParquetTable) and not is_multiLevel:
columns = self.columns
df = data.toDataFrame(columns=columns)
valDict = {k: f._func(df) for k, f in self.funcDict.items()}

if is_multiLevel:
columns = self.multilevelColumns(data, columnIndex=columnIndex)

if isinstance(data, MultilevelParquetTable):
df = data.toDataFrame(columns=columns, droplevels=False)
elif isinstance(data, DeferredDatasetHandle):
df = data.get(parameters={"columns": columns})

def __call__(self, parq, **kwargs):
if isinstance(parq, MultilevelParquetTable):
columns = self.multilevelColumns(parq)
df = parq.toDataFrame(columns=columns, droplevels=False)
valDict = {}
for k, f in self.funcDict.items():
try:
subdf = f._setLevels(df[f.multilevelColumns(parq)])
subdf = f._setLevels(
df[f.multilevelColumns(data, returnTuple=True, columnIndex=columnIndex)]
)
valDict[k] = f._func(subdf)
except Exception:
raise
valDict[k] = f.fail(subdf)
else:
columns = self.columns
df = parq.toDataFrame(columns=columns)
valDict = {k: f._func(df) for k, f in self.funcDict.items()}

try:
valDf = pd.concat(valDict, axis=1)
Expand Down Expand Up @@ -701,7 +807,7 @@ def _func(self, df):
def columns(self):
return [self.mag1.col, self.mag2.col]

def multilevelColumns(self, parq):
def multilevelColumns(self, parq, **kwargs):
return [(self.dataset, self.filt1, self.col), (self.dataset, self.filt2, self.col)]

@property
Expand Down

0 comments on commit e148ebb

Please sign in to comment.