Skip to content

Commit

Permalink
Merge pull request #110 from khaeru/enh/2023-W50
Browse files Browse the repository at this point in the history
Miscellaneous enhancements for 2023-W50
  • Loading branch information
khaeru committed Dec 13, 2023
2 parents 261e307 + c789cfd commit a90dc01
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 81 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.x"

- name: Force recreation of pre-commit virtual environment for mypy
if: github.event_name == 'schedule' # Comment this line to run on a PR
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies:
Expand All @@ -13,7 +13,7 @@ repos:
- types-PyYAML
- types-pytz
- types-python-dateutil
- types-setuptools
- types-requests
- xarray
args: []
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
3 changes: 2 additions & 1 deletion doc/compat-plotnine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ To use :class:`.Plot`:
...: + p9.geom_point(color="blue")
...: )

2. Call :meth:`.make_task` to get a task tuple suitable for adding to a :class:`.Computer`:
2. :meth:`~.Computer.add` the class to a :class:`.Computer` directly.
The :meth:`.Plot.add_tasks` method handles connecting the :attr:`.Plot.inputs` to :meth:`.Plot.save`:

.. ipython:: python
Expand Down
16 changes: 14 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
What's new
**********

.. Next release
.. ============
Next release
============

- New attribute :attr:`.Plot.path`, allowing control of the full path used to write plots (:pull:`110`).
- Bugfix: :meth:`.AttrSeries.sel` with a scalar indexer (for instance, :py:`qty.sel(x="foo")`) formerly did *not* drop the selected dimension; this was in contrast to :meth:`xarray.DataArray.sel`.
The behaviour is now consistent (:pull:`110`):

- :py:`qty.sel(x=["foo"])`, a length-1 sequence of indexers: the dimension is retained.
- :py:`qty.sel(x="foo")`, a single scalar indexer: the dimension is dropped.
- Small fixes in :class:`.SparseDataArray` (:pull:`110`):

- The :attr:`.Quantity.name` is preserved when an :class:`xarray.DataArray` is passed to the constructor.
- :meth:`~.SparseDataArray.to_series` works with 0-D (scalar) quantities.
- Provide typed signature for :meth:`.Quantity.squeeze` for the benefit of downstream applications (:pull:`110`).

v1.21.0 (2023-11-28)
====================
Expand Down
3 changes: 2 additions & 1 deletion genno/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dask.core import quote
from dask.core import literal, quote

from . import computations
from .config import configure
Expand All @@ -19,4 +19,5 @@
"computations",
"configure",
"quote",
"literal",
]
9 changes: 0 additions & 9 deletions genno/compat/plotnine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
import warnings

# NB does not seem to have any effect. The entry in setup.cfg / [pytest] achieves this.
warnings.filterwarnings(
action="ignore",
message="Using or importing the ABCs from 'collections'",
module="patsy",
)

try:
import plotnine # noqa: F401
except ModuleNotFoundError: # pragma: no cover
Expand Down
48 changes: 30 additions & 18 deletions genno/compat/plotnine/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from abc import ABC, abstractmethod
from typing import Hashable, Sequence
from pathlib import Path
from typing import Hashable, Optional, Sequence
from warnings import warn

import plotnine as p9
Expand All @@ -15,34 +16,43 @@
class Plot(ABC):
"""Class for plotting using :doc:`plotnine <plotnine:index>`."""

#: Filename base for saving the plot.
#: File name base for saving the plot.
basename = ""
#: File extension; determines file format.
suffix = ".pdf"
#: Keys for quantities needed by :meth:`generate`.
#: Path for file output. If it is not set, :meth:`save` will populate it with a
#: value constructed from :py:`config["output_dir"]`, :attr:`basename`, and
#: :attr:`suffix`. The implementation of :meth:`generate` in a Plot sub-class may
#: assign any other value, for instance one constructed at runtime from the
#: :attr:`inputs`.
path: Optional[Path] = None
#: :class:`Keys <.Key>` referring to :class:`Quantities <.Quantity>` or other inputs
#: accepted by :meth:`generate`.
inputs: Sequence[Hashable] = []
#: Keyword arguments for :meth:`plotnine.ggplot.save`.
save_args = dict(verbose=False)

# TODO add static geoms automatically in generate()
__static: Sequence = []

def save(self, config, *args, **kwargs):
def save(self, config, *args, **kwargs) -> Optional[Path]:
"""Prepare data, call :meth:`.generate`, and save to file.
This method is used as the callable in the task generated by :meth:`.make_task`.
This method is used as the callable in the task generated by :meth:`.add_tasks`.
"""
path = config["output_dir"] / f"{self.basename}{self.suffix}"
self.path = self.path or (
config["output_dir"] / f"{self.basename}{self.suffix}"
)

missing = tuple(filter(lambda arg: isinstance(arg, str), args))
if len(missing):
log.error(
f"Missing input(s) {missing!r} to plot {self.basename!r}; no output"
)
return
return None

# Convert Quantity arguments to pd.DataFrame for use with plotnine
args = map(
_args = map(
lambda arg: arg
if not isinstance(arg, Quantity)
else arg.to_series()
Expand All @@ -52,25 +62,25 @@ def save(self, config, *args, **kwargs):
args,
)

plot_or_plots = self.generate(*args, **kwargs)
plot_or_plots = self.generate(*_args, **kwargs)

if not plot_or_plots:
log.info(
f"{self.__class__.__name__}.generate() returned {plot_or_plots!r}; no "
"output"
)
return
return None

log.info(f"Save to {path}")
log.info(f"Save to {self.path}")

try:
# Single plot
plot_or_plots.save(path, **self.save_args)
plot_or_plots.save(self.path, **self.save_args)
except AttributeError:
# Iterator containing 0 or more plots
p9.save_as_pdf_pages(plot_or_plots, path, **self.save_args)
p9.save_as_pdf_pages(plot_or_plots, self.path, **self.save_args)

return path
return self.path

@classmethod
def make_task(cls, *inputs):
Expand Down Expand Up @@ -119,12 +129,14 @@ def add_tasks(
def generate(self, *args, **kwargs):
"""Generate and return the plot.
Must be implemented by subclasses.
A subclass of Plot **must** implement this method.
Parameters
----------
args : sequence of pandas.DataFrame
args : sequence of pandas.DataFrame or other
One argument is given corresponding to each of the :attr:`inputs`.
Because :doc:`plotnine <plotnine:index>` operates on pandas data structures,
:meth:`save` automatically converts :obj:`.Quantity` before they are passed
to :meth:`generate`.
:meth:`save` automatically converts any :class:`.Quantity` inputs to
:class:`pandas.DataFrame` before they are passed to :meth:`generate`.
"""
13 changes: 12 additions & 1 deletion genno/compat/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def cumprod(

def drop_vars(
self,
names: Union[Hashable, Iterable[Hashable]],
names: Union[
str, Iterable[Hashable], Callable[[Any], Union[str, Iterable[Hashable]]]
],
*,
errors="raise",
):
Expand Down Expand Up @@ -237,6 +239,15 @@ def shift(
"""Like :attr:`xarray.DataArray.shift`."""
...

def squeeze(
self,
dim: Union[Hashable, Iterable[Hashable], None] = None,
drop: bool = False,
axis: Union[int, Iterable[int], None] = None,
):
"""Like :meth:`xarray.DataArray.squeeze`."""
...

def sum(
self,
dim: Dims = None,
Expand Down
34 changes: 10 additions & 24 deletions genno/core/attrseries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from functools import partial
from itertools import tee
from typing import (
Expand Down Expand Up @@ -381,13 +380,6 @@ def sel(

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")

if len(indexers) == 1:
level, key = list(indexers.items())[0]
if isinstance(key, str) and not drop:
# When using .loc[] to select 1 label on 1 level, pandas drops the
# level. Use .xs() to avoid this behaviour unless drop=True
return AttrSeries(self.xs(key, level=level, drop_level=False))

if len(indexers) and all(
isinstance(i, xr.DataArray) for i in indexers.values()
):
Expand Down Expand Up @@ -447,22 +439,14 @@ def sel(
# Get an indexer for this dimension
i = indexers.get(dim, slice(None))

if is_scalar(i) and (i != slice(None)) and drop:
if is_scalar(i) and (i != slice(None)):
to_drop.add(dim)

# Maybe unpack an xarray DataArray indexers, for pandas
idx.append(i.data if isinstance(i, xr.DataArray) else i)

# Silence a warning from pandas ≥1.4 that may be spurious
# FIXME investigate, adjust the code, remove the filter
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
".*indexing on a MultiIndex with a nested sequence.*",
FutureWarning,
)
# Select
data = self.loc[tuple(idx)]
# Select
data = self.loc[tuple(idx)]

# Only drop if not returning a scalar value
if isinstance(data, pd.Series):
Expand Down Expand Up @@ -523,9 +507,8 @@ def sum(
# Create the object on which to .sum()
return self._replace(self._maybe_groupby(dim).sum(**kwargs))

def squeeze(self, dim=None, *args, **kwargs):
def squeeze(self, dim=None, drop=False, axis=None):
"""Like :meth:`xarray.DataArray.squeeze`."""
assert kwargs.pop("drop", True)

idx = self.index.remove_unused_levels()

Expand All @@ -545,10 +528,13 @@ def squeeze(self, dim=None, *args, **kwargs):
to_drop.append(name)

if dim and not to_drop:
# Specified dimension does not exist
raise KeyError(dim)
raise KeyError(dim) # Specified dimension does not exist

return self.droplevel(to_drop)
if set(to_drop) == set(self.dims):
# Dropping all dimensions → 0-D quantity; simply reset
return self.reset_index(drop=True)
else:
return self.droplevel(to_drop)

def transpose(self, *dims):
"""Like :meth:`xarray.DataArray.transpose`."""
Expand Down
27 changes: 24 additions & 3 deletions genno/core/sparsedataarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from warnings import filterwarnings

import numpy as np
import pandas as pd
Expand All @@ -9,12 +10,22 @@
HAS_SPARSE = True
except ImportError: # pragma: no cover
HAS_SPARSE = False

import xarray as xr
from xarray.core import dtypes
from xarray.core.utils import either_dict_or_kwargs

from genno.core.quantity import Quantity, possible_scalar

# sparse.COO raises this warning when the data is 0-D / length-1; self.coords.size is
# then 0 (no dimensions = no coordinates)
filterwarnings(
"ignore",
"coords should be an ndarray.*",
DeprecationWarning,
"sparse._coo.core",
)


def _binop(name: str, swap: bool = False):
"""Create a method for binary operator `name`."""
Expand Down Expand Up @@ -165,6 +176,7 @@ def __init__(
if isinstance(data, xr.DataArray):
# Possibly converted from pd.Series, above
coords = data._coords
name = name or data.name
data = data.variable

# Invoke the xr.DataArray constructor
Expand Down Expand Up @@ -241,6 +253,11 @@ def sel(
._sda.convert()
)

def squeeze(self, dim=None, drop=False, axis=None):
return self._sda.dense_super.squeeze(
dim=dim, drop=drop, axis=axis
)._sda.convert()

def to_dataframe(
self,
name: Optional[Hashable] = None,
Expand All @@ -263,7 +280,11 @@ def to_series(self) -> pd.Series:
# Use SparseArray.coords and .data (each already 1-D) to construct the pd.Series

# Construct a pd.MultiIndex without using .from_product
index = pd.MultiIndex.from_arrays(self.data.coords, names=self.dims).set_levels(
[self.coords[d].values for d in self.dims]
)
if self.dims:
index = pd.MultiIndex.from_arrays(
self.data.coords, names=self.dims
).set_levels([self.coords[d].values for d in self.dims])
else:
index = pd.MultiIndex.from_arrays([[0]], names=[None])

return pd.Series(self.data.data, index=index, name=self.name)
5 changes: 4 additions & 1 deletion genno/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Hashable,
Iterable,
Expand Down Expand Up @@ -515,7 +516,9 @@ def div(numerator: Union[Quantity, float], denominator: Quantity) -> Quantity:

def drop_vars(
qty: Quantity,
names: Union[Hashable, Iterable[Hashable]],
names: Union[
str, Iterable[Hashable], Callable[[Quantity], Union[str, Iterable[Hashable]]]
],
*,
errors="raise",
) -> Quantity:
Expand Down
4 changes: 4 additions & 0 deletions genno/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
# Pytest hooks


def pytest_sessionstart(session):
logging.getLogger("numba").setLevel(logging.INFO)


def pytest_runtest_makereport(item, call):
"""Pytest hook to unwrap :class:`genno.ComputationError`.
Expand Down

0 comments on commit a90dc01

Please sign in to comment.