Skip to content

Commit

Permalink
Fix bugs with pint quantity input, column iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
lukelbd committed Aug 20, 2021
1 parent f297017 commit e57d238
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions proplot/axes/plot.py
Expand Up @@ -1153,6 +1153,16 @@ def _load_objects():


# Standardization utilities
def _is_array(data):
"""
Test whether input is numpy array or pint quantity.
"""
# NOTE: This is used in _iter_columns to identify 2D matrices that
# should be iterated over and omit e.g. scalar marker size or marker color.
_load_objects()
return isinstance(data, ndarray) or ndarray is not Quantity and isinstance(data, Quantity) # noqa: E501


def _is_numeric(data):
"""
Test whether input is numeric array rather than datetime or strings.
Expand Down Expand Up @@ -1222,6 +1232,7 @@ def _safe_mask(mask, *args):
Safely apply the mask to the input arrays, accounting for existing masked
or invalid values. Values matching ``False`` are set to `np.nan`.
"""
_load_objects()
invalid = ~mask # True if invalid
args_masked = []
for arg in args:
Expand All @@ -1248,6 +1259,7 @@ def _safe_range(data, lo=0, hi=100, automin=True, automax=True):
for masked values. Use min and max functions when possible for speed. Return
``None`` if we faile to get a valid range.
"""
_load_objects()
units = 1
if ndarray is not Quantity and isinstance(data, Quantity):
data, units = data.magnitude, data.units
Expand Down Expand Up @@ -1369,10 +1381,10 @@ def _get_labels(data, axis=0, always=True):
# data values metadata but that is incorrect. The paradigm for 1D plots
# is we have row coordinates representing x, data values representing y,
# and column coordinates representing individual series.
if axis not in (0, 1, 2):
raise ValueError(f'Invalid axis {axis}.')
labels = None
_load_objects()
if axis not in (0, 1, 2):
raise ValueError(f'Invalid axis {axis}.')
if isinstance(data, (ndarray, Quantity)):
if not always:
pass
Expand Down Expand Up @@ -1446,6 +1458,7 @@ def _get_units(data):
Get the unit string from the `xarray.DataArray` attributes or the
`pint.Quantity`. Format the latter with :rcraw:`unitformat`.
"""
_load_objects()
# Get units from the attributes
if ndarray is not DataArray and isinstance(data, DataArray):
units = data.attrs.get('units', None)
Expand Down Expand Up @@ -1859,6 +1872,7 @@ def _redirect_or_standardize(self, *args, **kwargs):
kwargs[key] = _get_data(data, kwargs[key])

# Auto-setup matplotlib with the input unit registry
_load_objects()
for arg in args:
if ndarray is not DataArray and isinstance(arg, DataArray):
arg = arg.data
Expand Down Expand Up @@ -2939,10 +2953,9 @@ def _iter_columns(self, *args, label=None, labels=None, values=None, **kwargs):
keyword arguments using the input label-list ``'labels'``.
"""
# Handle cycle args and label lists
# WARNING: Must convert to ndarray or can get singleton DataArrays
# WARNING: We do not handle color cycling here because we want to allow
# iterating over columns of scatter() color arrays. Handle in _parse_cycle().
n = max(1 if a.ndim < 2 else a.shape[1] for a in args if isinstance(a, ndarray))
# NOTE: Arrays here should have had metadata stripped by _standardize_1d
# but could still be pint quantities that get processed by axis converter.
n = max(1 if not _is_array(a) or a.ndim < 2 else a.shape[-1] for a in args)
labels = _not_none(label=label, values=values, labels=labels)
if not np.iterable(labels) or isinstance(labels, str):
labels = n * [labels]
Expand All @@ -2957,10 +2970,7 @@ def _iter_columns(self, *args, label=None, labels=None, values=None, **kwargs):
for i in range(n):
kw = kwargs.copy()
kw['label'] = labels[i] or None
a = tuple(
a if not isinstance(a, ndarray) or a.ndim == 1 else a[:, i]
for a in args
)
a = tuple(a if not _is_array(a) or a.ndim < 2 else a[..., i] for a in args)
yield (i, n, *a, kw)

def _parse_cycle(
Expand Down

0 comments on commit e57d238

Please sign in to comment.