Skip to content

Commit

Permalink
Merge pull request #37 from khaeru/issue/27
Browse files Browse the repository at this point in the history
Quantity as an actual class, rather than factory
  • Loading branch information
khaeru committed Mar 7, 2021
2 parents c79a727 + b8b02cc commit 1ad7e5b
Show file tree
Hide file tree
Showing 14 changed files with 463 additions and 183 deletions.
10 changes: 4 additions & 6 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ Top-level classes and functions
<foo:a-b-c>


.. autodata:: genno.Quantity(data, *args, **kwargs)
:annotation:
.. autoclass:: genno.Quantity
:members:

The :data:`.Quantity` constructor converts its arguments to an internal, :class:`xarray.DataArray`-like data format:
The :class:`.Quantity` constructor converts its arguments to an internal, :class:`xarray.DataArray`-like data format:

.. code-block:: python
Expand All @@ -307,9 +307,7 @@ Computations
.. automodule:: genno.computations
:members:

Unless otherwise specified, these methods accept and return
:class:`Quantity <genno.utils.Quantity>` objects for data
arguments/return values.
Unless otherwise specified, these methods accept and return :class:`.Quantity` objects for data arguments/return values.

Genno's :ref:`compatibility modules <compat>` each provide additional computations.

Expand Down
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinx.ext.todo",
"sphinx.ext.viewcode",
]

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -60,6 +61,7 @@
"dask": ("https://docs.dask.org/en/stable/", None),
"ixmp": ("https://docs.messageix.org/projects/ixmp/en/latest", None),
"message_ix": ("https://docs.messageix.org/en/latest", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
"pint": ("https://pint.readthedocs.io/en/stable/", None),
"plotnine": ("https://plotnine.readthedocs.io/en/stable/", None),
"pyam": ("https://pyam-iamc.readthedocs.io/en/stable/", None),
Expand Down
10 changes: 8 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ What's new
:backlinks: none
:depth: 1

.. Next release
.. ============
Next release
============

- :class:`.Quantity` becomes an actual class, rather than a factory function; :class:`.AttrSeries` and :class:`.SparseDataArray` are subclasses (:pull:`37`).
- :class:`.AttrSeries` gains methods :meth:`~.AttrSeries.bfill`, :meth:`~.AttrSeries.cumprod`, :meth:`~.AttrSeries.ffill`, and :meth:`~.AttrSeries.shift` (:pull:`37`)
- :func:`.computations.load_file` uses the `skipinitialspace` parameter to :func:`pandas.read_csv`; extra dimensions not mentioned in the `dims` parameter are preserved (:pull:`37`).
- :meth:`.AttrSeries.sel` accepts :class:`xarray.DataArray` for xarray-style indexing (:pull:`37`).


v1.1.1 (2021-02-22)
===================
Expand Down
101 changes: 59 additions & 42 deletions genno/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pandas as pd
import pint

from .core.quantity import Quantity, assert_quantity
from .util import collect_units, filter_concat_args
from genno.core.attrseries import AttrSeries
from genno.core.quantity import Quantity, assert_quantity
from genno.util import collect_units, filter_concat_args

__all__ = [
"add",
Expand Down Expand Up @@ -55,12 +56,12 @@ def add(*quantities, fill_value=0.0):
# Ensure arguments are all quantities
assert_quantity(*quantities)

if Quantity.CLASS == "SparseDataArray":
# Use xarray's built-in broadcasting, return to Quantity class
quantities = map(Quantity, xr.broadcast(*quantities))
else:
if isinstance(quantities[0], AttrSeries):
# map() returns an iterable
quantities = iter(quantities)
else:
# Use xarray's built-in broadcasting, return to Quantity class
quantities = map(Quantity, xr.broadcast(*quantities))

# Initialize result values with first entry
result = next(quantities)
Expand All @@ -74,7 +75,7 @@ def add(*quantities, fill_value=0.0):

factor = u.from_(1.0, strict=False).to(ref_unit).magnitude

if Quantity.CLASS == "AttrSeries":
if isinstance(q, AttrSeries):
result = result.add(factor * q, fill_value=fill_value).dropna()
else:
result = result + factor * q
Expand All @@ -89,12 +90,12 @@ def aggregate(quantity, groups, keep):
----------
quantity : :class:`Quantity <genno.utils.Quantity>`
groups: dict of dict
Top-level keys are the names of dimensions in `quantity`. Second-level
keys are group names; second-level values are lists of labels along the
dimension to sum into a group.
Top-level keys are the names of dimensions in `quantity`. Second-level keys are
group names; second-level values are lists of labels along the dimension to sum
into a group.
keep : bool
If True, the members that are aggregated into a group are returned with
the group sums. If False, they are discarded.
If True, the members that are aggregated into a group are returned with the
group sums. If False, they are discarded.
Returns
-------
Expand All @@ -113,7 +114,8 @@ def aggregate(quantity, groups, keep):
agg = (
quantity.sel({dim: members}).sum(dim=dim).assign_coords(**{dim: group})
)
if Quantity.CLASS == "AttrSeries":

if isinstance(agg, AttrSeries):
# .transpose() is necessary for AttrSeries
agg = agg.transpose(*quantity.dims)
else:
Expand Down Expand Up @@ -220,8 +222,8 @@ def combine(*quantities, select=None, weights=None): # noqa: F811
# Check units
units = collect_units(*quantities)
for u in units:
# TODO relax this condition: modify the weights with conversion factors
# if the units are compatible, but not the same
# TODO relax this condition: modify the weights with conversion factors if the
# units are compatible, but not the same
if u != units[0]:
raise ValueError(f"Cannot combine() units {units[0]} and {u}")
units = units[0]
Expand All @@ -247,16 +249,27 @@ def combine(*quantities, select=None, weights=None): # noqa: F811


def concat(*objs, **kwargs):
"""Concatenate Quantity *objs*.
"""Concatenate Quantity `objs`.
Any strings included amongst *args* are discarded, with a logged warning;
these usually indicate that a quantity is referenced which is not in the
Computer.
Any strings included amongst `objs` are discarded, with a logged warning; these
usually indicate that a quantity is referenced which is not in the Computer.
"""
objs = filter_concat_args(objs)
if Quantity.CLASS == "AttrSeries":
# Silently discard any "dim" keyword argument
kwargs.pop("dim", None)
if Quantity._get_class() is AttrSeries:
try:
# Retrieve a "dim" keyword argument
dim = kwargs.pop("dim")
except KeyError:
pass
else:
if isinstance(dim, pd.Index):
# Convert a pd.Index argument to names and keys
kwargs["names"] = [dim.name]
kwargs["keys"] = dim.values
else:
# Something else; warn and discard
log.warning(f"Ignore concat(…, dim={repr(dim)})")

return pd.concat(objs, **kwargs)
else:
# Correct fill-values
Expand Down Expand Up @@ -284,22 +297,22 @@ def group_sum(qty, group, sum):
def load_file(path, dims={}, units=None, name=None):
"""Read the file at *path* and return its contents as a :class:`.Quantity`.
Some file formats are automatically converted into objects for direct use
in genno computations:
Some file formats are automatically converted into objects for direct use in genno
computations:
:file:`.csv`:
Converted to :class:`.Quantity`. CSV files must have a 'value' column;
all others are treated as indices, except as given by `dims`. Lines
beginning with '#' are ignored.
Converted to :class:`.Quantity`. CSV files must have a 'value' column; all others
are treated as indices, except as given by `dims`. Lines beginning with '#' are
ignored.
Parameters
----------
path : pathlib.Path
Path to the file to read.
dims : collections.abc.Collection or collections.abc.Mapping, optional
If a collection of names, other columns besides these and 'value' are
discarded. If a mapping, the keys are the column labels in `path`, and
the values are the target dimension names.
If a collection of names, other columns besides these and 'value' are discarded.
If a mapping, the keys are the column labels in `path`, and the values are the
target dimension names.
units : str or pint.Unit
Units to apply to the loaded Quantity.
name : str
Expand All @@ -308,7 +321,7 @@ def load_file(path, dims={}, units=None, name=None):
# TODO optionally cache: if the same Computer is used repeatedly, then the file will
# be read each time; instead cache the contents in memory.
if path.suffix == ".csv":
data = pd.read_csv(path, comment="#")
data = pd.read_csv(path, comment="#", skipinitialspace=True)

# Index columns
index_columns = data.columns.tolist()
Expand All @@ -324,26 +337,26 @@ def load_file(path, dims={}, units=None, name=None):
# Use a unique value for units of the quantity
if len(units_col) > 1:
raise ValueError(
f"Cannot load {path} with non-unique units " + repr(units_col)
f"Cannot load {path} with non-unique units {repr(units_col)}"
)
elif units and units not in units_col:
raise ValueError(
f"Explicit units {units} do not match " f"{units_col[0]} in {path}"
f"Explicit units {units} do not match {units_col[0]} in {path}"
)
units = units_col[0]

if len(dims):
# Use specified dimensions
if not isinstance(dims, Mapping):
# Convert a list, set, etc. to a dict
dims = {d: d for d in dims}
# Convert a list, set, etc. to a dict
dims = dims if isinstance(dims, Mapping) else {d: d for d in dims}

# - Drop columns not mentioned in *dims*
# - Rename columns according to *dims*
data = data.drop(columns=set(index_columns) - set(dims.keys())).rename(
columns=dims
)
index_columns = list(dims.values())

index_columns = list(data.columns)
index_columns.pop(index_columns.index("value"))

return Quantity(data.set_index(index_columns)["value"], units=units, name=name)
elif path.suffix in (".xls", ".xlsx"):
Expand Down Expand Up @@ -381,7 +394,7 @@ def pow(a, b):
if not u_b.dimensionless:
raise ValueError(f"Cannot raise to a power with units ({u_b:~})")

if Quantity.CLASS == "AttrSeries":
if isinstance(a, AttrSeries):
result = a ** b.align_levels(a)
else:
result = a ** b
Expand All @@ -405,7 +418,7 @@ def product(*quantities):

# Iterate over remaining entries
for q, u in items:
if Quantity.CLASS == "AttrSeries":
if isinstance(q, AttrSeries):
# Work around pandas-dev/pandas#25760; see attrseries.py
result = (result * q.align_levels(result)).dropna()
else:
Expand All @@ -428,7 +441,7 @@ def ratio(numerator, denominator):
# Handle units
u_num, u_denom = collect_units(numerator, denominator)

if Quantity.CLASS == "AttrSeries":
if isinstance(numerator, AttrSeries):
result = numerator / denominator.align_levels(numerator)
else:
result = numerator / denominator
Expand All @@ -439,12 +452,16 @@ def ratio(numerator, denominator):
ureg = pint.get_application_registry()
result.attrs["_unit"] = ureg.Unit(u_num) / ureg.Unit(u_denom)

if Quantity.CLASS == "AttrSeries":
if isinstance(result, AttrSeries):
result.dropna(inplace=True)

return result


#: TODO make this the actual method name; emit DeprecationWarning if ratio() is used
div = ratio


def select(qty, indexers, inverse=False):
"""Select from *qty* based on *indexers*.
Expand Down

0 comments on commit 1ad7e5b

Please sign in to comment.