Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38386: Add utility function to calculate safe plotting limits #151

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added doc/changes/DM-38386.feature.md
Empty file.
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,13 @@ convention = "numpy"
# D200, D205 and D400 all complain if the first sentence of the docstring does
# not fit on one line.
add-ignore = ["E133", "E226", "E228", "N802", "N803", "N806", "N812", "N815", "N816", "W503", "E203"]

[tool.coverage.report]
show_missing = true
exclude_lines = [
"pragma: no cover",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
89 changes: 89 additions & 0 deletions python/lsst/utils/plotting/limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# This file is part of utils.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# Use of this source code is governed by a 3-clause BSD-style
# license that can be found in the LICENSE file.

from __future__ import annotations

from collections.abc import Sequence
from typing import Iterable, Optional

import numpy as np


def calculate_safe_plotting_limits(
data_series: Sequence, # a sequence of sequences is still a sequence
percentile: float = 99.9,
constant_extra: Optional[float] = None,
symmetric_around_zero: bool = False,
) -> tuple[float, float]:
"""Calculate the right limits for plotting for one or more data series.

Given one or more data series with potential outliers, calculated the
values to pass for ymin, ymax so that extreme outliers don't ruin the plot.
If you are plotting several series on a single axis, pass them all in and
the overall plotting range will be given.

Parameters
----------
data_series : `iterable` or `iterable` of `iterable`
One or more data series which will be going on the same axis, and
therefore want to have their common plotting limits calculated.
percentile : `float`, optional
The percentile used to clip the outliers from the data.
constant_extra : `float`, optional
The amount that's added on each side of the range so that data does not
quite touch the axes. If the default ``None`` is left then 5% of the
data range is added for cosmetics, but if zero is set this will
overrides this behaviour and zero you will get.
symmetric_around_zero : `bool`, optional
Make the limits symmetric around zero?
Returns
-------
ymin : `float`
The value to set the ylim minimum to.
ymax : `float`
The value to set the ylim maximum to.
"""
if not isinstance(data_series, Iterable):
raise ValueError("data_series must be either an iterable, or an iterable of iterables")

Check warning on line 55 in python/lsst/utils/plotting/limits.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/utils/plotting/limits.py#L55

Added line #L55 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a quick test that this ValueError happens if you pass in something like a single float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, and changed it to a TypeError too, because, well, it is 🙂


# now we're sure we have an iterable, if it's just one make it a list of it
# lsst.utils.ensure_iterable is not suitable here as we already have one,
# we would need ensure_iterable_of_iterables here
if not isinstance(data_series[0], Iterable): # np.array are Iterable but not Sequence so isinstance that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that we are being a bit inconsistent. Remind me what the type error was if Iterable is used? numpy not being a Sequence is a bit odd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I hate hate hate this. Getting this to work with a list, a numpy array, and a column from a pandas.dataframe, and an iterable of all of these was not super fun. This was the best solution I found. It is quite likely someone who knows what they're doing with mypy could do better, but I could not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am absolutely not suggesting a change to this particular code (especially now that it's working and merged), but in general when you run into trouble making a whole host of differently-behaving argument types work in single function, it might be time to make it into multiple functions that share code. Python was not designed to support function overloading, and while we often want it to avoid having to come up with or export multiple names, sometimes that's the better price to pay.

# we have a single data series, not multiple, wrap in [] so we can
# iterate over it as if we were given many
data_series = [data_series] #
mfisherlevine marked this conversation as resolved.
Show resolved Hide resolved

mins = []
maxs = []

for data in data_series:
max_val = np.nanpercentile(data, percentile)
min_val = np.nanpercentile(data, 100 - percentile)
mfisherlevine marked this conversation as resolved.
Show resolved Hide resolved

if constant_extra is None:
data_range = max_val - min_val
constant_extra = 0.05 * data_range

max_val += constant_extra
min_val -= constant_extra

maxs.append(max_val)
mins.append(min_val)

max_val = max(maxs)
min_val = min(mins)

if symmetric_around_zero:
biggest_abs = max(abs(min_val), abs(max_val))
return -biggest_abs, biggest_abs

return min_val, max_val
88 changes: 88 additions & 0 deletions tests/test_plotting_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This file is part of utils.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# Use of this source code is governed by a 3-clause BSD-style
# license that can be found in the LICENSE file.
#

import unittest

import numpy as np
from lsst.utils.plotting.limits import calculate_safe_plotting_limits


class PlottingLimitsTestCase(unittest.TestCase):
"""Tests for `calculate_safe_plotting_limits` function."""

xs = np.linspace(0, 10, 10000)
series1 = np.sin(xs + 3.1415 / 2) + 0.75 # min=-0.24999, max=1.74999
series1_min = min(series1)
series1_max = max(series1)

series2 = np.sin(xs) + 1.2 # min=0.2, max=2.19999
series2_min = min(series2)
series2_max = max(series2)

outliers = series1[:]
outliers[1000] = 20
outliers[2000] = -1000

def testSingleSeries(self):
"""Test that a single series works and the outliers exclusion works."""
# Deliberately test the bounds are the same when using the series
# itself, and the copy with the outlier values, i.e. using
# self.series1_min/max inside the loop depsite changing the series we
mfisherlevine marked this conversation as resolved.
Show resolved Hide resolved
# loop over is the intent here, not a bug.
for series in [self.series1, self.outliers]:
ymin, ymax = calculate_safe_plotting_limits(series)
self.assertLess(ymin, self.series1_min)
self.assertGreater(ymin, self.series1_min - 1)
self.assertLess(ymax, self.series1_max + 1)
self.assertGreater(ymax, self.series1_max)

def testMultipleSeries(self):
"""Test that passing multiple several series in works wrt outliers."""
ymin, ymax = calculate_safe_plotting_limits([self.series1, self.outliers])
self.assertLess(ymin, self.series1_min)
self.assertGreater(ymin, self.series1_min - 1)
self.assertLess(ymax, self.series1_max + 1)
self.assertGreater(ymax, self.series1_max)

def testMultipleSeriesCommonRange(self):
"""Test that passing multiple several series in works wrt outliers."""
ymin, ymax = calculate_safe_plotting_limits([self.series1, self.series2])
# lower bound less than the lowest of the two
self.assertLess(ymin, min(self.series1_min, self.series2_min))
# lower bound less than the lowest of the two, but not by much
self.assertGreater(ymin, min(self.series1_min, self.series2_min) - 1)
# upper bound greater than the highest of the two
self.assertGreater(ymax, max(self.series1_max, self.series2_max))
# upper bound greater than the highest of the two, but not by much
self.assertLess(ymax, max(self.series1_max, self.series2_max) + 1)

def testSymmetric(self):
"""Test that the symmetric option works"""
ymin, ymax = calculate_safe_plotting_limits([self.series1, self.outliers], symmetric_around_zero=True)
self.assertEqual(ymin, -ymax)
self.assertGreater(ymax, self.series1_max)
self.assertLess(ymin, self.series1_min)

def testConstantExtra(self):
"""Test that the constantExtra option works"""
strictMin, strictMax = calculate_safe_plotting_limits([self.series1, self.outliers], constant_extra=0)
self.assertAlmostEqual(strictMin, self.series1_min, places=4)
self.assertAlmostEqual(strictMax, self.series1_max, places=4)

for extra in [-2.123, -1, 0, 1, 1.5, 23]:
ymin, ymax = calculate_safe_plotting_limits([self.series1, self.outliers], constant_extra=extra)
self.assertAlmostEqual(ymin, self.series1_min - extra, places=4)
self.assertAlmostEqual(ymax, self.series1_max + extra, places=4)


if __name__ == "__main__":
unittest.main()
mfisherlevine marked this conversation as resolved.
Show resolved Hide resolved