Skip to content

Commit

Permalink
Better match analysis_tools standards, e.g.,
Browse files Browse the repository at this point in the history
turn on atools in the pipeline config, not the task
  • Loading branch information
mrawls committed Feb 2, 2024
1 parent f77b6b8 commit 2b03b03
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 24 deletions.
5 changes: 3 additions & 2 deletions pipelines/diaTractQualityCore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ description: |
tasks:
analyzeDiaSourceTableTract:
class: lsst.analysis.tools.tasks.DiaSourceTableTractAnalysisTask
# This task is designed to automatically run all the relevant metric
# atools, so only plots need to be configured here.
config:
# This will be used as the first part of the butler data product name
connections.outputName: diaSourceTableTract
atools.NumDiaSources: NumDiaSourcesMetric
atools.NumStreakDiaSources: NumStreakDiaSourcesMetric
atools.NumStreakCenterDiaSources: NumStreakCenterDiaSourcesMetric
atools.streakDiaSourcePlot: PlotStreakDiaSources
python: |
from lsst.analysis.tools.atools import *
13 changes: 6 additions & 7 deletions python/lsst/analysis/tools/actions/plot/diaSkyPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def makePlot(self, data: KeyedData, **kwargs) -> Figure:
"""
if "figsize" in kwargs:
figsize = kwargs.pop("figsize", "")
fig = plt.figure(figsize=figsize, dpi=1000)
fig = plt.figure(figsize=figsize, dpi=600)
else:
fig = plt.figure(figsize=(8, 6), dpi=1000)
fig = plt.figure(figsize=(12, 9), dpi=600)
axs = self._makeAxes(fig)
for panel, ax in zip(self.panels.values(), axs):
self._makePanel(data, panel, ax, **kwargs)
Expand Down Expand Up @@ -147,11 +147,10 @@ def _makePanel(self, data, panel, ax, **kwargs):
color : `str`
"""
for ra, dec in zip(panel.ras, panel.decs): # loop over column names (dict keys)
ax.scatter(
data[ra], data[dec], s=panel.size, alpha=panel.alpha, marker=".", linewidths=0
)
# TODO: implement lists of colors, sizes, alphas, etc.
# Right now, color is excluded so each series gets the next default
ax.scatter(data[ra], data[dec], s=panel.size, alpha=panel.alpha, marker=".", linewidths=0)
# TODO DM-42768: implement lists of colors, sizes, alphas, etc.
# and add better support for multi-panel plots.
# Right now, color is excluded, each series gets the next default.

ax.set_xlabel(panel.xlabel)
ax.set_ylabel(panel.ylabel)
Expand Down
22 changes: 15 additions & 7 deletions python/lsst/analysis/tools/atools/diaSourceTableTractMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..actions.scalar import CountAction
from ..actions.vector import FlagSelector, GoodDiaSourceSelector, LoadVector
from ..contexts import DrpContext
from ..interfaces import AnalysisTool
from ..interfaces import AnalysisTool, KeyedDataSchema, Vector


class NumDiaSourcesMetric(AnalysisTool):
Expand Down Expand Up @@ -69,7 +69,8 @@ def setDefaults(self):


class NumStreakCenterDiaSourcesMetric(AnalysisTool):
"""Count DiaSources that have the STREAK flag in the center of the source."""
"""Count DiaSources that have the STREAK flag in the center
of the source."""

def setDefaults(self):
super().setDefaults()
Expand All @@ -85,6 +86,16 @@ def setDefaults(self):
# Use, e.g., `pixelFlags_thing`, not `base_PixelFlags_flag_thing`
self.applyContext(DrpContext)

def getInputSchema(self) -> KeyedDataSchema:
"""Defines the schema this plot action expects (the keys it looks
for and what type they should be). In other words, verifies that
the input data has the columns we are expecting with the right dtypes.
"""
for ra in self.panel.ras:
yield (ra, Vector)
for dec in self.panel.decs:
yield (ra, Vector)


class PlotStreakDiaSources(AnalysisTool):
"""Plot all good DiaSources, and indicate which coincide with a streak."""
Expand All @@ -102,22 +113,21 @@ def setDefaults(self):
# First, select "good" DiaSources that are not obvious garbage
self.process.buildActions.coordsGood = KeyedDataSelectorAction(vectorKeys=["ra", "dec"])
self.process.buildActions.coordsGood.selectors.selectorGood = GoodDiaSourceSelector()
self.applyContext(DrpContext)

# Second, select DiaSources with STREAK flag set in the footprint
self.process.buildActions.coordsStreak = KeyedDataSelectorAction(vectorKeys=["ra", "dec"])
self.process.buildActions.coordsStreak.selectors.selectorStreak = FlagSelector(
selectWhenTrue=["pixelFlags_streak"]
)
self.applyContext(DrpContext)

# Finally, select DiaSources with STREAK flag set in the source center
self.process.buildActions.coordsStreakCenter = KeyedDataSelectorAction(vectorKeys=["ra", "dec"])
self.process.buildActions.coordsStreakCenter.selectors.selectorStreakCenter = FlagSelector(
selectWhenTrue=["pixelFlags_streakCenter"]
)
self.applyContext(DrpContext)

# Use the DRP column names for all of the above, and generate the plot
self.applyContext(DrpContext)
self.produce.plot = DiaSkyPlot()

self.produce.plot.panels["panel_main"] = DiaSkyPanel()
Expand All @@ -136,5 +146,3 @@ def setDefaults(self):
"coordsStreakCenter_dec",
]
self.produce.plot.panels["panel_main"].rightSpinesVisible = False

# TODO: plot color, point size, and legend customizations
2 changes: 1 addition & 1 deletion python/lsst/analysis/tools/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .catalogMatch import *
from .ccdVisitTableAnalysis import *
from .diaObjectDetectorVisitAnalysis import *
from .diaSourceTableTractAnalysis import *
from .diffimTaskDetectorVisitAnalysis import *
from .diffMatchedAnalysis import *
from .gatherResourceUsage import *
Expand All @@ -18,4 +19,3 @@
from .refCatSourcePhotometricAnalysis import *
from .sourceTableVisitAnalysis import *
from .trailedDiaSrcDetectorVisitAnalysis import *
from .diaSourceTableTractAnalysis import *
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from lsst.skymap import BaseSkyMap

from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask
from ..atools import diaSourceTableTractMetrics


class DiaSourceTableTractAnalysisConnections(
Expand Down Expand Up @@ -58,13 +57,7 @@ class DiaSourceTableTractAnalysisConfig(
AnalysisBaseConfig, pipelineConnections=DiaSourceTableTractAnalysisConnections
):
def setDefaults(self):
"""All the metrics are turned on in this task by default,
so it is not necessary to configure each one in a pipeline
"""
super().setDefaults()
self.atools.NumDiaSources = diaSourceTableTractMetrics.NumDiaSourcesMetric()
self.atools.NumStreakDiaSources = diaSourceTableTractMetrics.NumStreakDiaSourcesMetric()
self.atools.NumStreakCenterDiaSources = diaSourceTableTractMetrics.NumStreakCenterDiaSourcesMetric()


class DiaSourceTableTractAnalysisTask(AnalysisPipelineTask):
Expand Down

0 comments on commit 2b03b03

Please sign in to comment.