Skip to content

Commit

Permalink
Make changes required by DM-42888
Browse files Browse the repository at this point in the history
  • Loading branch information
enourbakhsh committed Feb 21, 2024
1 parent 3f0e551 commit 537b17e
Showing 1 changed file with 163 additions and 26 deletions.
189 changes: 163 additions & 26 deletions python/lsst/analysis/tools/actions/plot/matrixPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import matplotlib.pyplot as plt
import numpy as np
from astropy.visualization.mpl_normalize import ImageNormalize
from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField
from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField

from ...interfaces import PlotAction, Vector
from .plotUtils import addPlotInfo
Expand Down Expand Up @@ -64,85 +64,161 @@ class GuideLinesConfig(Config):


class MatrixPlot(PlotAction):
"""Make the plot of a matrix (2D array)."""
"""Make the plot of a matrix (2D array).
Notes
-----
The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve
as dictionaries to map axis tick positions to their corresponding labels.
If any positions do not align with major ticks (either provided by
`x/yAxisTickValues` or automatically set by matplotlib), they will be
designated as minor ticks. Thus, these tick labels operate independently,
meaning their corresponding positions do not need to match those in
`x/yAxisTickValues` or anything else. The code automatically adjusts to
handle any overlaps caused by user input and across various plotting
scenarios.
Note that when `component1Key` and `component2Key` are specified, the x and
y tick values and labels will be dynamically configured, thereby
eliminating the need for providing `x/yAxisTickValues` and
`x/yAxisTickLabels`.
"""

inputDim = ChoiceField[int](
doc="The dimensionality of the input data.",
default=1,
allowed={
1: "1D inputs are automatically reshaped into square 2D matrices.",
2: "2D inputs are directly utilized as is.",
},
optional=True,
)

matrixKey = Field[str](
doc="The key for the input matrix.",
default="matrix",
)

component1Key = Field[str](
doc="The key to access a list of names for the first component set in a correlation analysis. This "
"will be used to determine x-axis tick values and labels.",
default=None,
optional=True,
)

component2Key = Field[str](
doc="The key to access a list of names for the second component set in a correlation analysis. This "
"will be used to determine y-axis tick values and labels.",
)

xAxisLabel = Field[str](
doc="The label to use for the x-axis.",
default="",
optional=True,
)

yAxisLabel = Field[str](
doc="The label to use for the y-axis.",
default="",
optional=True,
)

axisLabelFontSize = Field[float](
doc="The font size for the axis labels.",
default=9,
optional=True,
)

colorbarLabel = Field[str](
doc="The label to use for the colorbar.",
default="",
optional=True,
)

colorbarLabelFontSize = Field[float](
doc="The font size for the colorbar label.",
default=10,
optional=True,
)

colorbarTickLabelFontSize = Field[float](
doc="The font size for the colorbar tick labels.",
default=8,
optional=True,
)

vmin = Field[float](
doc="The vmin value for the colorbar.",
default=None,
optional=True,
)

vmax = Field[float](
doc="The vmax value for the colorbar.",
default=None,
optional=True,
)

figsize = ListField[float](
doc="The size of the figure.",
default=[5, 5],
maxLength=2,
optional=True,
)

title = Field[str](
doc="The title of the figure.",
default="",
optional=True,
)

titleFontSize = Field[int](
titleFontSize = Field[float](
doc="The font size for the title.",
default=12,
default=10,
optional=True,
)

xAxisTickValues = ListField[float](
doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
default=None,
optional=True,
)

xAxisTickLabels = DictField[float, str](
doc="Dictionary mapping x-axis tick positions to their corresponding labels. If any positions do not "
"align with major ticks (provided by `xAxisTickValues` or automatically set by matplotlib), they "
"will be set as minor ticks. Thus, these tick labels operate independently, meaning their "
"corresponding positions do not need to match those in `xAxisTickValues` or anything else. The code "
"automatically adjusts to handle any overlaps caused by user input.",
doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, "
"refer to the 'Notes' section of the class docstring.",
default=None,
optional=True,
)

yAxisTickValues = ListField[float](
doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
default=None,
optional=True,
)

yAxisTickLabels = DictField[float, str](
doc="Dictionary mapping y-axis tick positions to their corresponding labels. If any positions do not "
"align with major ticks (provided by `yAxisTickValues` or automatically set by matplotlib), they "
"will be set as minor ticks. Thus, these tick labels operate independently, meaning their "
"corresponding positions do not need to match those in `yAxisTickValues` or anything else. The code "
"automatically adjusts to handle any overlaps caused by user input.",
doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, "
"refer to the 'Notes' section of the class docstring.",
default=None,
optional=True,
)

tickLabelsFontSize = Field[float](
doc="The font size for the tick labels.",
default=8,
optional=True,
)

tickLabelsRotation = Field[float](
doc="The rotation of the tick labels.",
default=0,
optional=True,
)

setPositionsAtPixelBoundaries = Field[bool](
doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.",
default=False,
optional=True,
)

hideMajorTicks = ListField[str](
Expand All @@ -152,6 +228,7 @@ class MatrixPlot(PlotAction):
default=[],
maxLength=2,
itemCheck=lambda s: s in ["x", "y"],
optional=True,
)

hideMinorTicks = ListField[str](
Expand All @@ -161,18 +238,21 @@ class MatrixPlot(PlotAction):
default=[],
maxLength=2,
itemCheck=lambda s: s in ["x", "y"],
optional=True,
)

dpi = Field[int](
doc="The resolution of the figure.",
default=300,
optional=True,
)

guideLines = ConfigDictField[str, GuideLinesConfig](
doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are "
"instances of `GuideLinesConfig`.",
default={},
dictCheck=lambda d: all([k in ["x", "y"] for k in d]),
optional=True,
)

def getInputSchema(self) -> KeyedDataSchema:
Expand All @@ -198,6 +278,26 @@ def _validateInput(self, data: KeyedData, **kwargs: Any) -> None:
raise ValueError(
f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}"
)
# Check that if one component key is provided, the other must be too.
if (self.component1Key is not None and self.component2Key is None) or (
self.component1Key is None and self.component2Key is not None
):
raise ValueError(
"Both 'component1Key' and 'component2Key' must be provided together if either is provided."
)
# Check that if component keys are provided, any of the tick values or
# labels are not and vice versa.
if (self.component1Key is not None and self.component2Key is not None) and (
self.xAxisTickValues is not None
or self.yAxisTickValues is not None
or self.xAxisTickLabels is not None
or self.yAxisTickLabels is not None
):
raise ValueError(
"If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', "
"'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be "
"provided as they will be dynamically configured."
)

def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure:
"""
Expand All @@ -217,7 +317,23 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
fig : `~matplotlib.figure.Figure`
The resulting figure.
"""
# Retrieve the matrix info from the input data.
matrix = data[self.matrixKey]

# Fetch the components between which the correlation is calculated.
if self.component1Key is not None and self.component2Key is not None:
comp1 = data[self.component1Key]
comp2 = data[self.component2Key]

if self.inputDim == 1:
# Calculate the size of the square.
square_size = int(np.sqrt(matrix.size))
# Reshape into a square array.
matrix = matrix.reshape(square_size, square_size)
if self.component1Key is not None and self.component2Key is not None:
comp1 = comp1.reshape(square_size, square_size)
comp2 = comp2.reshape(square_size, square_size)

# Calculate default limits only if needed.
if self.vmin is None or self.vmax is None:
default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten()))
Expand All @@ -242,10 +358,10 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
ax.set_title(self.title, fontsize=self.titleFontSize)

if self.xAxisLabel:
ax.set_xlabel(self.xAxisLabel)
ax.set_xlabel(self.xAxisLabel, fontsize=self.axisLabelFontSize)

if self.yAxisLabel:
ax.set_ylabel(self.yAxisLabel)
ax.set_ylabel(self.yAxisLabel, fontsize=self.axisLabelFontSize)

# Set the colorbar and draw the image.
norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1])
Expand All @@ -255,15 +371,32 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
ratio = matrix.shape[0] / matrix.shape[1]

# Add the colorbar flush with the image axis.
fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04, label=self.colorbarLabel)
cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04)

# Set the colorbar label and its font size.
cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize)

# Set the colorbar tick label font size.
cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize)

# If requested, we shift all the positions by 0.5 considering the
# zero-point at a pixel boundary rather than the center of the pixel.
shift = 0.5 if self.setPositionsAtPixelBoundaries else 0

if self.component1Key is not None and self.component2Key is not None:
xAxisTickValues = np.arange(matrix.shape[0] + shift)
yAxisTickValues = np.arange(matrix.shape[1] + shift)
xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp1[0, :])}
yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp2[:, 0])}
else:
xAxisTickValues = self.xAxisTickValues
yAxisTickValues = self.yAxisTickValues
xAxisTickLabels = self.xAxisTickLabels
yAxisTickLabels = self.yAxisTickLabels

# If the tick values are not provided, retrieve them from the axes.
xticks = self.xAxisTickValues if self.xAxisTickValues is not None else ax.xaxis.get_ticklocs()
yticks = self.yAxisTickValues if self.yAxisTickValues is not None else ax.yaxis.get_ticklocs()
xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs()
yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs()

# Retrieve the current limits of the x and y axes.
xlim, ylim = ax.get_xlim(), ax.get_ylim()
Expand All @@ -276,13 +409,13 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
tick_data = {
"x": (
xticks - shift,
np.array(list(self.xAxisTickLabels.keys())) - shift if self.xAxisTickLabels else None,
list(self.xAxisTickLabels.values()) if self.xAxisTickLabels else None,
np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None,
list(xAxisTickLabels.values()) if xAxisTickLabels else None,
),
"y": (
yticks - shift,
np.array(list(self.yAxisTickLabels.keys())) - shift if self.yAxisTickLabels else None,
list(self.yAxisTickLabels.values()) if self.yAxisTickLabels else None,
np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None,
list(yAxisTickLabels.values()) if yAxisTickLabels else None,
),
}

Expand All @@ -300,7 +433,8 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
[
f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}"
for tick in axis.get_ticklocs()
]
],
fontsize=self.tickLabelsFontSize,
)

# Check if positions are provided.
Expand All @@ -326,14 +460,14 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *

# Apply labels to major ticks if any exist.
if any(e for e in major_labels if e):
axis.set_ticklabels(major_labels, minor=False)
axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize)
else:
# If no major labels, clear major tick labels.
axis.set_ticklabels("")

# Apply labels to minor ticks if any exist.
if any(e for e in minor_labels if e):
axis.set_ticklabels(minor_labels, minor=True)
axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize)

if dim in self.hideMajorTicks:
# Remove major tick marks for asthetic reasons.
Expand All @@ -343,6 +477,9 @@ def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, *
# Remove minor tick marks for asthetic reasons.
axis.set_tick_params(which="minor", length=0)

# Rotate the tick labels by the specified angle.
ax.tick_params(axis=dim, rotation=self.tickLabelsRotation)

# Add vertical and horizontal lines if provided.
if "x" in self.guideLines:
xLines = self.guideLines["x"]
Expand Down

0 comments on commit 537b17e

Please sign in to comment.