Skip to content

Commit

Permalink
Add more validations and features
Browse files Browse the repository at this point in the history
  • Loading branch information
enourbakhsh committed Jan 30, 2024
1 parent de5c340 commit 4832c1a
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 75 deletions.
1 change: 1 addition & 0 deletions python/lsst/analysis/tools/actions/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .diaSkyPlot import *
from .focalPlanePlot import *
from .histPlot import *
from .matrixPlot import *
from .multiVisitCoveragePlot import *
from .propertyMapPlot import *
from .rhoStatisticsPlot import *
Expand Down
284 changes: 209 additions & 75 deletions python/lsst/analysis/tools/actions/plot/matrixPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import TYPE_CHECKING, Any, Mapping

import matplotlib.pyplot as plt
from lsst.pex.config import ChoiceField, DictField, Field, ListField, FieldValidationError
from lsst.pex.config import DictField, Field, ListField

from astropy.visualization.mpl_normalize import ImageNormalize
import astropy.visualization as apViz
Expand All @@ -51,164 +51,298 @@ class MatrixPlot(PlotAction):

xAxisLabel = Field[str](
doc="The label to use for the x-axis.",
default="x",
default="",
)

yAxisLabel = Field[str](
doc="The label to use for the y-axis.",
default="y",
default="",
)

colorbarLabel = Field[str](
doc="The label to use for the colorbar.",
default="",
)

vmin = Field[float](
doc="The vmin value.",
doc="The vmin value for the colorbar.",
default=None,
optional=True,
)

vmax = Field[float](
doc="The vmax value.",
doc="The vmax value for the colorbar.",
default=None,
optional=True,
)

figsize = ListField[float](
doc="The figsize value.",
default=[14, 8],
doc="The size of the figure.",
default=[5, 5],
optional=True,
)

title = Field[str](
doc="The title value.",
doc="The title of the figure.",
default="",
)

xLines = ListField[float](
doc=("The values of x where a vertical lines is drawn."),
xLines = DictField[float, str](
doc=("Dictionary of x-values and the labels where vertical lins are drawn."),
default=None,
optional=True,
)

yLines = ListField[float](
doc=("The values of y where a horizontal lins is drawn."),
yLines = DictField[float, str](
doc=("Dictionary of y-values and the leabels where horizontal lins are drawn."),
default=None,
optional=True,
)

# DictField of float and string
xAxisTicks = ListField[float](
doc="The values of x where a vertical lines is drawn.",
default=None,
xLinesColor = Field[str](
doc="The color of the vertical lines.",
default="red",
optional=True,
)

yLinesColor = Field[str](
doc="The color of the horizontal lines.",
default="red",
optional=True,
)

xLinesStyle = Field[str](
doc="The style of the vertical lines.",
default="--",
optional=True,
)

yLinesStyle = Field[str](
doc="The style of the horizontal lines.",
default="--",
optional=True,
)

yAxisTicks = ListField[float](
doc="The values of x where a vertical lines is drawn.",
xAxisTickValues = ListField[float](
doc="List of x-axis tick values. When `centerLabelsBetweenTicks` is set to True, ensure that the "
"number of values is exactly one more than the number of tick labels.",
default=None,
)

xAxisTickLabels = ListField[str](
doc="The ... x",
doc="List of x-axis tick labels. When `centerLabelsBetweenTicks` is enabled, make sure that the "
"number of labels is exactly one less than the number of tick values.",
default=None,
)

yAxisTickValues = ListField[float](
doc="List of y-axis tick values. When `centerLabelsBetweenTicks` is set to True, ensure that the "
"number of values is exactly one more than the number of tick labels.",
default=None,
)

yAxisTickLabels = ListField[str](
doc="The ... y",
doc="List of y-axis tick labels. When `centerLabelsBetweenTicks` is enabled, make sure that the "
"number of labels is exactly one less than the number of tick values.",
default=None,
)

centerLabelsBetweenTicks = Field[bool](
doc="Whether to center the tick labels between the tick marks. If you set this to True, you must "
"also provide the `xAxisTickValues`, `xAxisTickLabels`, `yAxisTickValues` and `yAxisTickLabels` "
"fields.",
default=False,
)

def setDefaults(self):
super().setDefaults()
# self.strKwargs = {"fmt": "o"}

def validate(self):
# if (len(set(self.boolKwargs.keys()).intersection(self.numKwargs.keys())) > 0) or (
# len(set(self.boolKwargs.keys()).intersection(self.strKwargs.keys())) > 0
# ):
# raise FieldValidationError(self.boolKwargs, self, "Keywords have been repeated")

super().validate()

def getInputSchema(self) -> KeyedDataSchema:
base: list[tuple[str, type[Vector]]] = []
base.append(("matrix", Vector))
base.append((self.matrixKey, Vector))
return base

def __call__(self, data: KeyedData, **kwargs) -> Figure:
self._validateInput(data)
return self.makePlot(data, **kwargs)

def _validateInput(self, data: KeyedData) -> None:
# check the input is a 2d array
pass
# needed = set(k[0] for k in self.getInputSchema())
# if not needed.issubset(data.keys()):
# raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}")


# Check that the input data contains all the required keys.
needed = set(k[0] for k in self.getInputSchema())
if not needed.issubset(data.keys()):
raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}")
# Check the input data is a matrix, i.e. a 2d array.
if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2:
raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}")
# Check that the set of tick values and labels are jointly set or
# unset. The logical XOR operator "^" to checks if exactly one of them
# is None.
if (self.xAxisTickValues is None) ^ (self.xAxisTickLabels is None):
raise ValueError(
"Both `xAxisTickValues` and `xAxisTickLabels` must be set if either is set. "
f"xAxisTickValues: {self.xAxisTickValues}, xAxisTickLabels: {self.xAxisTickLabels}"
)
if (self.yAxisTickValues is None) ^ (self.yAxisTickLabels is None):
raise ValueError(
"Both `yAxisTickValues` and `yAxisTickLabels` must be set if either is set. "
f"yAxisTickValues: {self.yAxisTickValues}, yAxisTickLabels: {self.yAxisTickLabels}"
)
# Check that the tick values and labels are consistent.
if self.centerLabelsBetweenTicks:
if any(
item is None
for item in (
self.xAxisTickValues,
self.xAxisTickLabels,
self.yAxisTickValues,
self.yAxisTickLabels,
)
):
raise ValueError(
"All of `xAxisTickValues`, `xAxisTickLabels`, `yAxisTickValues` and `yAxisTickLabels` "
"must be set if `centerLabelsBetweenTicks` is set to True"
)
if len(self.xAxisTickValues) != len(self.xAxisTickLabels) + 1:
raise ValueError(
f"Length of `xAxisTickValues` ({len(self.xAxisTickValues)}) must be exactly one more "
f"than the length of `xAxisTickLabels` ({len(self.xAxisTickLabels)}) since "
"`centerLabelsBetweenTicks` is set to True"
)
if len(self.yAxisTickValues) != len(self.yAxisTickLabels) + 1:
raise ValueError(
f"Length of `yAxisTickValues` ({len(self.yAxisTickValues)}) must be exactly one more "
f"than the length of `yAxisTickLabels` ({len(self.yAxisTickLabels)}) since "
"`centerLabelsBetweenTicks` is set to True"
)
else:
if self.xAxisTickValues is not None and self.xAxisTickLabels is not None:
if len(self.xAxisTickValues) != len(self.xAxisTickLabels):
raise ValueError(
f"Length of `xAxisTickValues` ({len(self.xAxisTickValues)}) must be exactly equal to "
f"the length of `xAxisTickLabels` ({len(self.xAxisTickLabels)}) since "
"`centerLabelsBetweenTicks` is set to False"
)
if self.yAxisTickValues is not None and self.yAxisTickLabels is not None:
if len(self.yAxisTickValues) != len(self.yAxisTickLabels):
raise ValueError(
f"Length of `yAxisTickValues` ({len(self.yAxisTickValues)}) must be exactly equal to "
f"the length of `yAxisTickLabels` ({len(self.yAxisTickLabels)}) since "
"`centerLabelsBetweenTicks` is set to False"
)

def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure:
"""
Plot a matrix
Plot a matrix of values.
Parameters
----------
data : `~pandas.core.frame.DataFrame`
The catalog containing various rho statistics.
data : `~lsst.analysis.tools.interfaces.KeyedData`
The data to plot.
**kwargs
Additional keyword arguments to pass to the plot
Additional keyword arguments to pass to the plot.
Returns
-------
fig : `~matplotlib.figure.Figure`
The resulting figure.
"""

slots = kwargs.get('slots', None)
title = self.title
vmin = self.vmin
vmax = self.vmax
figsize = self.figsize

matrix = np.array(data[self.matrixKey])
if vmin is not None and vmax is not None:
vrange = (vmin, vmax)
matrix = data[self.matrixKey]
if self.vmin is not None and self.vmax is not None:
vrange = (self.vmin, self.vmax)
else:
interval = apViz.PercentileInterval(98.)
interval = apViz.PercentileInterval(98.0)
vrange = interval.get_limits(np.abs(matrix.flatten()))

# Allow for multiple curves to lie on the same plot.
# Allow for the figure object to be passed in.
fig = kwargs.get("fig", None)
if fig is None:
# fig = plt.figure(dpi=300)
# ax = fig.add_subplot(111)
fig = plt.figure(figsize=figsize, dpi=300)
axes = fig.add_subplot(111)
fig = plt.figure(figsize=self.figsize, dpi=300)
ax = fig.add_subplot(111)
else:
ax = fig.gca()

if title:
axes.set_title(title)
if self.title:
ax.set_title(self.title)

norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1])
img = axes.imshow(matrix, interpolation='none', norm=norm)
cbar = plt.colorbar(img)

# slots = ["S00", "S01", "S02", "S10", "S11", "S12", "S20", "S21", "S22"]
# if slots is not None:
# amps = 16
# major_locs = [i*amps - 0.5 for i in range(len(slots) + 1)]
# minor_locs = [amps//2 + i*amps for i in range(len(slots))]
# for axis in (axes.xaxis, axes.yaxis):
# axis.set_tick_params(which='minor', length=0)
# axis.set_major_locator(ticker.FixedLocator(major_locs))
# axis.set_major_formatter(ticker.FixedFormatter(['']*len(major_locs)))
# axis.set_minor_locator(ticker.FixedLocator(minor_locs))
# axis.set_minor_formatter(ticker.FixedFormatter(slots))

# o_dict = dict(fig=fig, axes=axes, img=img, cbar=cbar)
# self._fig_dict[key] = o_dict
# return o_dict
if self.xAxisLabel:
ax.set_xlabel(self.xAxisLabel)

if self.yAxisLabel:
ax.set_ylabel(self.yAxisLabel)

# Set the colorbar and draw the image.
norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1])
img = ax.imshow(matrix, interpolation="none", norm=norm)

# Calculate the aspect ratio of the image.
ratio = matrix.shape[0] / matrix.shape[1]

# Add the colorbar flush with the image axis.
fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04, label=self.colorbarLabel)

if self.centerLabelsBetweenTicks:
# Create a dictionary to map dimension labels to their respective
# tick values and labels.
# Note: We shift by 0.5 to position the tick marks at the pixel
# boundaries rather than the center of the pixel.
major_tick_data = {
"x": (np.array(self.xAxisTickValues) - 0.5, self.xAxisTickLabels),
"y": (np.array(self.yAxisTickValues) - 0.5, self.yAxisTickLabels),
}
for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]:
# Get the major tick positions and labels.
major_tick_positions, major_tick_labels = major_tick_data[dim]
# Set major ticks.
axis.set_ticks(major_tick_positions, minor=False)
# Hide the major tick labels.
axis.set_ticklabels("")
# Set positions for minor ticks at the midpoints of the major
# ticks.
axis.set_ticks((major_tick_positions[:-1] + major_tick_positions[1:]) / 2, minor=True)
# Set minor tick labels.
axis.set_ticklabels(major_tick_labels, minor=True)
# Remove minor tick marks for asthetic reasons.
axis.set_tick_params(which="minor", length=0)
else:
# Set the desired tick values and labels if provided.
if self.xAxisTickValues is not None:
ax.set_xticks(self.xAxisTickValues)
if self.xAxisTickLabels is not None:
ax.set_xticklabels(self.xAxisTickLabels)
if self.yAxisTickValues is not None:
ax.set_yticks(self.yAxisTickValues)
if self.yAxisTickLabels is not None:
ax.set_yticklabels(self.yAxisTickLabels)

# Add vertical and horizontal lines if provided.
if self.xLines is not None:
for x, label in self.xLines.items():
ax.axvline(x=x, color=self.xLinesColor, linestyle=self.xLinesStyle)
ax.text(
x,
0.03,
label,
rotation=90,
color=self.xLinesColor,
transform=ax.get_xaxis_transform(),
horizontalalignment="right",
)
if self.yLines is not None:
for y, label in self.yLines.items():
ax.axhline(y=y, color=self.yLinesColor, linestyle=self.yLinesStyle)
ax.text(
0.03,
y,
label,
color=self.yLinesColor,
transform=ax.get_yaxis_transform(),
verticalalignment="bottom",
)

# Add plot info if provided.
if plotInfo is not None:
fig = addPlotInfo(fig, plotInfo)

Expand Down

0 comments on commit 4832c1a

Please sign in to comment.