Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-34133: Add ScatterPlotWithTwoHistsTaskTestCase #30

Merged
merged 3 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ pytest_session.txt
.cache/
.pytest_cache
.coverage

# Pytest outputs
tests/data/test_*-failed-diff.png
49 changes: 32 additions & 17 deletions python/lsst/analysis/drp/scatterPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
itemtype=str
)

def get_requirements(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have an underscore in the beginning, since this is not meant to be a public API (or is it?). I don't have a good alternative, but requirements seems strange. _get_bands_and_columns, may be?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any need to make it quasi-private - it's purely a getter that doesn't change state - and while requirements isn't a great name, it's meant to specify that these are the things that the config requires for the task to run.

"""Return inputs required for a Task to run with this config.

Returns
-------
bands : `set`
The required bands.
columns : `set`
The required column names.
"""
columnNames = {"patch"}
bands = set()
for actionStruct in [self.axisActions,
self.selectorActions,
self.highSnStatisticSelectorActions,
self.lowSnStatisticSelectorActions,
self.sourceSelectorActions]:
for action in actionStruct:
for col in action.columns:
if col is not None:
columnNames.add(col)
column_split = col.split("_")
# If there's no underscore, it has no band prefix
if len(column_split) > 1:
band = column_split[0]
if band not in self.nonBandColumnPrefixes:
bands.add(band)
return bands, columnNames

nonBandColumnPrefixes = pexConfig.ListField(
doc="Column prefixes that are not bands and which should not be added to the set of bands",
dtype=str,
Expand Down Expand Up @@ -111,22 +140,7 @@ class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Docs inherited from base class
columnNames = {"patch"}
bands = set()
for actionStruct in [self.config.axisActions, self.config.selectorActions,
self.config.highSnStatisticSelectorActions,
self.config.lowSnStatisticSelectorActions,
self.config.sourceSelectorActions]:
for action in actionStruct:
for col in action.columns:
columnNames.add(col)
column_split = col.split("_")
# If there's no underscore, it doesn't have a band prefix
if len(column_split) > 1:
band = column_split[0]
if band not in self.config.nonBandColumnPrefixes:
bands.add(band)

bands, columnNames = self.config.get_requirements()
inputs = butlerQC.get(inputRefs)
dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
inputs['catPlot'] = dataFrame
Expand Down Expand Up @@ -233,7 +247,8 @@ def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
# Get useful information about the plot
plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN)
# Calculate the corners of the patches and some associated stats
sumStats = generateSummaryStats(plotDf, self.config.axisLabels["y"], skymap, plotInfo)
sumStats = {} if skymap is None else generateSummaryStats(
plotDf, self.config.axisLabels["y"], skymap, plotInfo)
# Make the plot
fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)

Expand Down
Binary file added tests/data/test_scatterPlot.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
137 changes: 137 additions & 0 deletions tests/test_scatterPlot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# This file is part of analysis_drp.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.


import unittest
import lsst.utils.tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from lsst.analysis.drp.calcFunctors import MagDiff
from lsst.analysis.drp.dataSelectors import GalaxyIdentifier
from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images, ImageComparisonFailure

import numpy as np
from numpy.random import default_rng
import os
import pandas as pd
import shutil
import tempfile

matplotlib.use("Agg")

ROOT = os.path.abspath(os.path.dirname(__file__))
filename_figure_ref = os.path.join(ROOT, "data", "test_scatterPlot.png")


class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase):
"""ScatterPlotWithTwoHistsTask test case."""
def setUp(self):
self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output")

# Set up a quasi-plausible measurement catalog
mag = 12.5 + 2.5*np.log10(np.arange(10, 100000))
flux = 10**(-0.4*(mag - (mag[-1] + 1)))
rng = default_rng(0)
extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0]))
flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness)))
flux_err = np.sqrt(flux_meas * (1 + extendedness))
good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3
extendedness = extendedness[good]
flux = flux[good]
flux_meas = flux_meas[good]
flux_err = flux_err[good]

# Configure the plot to show observed vs true mags
config = ScatterPlotWithTwoHistsTaskConfig(
axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"},
)
config.selectorActions.flagSelector.bands = ["i"]
config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux")
config.nonBandColumnPrefixes.append("refcat")
config.sourceSelectorActions.galaxySelector = GalaxyIdentifier
config.highSnStatisticSelectorActions.statSelector.threshold = 50
config.lowSnStatisticSelectorActions.statSelector.threshold = 20
self.task = ScatterPlotWithTwoHistsTask(config=config)

n = len(flux)
self.bands, columns = config.get_requirements()
data = {
"refcat_flux": flux,
"patch": np.zeros(n, dtype=int),
}

# Assign values to columns based on their unchanged default names
for column in columns:
if column not in data:
if column.startswith("detect"):
data[column] = np.ones(n, dtype=bool)
elif column.endswith("_flag") or "Flag" in column:
data[column] = np.zeros(n, dtype=bool)
elif column.endswith("Flux"):
config.axisActions.yAction.col1 = column
data[column] = flux_meas
elif column.endswith("FluxErr"):
data[column] = flux_err
elif column.endswith("_extendedness"):
data[column] = extendedness
else:
raise RuntimeError(f"Unexpected column {column} in ScatterPlotWithTwoHistsTaskConfig")

self.data = pd.DataFrame(data)

def tearDown(self):
if os.path.exists(self.testDir):
shutil.rmtree(self.testDir, True)
del self.bands
del self.data
del self.task

def test_ScatterPlotWithTwoHistsTask(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to protect plotting tests with the default matplotlibrc:
https://github.com/LSSTDESC/skyproj/blob/702ac340e46cccae37ad91194d017816bf596ceb/tests/test_plotting.py#L17

plt.rcParams.update(plt.rcParamsDefault)
result = self.task.run(self.data,
dataId={},
runName="test",
skymap=None,
tableName="test",
bands=self.bands,
plotName="test")

filename_figure_tmp = os.path.join(self.testDir, "test_scatterPlot.png")
result.scatterPlot.savefig(filename_figure_tmp)
diff = compare_images(filename_figure_tmp, filename_figure_ref, 0)
if diff is not None:
raise ImageComparisonFailure(diff)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()