Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-41027: Fix potential infinite loop in analysis_drp colorColorFitPlot #73

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 53 additions & 16 deletions python/lsst/analysis/drp/colorColorFitPlot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import astropy.units as u
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -106,6 +107,13 @@ class ColorColorFitPlotConfig(pipeBase.PipelineTaskConfig,
optional=True,
)

minPointsForFit = pexConfig.RangeField(
doc="Minimum number of valid objects to bother attempting a fit.",
dtype=int,
default=5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be validated as > 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I made it a RangeField and gave it min=1.

min=1,
)

def setDefaults(self):
super().setDefaults()
self.axisActions.xAction.magDiff.returnMillimags = False
Expand Down Expand Up @@ -205,16 +213,33 @@ def run(self, catPlot, dataId, runName, tableName, bands, plotName):
ys = plotDf[self.config.axisLabels["y"]].values

plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux)
if len(plotDf) == 0:
if len(plotDf) < self.config.minPointsForFit:
fig = plt.Figure()
noDataText = ("No data to plot after selectors applied\n(do you have all three of "
"the bands required: {}?)".format(bands))
if len(plotDf) == 0:
noDataText = ("No data to plot after selectors applied\n(do you have all three of "
"the bands required: {}?)".format(bands))
else:
noDataText = ("Not enough data ({} < config.minPointsForFIt = {}) after selectors for "
"fit.".format(len(plotDf), self.config.minPointsForFit))
fig.text(0.5, 0.5, noDataText, ha="center", va="center")
fig = addPlotInfo(fig, plotInfo)
else:
fitParams = stellarLocusFit(xs, ys, self.config.stellarLocusFitDict)
fig = self.colorColorFitPlot(plotDf, plotInfo, fitParams)

try:
fig = self.colorColorFitPlot(plotDf, plotInfo, fitParams)
if len(fitParams["fitPoints"]) < 1:
raise ValueError("No fitPoints for {}".format(dataId))
except ValueError as e:
self.log.warning("Fit failed for %s with: %s", dataId, e)
fig = plt.Figure()
eStr = e.args[0]
chunks, chunk_size = len(eStr), 50
eStrChunked = ""
for i in range(0, chunks, chunk_size):
eStrChunked += eStr[i:i+chunk_size] + "\n"
noDataText = "Fit failed with:\n{}".format(eStrChunked)
fig.text(0.5, 0.5, noDataText, ha="center", va="center")

fig = addPlotInfo(fig, plotInfo)
return pipeBase.Struct(colorColorFitPlot=fig)

def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
Expand Down Expand Up @@ -272,10 +297,10 @@ def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
the fit line is given in a histogram in the second panel.
"""

self.log.info(("Plotting %s: the values of %s against %s on a color-color plot with the area "
"used for calculating the stellar locus fits marked.",
self.config.connections.plotName, self.config.axisLabels["x"],
self.config.axisLabels["y"]))
self.log.info("Plotting %s: the values of %s against %s on a color-color plot with the area "
"used for calculating the stellar locus fits marked.",
self.config.connections.plotName, self.config.axisLabels["x"],
self.config.axisLabels["y"])

# Define a new colormap
newBlues = mkColormap(["darkblue", "paleturquoise"])
Expand All @@ -290,7 +315,10 @@ def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
ys = catPlot[self.config.axisLabels["y"]].values
mags = catPlot[self.config.axisLabels["mag"]].values

if len(xs) == 0 or len(ys) == 0:
if len(xs) < self.config.minPointsForFit or len(ys) < self.config.minPointsForFit:
noDataText = ("Number of objects after cuts ({}) is less than the minimum\nrequired "
"by config.minPointsForFit ({})".format(len(xs), self.config.minPointsForFit))
fig.text(0.5, 0.5, noDataText, ha="center", va="center")
return fig

# Points used for the fit
Expand All @@ -305,10 +333,11 @@ def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
SNsUsed = (catPlot[SNBand + "_" + plotInfo["SNFlux"]].values[fitPoints]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for SNsUsed to be empty? I imagine the fit likely would have failed in that case, but if there's a chance it would keep going and result in an infinite loop again, maybe an explicit check would be worthwhile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good idea. I put in a check on len(fitParams["fitPoints"]) above with:

if len(fitParams["fitPoints"]) < 1:
    raise ValueError("No fitPoints for {}".format(dataId))

/ catPlot[SNBand + "_" + plotInfo["SNFlux"] + "Err"].values[fitPoints])
minSnUsed = np.nanmin(SNsUsed)
magsUsed = mags[fitPoints]
fluxesUsed = catPlot[SNBand + "_" + self.config.fluxTypeForColor].values[fitPoints]
magsUsed = (fluxesUsed*u.nJy).to_value(u.ABmag)
incr = 5.0
idsUsed = (SNsUsed < minSnUsed + incr)
while sum(idsUsed) < max(0.005*len(idsUsed), 3):
while sum(idsUsed) < max(int(0.005*len(idsUsed)), 1):
incr += 5.0
idsUsed = (SNsUsed < plotInfo["SN"] + incr)
medMagUsed = np.nanmedian(magsUsed[idsUsed])
Expand All @@ -331,11 +360,18 @@ def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
ax.text(0.04, yLoc, infoText, color="C0", transform=ax.transAxes,
fontsize=6, va="center")

# Calculate the density of the points
# Calculate the density of the points. Set all to 0.5 if density
# can't be calculated.
xyUsed = np.vstack([xs[fitPoints], ys[fitPoints]])
xyNotUsed = np.vstack([xs[~fitPoints], ys[~fitPoints]])
zUsed = scipy.stats.gaussian_kde(xyUsed)(xyUsed)
zNotUsed = scipy.stats.gaussian_kde(xyNotUsed)(xyNotUsed)
try:
zUsed = scipy.stats.gaussian_kde(xyUsed)(xyUsed)
except np.linalg.LinAlgError:
zUsed = [0.5]*len(xyUsed)
try:
zNotUsed = scipy.stats.gaussian_kde(xyNotUsed)(xyNotUsed)
except np.linalg.LinAlgError:
zNotUsed = [0.5]*len(xyNotUsed)

notUsedScatter = ax.scatter(xs[~fitPoints], ys[~fitPoints], c=zNotUsed, cmap=newGrays,
s=0.3)
Expand All @@ -350,6 +386,7 @@ def colorColorFitPlot(self, catPlot, plotInfo, fitParams):
ha="center", va="center", fontsize=7)
cbText.set_path_effects([pathEffects.Stroke(linewidth=1.5, foreground="w"), pathEffects.Normal()])
cbAx.set_xticks([np.min(zUsed), np.max(zUsed)], labels=["Less", "More"], fontsize=7)

cbAxNotUsed = fig.add_axes([0.12, 0.11, 0.43, 0.04])
plt.colorbar(notUsedScatter, cax=cbAxNotUsed, orientation="horizontal")
cbText = cbAxNotUsed.text(0.5, 0.5, "Number Density (not used in fit)", color="k",
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/analysis/drp/colorColorPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def colorColorContourPlot(self, catPlot, plotInfo):
# Compute 2d histograms of the stars (used in the contour plot
# function).
countsStars, xEdgesStars, yEdgesStars = np.histogram2d(xsStars, ysStars, bins=self.config.nBins,
normed=False)
density=False)
zsStars = countsStars.transpose()
[vminStars, vmaxStars] = np.nanpercentile(zsStars, [1, 99])
vminStars = max(5, vminStars)
Expand All @@ -484,7 +484,7 @@ def colorColorContourPlot(self, catPlot, plotInfo):
# Compute 2d histograms of the galaxies (used in the contour plot
# function).
countsGals, xEdgesGals, yEdgesGals = np.histogram2d(xsGals, ysGals, bins=self.config.nBins,
normed=False)
density=False)
zsGals = countsGals.transpose()
[vminGals, vmaxGals] = np.nanpercentile(zsGals, [1, 99])
vminGals = max(5, vminGals)
Expand Down