Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions python/lsst/analysis/tools/actions/plot/histPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@

__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")

import importlib.resources as importResources
import logging
from collections import defaultdict
from typing import Mapping

import lsst.analysis.tools
import numpy as np
import yaml
from lsst.pex.config import (
ChoiceField,
Config,
Expand All @@ -37,7 +40,7 @@
FieldValidationError,
ListField,
)
from lsst.utils.plotting import make_figure
from lsst.utils.plotting import get_multiband_plot_colors, make_figure, set_rubin_plotstyle
from matplotlib import cm
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
Expand Down Expand Up @@ -176,6 +179,10 @@ class HistPanel(Config):
"default stats: N, median, sigma mad are shown",
default=None,
)
addThresholds = Field[bool](
doc="Read in the predefined thresholds and indicate them on the histogram.",
default=False,
)

def validate(self):
super().validate()
Expand Down Expand Up @@ -216,7 +223,7 @@ class HistPlot(PlotAction):
cmap = Field[str](
doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
"A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
default="newtab10",
default="rubin",
)

def getInputSchema(self) -> KeyedDataSchema:
Expand Down Expand Up @@ -274,6 +281,7 @@ def makePlot(

# set up figure
fig = make_figure(dpi=300)
set_rubin_plotstyle()
hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[2.9, 1.1])
axs, ncols, nrows = self._makeAxes(hist_fig)

Expand All @@ -298,6 +306,7 @@ def makePlot(
label_font_size=label_font_size,
legend_font_size=legend_font_size,
ncols=ncols,
addThresholds=self.panels[panel].addThresholds,
)

all_handles, all_nums, all_meds, all_mads = [], [], [], []
Expand Down Expand Up @@ -391,6 +400,19 @@ def _assignColors(self):
"#009988",
"#BBBBBB",
],
rubin=[
"#0173B2",
"#DE8F05",
"#029E73",
"#D55E00",
"#CC78BC",
"#CA9161",
"#FBAFE4",
"#949494",
"#ECE133",
"#56B4E9",
],
bands=[get_multiband_plot_colors()],
)
if self.cmap in custom_cmaps.keys():
all_colors = custom_cmaps[self.cmap]
Expand All @@ -402,13 +424,16 @@ def _assignColors(self):

counter = 0
colors = defaultdict(list)

for panel in self.panels:
for hist in self.panels[panel].hists:
colors[panel].append(all_colors[counter % len(all_colors)])
counter += 1
return colors

def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
def _makePanel(
self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1, addThresholds=False
):
"""Plot a single panel containing histograms."""
nums, meds, mads = [], [], []
for i, hist in enumerate(self.panels[panel].hists):
Expand All @@ -418,6 +443,10 @@ def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_siz
meds.append(med)
mads.append(mad)
panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
if self.panels[panel].addThresholds:
metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml")
metricDefs = yaml.safe_load(metricThresholdFile)

if all(np.isfinite(panel_range)):
nHist = 0
for i, hist in enumerate(self.panels[panel].hists):
Expand All @@ -434,6 +463,15 @@ def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_siz
label=self.panels[panel].hists[hist],
)
ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
if self.panels[panel].addThresholds and hist in metricDefs:
if "lowThreshold" in metricDefs[hist].keys():
lowThreshold = metricDefs[hist]["lowThreshold"]
if np.isfinite(lowThreshold):
ax.axvline(lowThreshold, color=colors[i])
if "highThreshold" in metricDefs[hist].keys():
highThreshold = metricDefs[hist]["highThreshold"]
if np.isfinite(highThreshold):
ax.axvline(highThreshold, color=colors[i])
nHist += 1

if nHist > 0:
Expand Down
Loading
Loading