Skip to content

Commit

Permalink
Merge pull request #55 from khaeru/enh-2022-W08
Browse files Browse the repository at this point in the history
Improve typing of arithmetic, Quantity.shift(), Quantity.to_dataframe()
  • Loading branch information
khaeru committed Feb 23, 2022
2 parents d4f0d4b + f133e00 commit ca4c4c0
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 13 deletions.
7 changes: 5 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ What's new
:backlinks: none
:depth: 1

.. Next release
.. ============
Next release
============

- Add explicit implementations of :meth:`~.object.__radd__`, :meth:`~.object.__rmul__`, :meth:`~.object.__rsub__` and :meth:`~.object.__rtruediv__` for e.g. ``4.2 * Quantity(...)`` (:pull:`55`)
- Improve typing of :meth:`.Quantity.shift` (:pull:`55`)

v1.9.1 (2022-01-27)
===================
Expand Down
2 changes: 1 addition & 1 deletion genno/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def pow(a, b):
if isinstance(a, AttrSeries):
result = a ** b.align_levels(a)
else:
result = a ** b
result = a**b

result.attrs["_unit"] = (
a.attrs["_unit"] ** unit_exponent
Expand Down
10 changes: 8 additions & 2 deletions genno/core/attrseries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import partial
from typing import Any, Hashable, Iterable, Mapping, Union
from typing import Any, Hashable, Iterable, List, Mapping, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -461,8 +461,14 @@ def transpose(self, *dims):
"""Like :meth:`xarray.DataArray.transpose`."""
return self.reorder_levels(dims)

def to_dataframe(self):
def to_dataframe(
self, name: Hashable = None, dim_order: List[Hashable] = None
) -> pd.DataFrame:
"""Like :meth:`xarray.DataArray.to_dataframe`."""
if dim_order is not None:
raise NotImplementedError("dim_order arg to to_dataframe()")

self.name = name or self.name or "value" # type: ignore
return self.to_frame()

def to_series(self):
Expand Down
20 changes: 20 additions & 0 deletions genno/core/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ def units(self, value):
def __len__(self) -> int:
... # pragma: no cover

def __radd__(self, other):
... # pragma: no cover

def __rmul__(self, other):
... # pragma: no cover

def __rsub__(self, other):
... # pragma: no cover

def __rtruediv__(self, other):
... # pragma: no cover

def __truediv__(self, other) -> "Quantity":
... # pragma: no cover

Expand Down Expand Up @@ -107,6 +119,14 @@ def sel(
) -> "Quantity":
... # pragma: no cover

def shift(
self,
shifts: Mapping[Hashable, int] = None,
fill_value: Any = None,
**shifts_kwargs: int,
): # NB "Quantity" here offends mypy
... # pragma: no cover

def to_numpy(self) -> np.ndarray:
... # pragma: no cover

Expand Down
10 changes: 7 additions & 3 deletions genno/core/sparsedataarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Hashable, Mapping, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, List, Mapping, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -200,12 +200,16 @@ def sel(
._sda.convert()
)

def to_dataframe(self, name=None):
def to_dataframe(
self, name: Hashable = None, dim_order: List[Hashable] = None
) -> pd.DataFrame:
"""Convert this array and its coords into a :class:`~xarray.DataFrame`.
Overrides :meth:`~xarray.DataArray.to_dataframe`.
"""
return self.to_series().to_frame(name)
if dim_order is not None:
raise NotImplementedError("dim_order arg to to_dataframe()")
return self.to_series().to_frame(name or self.name or "value")

def to_series(self) -> pd.Series:
"""Convert this array into a :class:`~pandas.Series`.
Expand Down
2 changes: 1 addition & 1 deletion genno/tests/core/test_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def test_add_product(ureg):
assert key == "x squared:t-y"

# Product has the expected value
assert_qty_equal(Quantity(x * x, units=ureg.kilogram ** 2), c.get(key))
assert_qty_equal(Quantity(x * x, units=ureg.kilogram**2), c.get(key))

# add('product', ...) works
key = c.add("product", "x_squared", "x", "x", sums=True)
Expand Down
26 changes: 25 additions & 1 deletion genno/tests/core/test_quantity.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Tests for genno.quantity."""
import logging
import operator
import re

import pandas as pd
import pint
import pytest
import xarray as xr
from numpy import nan
from pytest import param

from genno import Computer, Quantity, computations
from genno.core.attrseries import AttrSeries
Expand Down Expand Up @@ -270,7 +272,18 @@ def test_size(self):

def test_to_dataframe(self, a):
"""Test Quantity.to_dataframe()."""
assert isinstance(a.to_dataframe(), pd.DataFrame)
# Returns pd.DataFrame
result = a.to_dataframe()
assert isinstance(result, pd.DataFrame)

# "value" is used as a column name
assert ["value"] == result.columns

# Explicitly passed name produces a named column
assert ["foo"] == a.to_dataframe("foo").columns

with pytest.raises(NotImplementedError):
a.to_dataframe(dim_order=["foo", "bar"])

def test_to_series(self, a):
"""Test .to_series() on child classes, and Quantity.from_series."""
Expand All @@ -290,3 +303,14 @@ def test_units(self, a):
# Can be set to dimensionless
a.units = ""
assert a.units.dimensionless

@pytest.mark.parametrize(
"op", [operator.add, operator.mul, operator.sub, operator.truediv]
)
@pytest.mark.parametrize("type_", [int, float, param(str, marks=pytest.mark.xfail)])
def test_arithmetic(self, op, type_, a):
"""Quantity can be added to int or float."""
result = op(type_(4.2), a)

assert (2,) == result.shape
assert a.dtype == result.dtype
6 changes: 3 additions & 3 deletions genno/tests/test_computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_pow(ureg):
result = computations.pow(A, 2)

# Expected units
assert ureg.kg ** 2 == result.attrs["_unit"]
assert ureg.kg**2 == result.attrs["_unit"]

# 2D ** 1D
B = random_qty(dict(y=3))
Expand Down Expand Up @@ -347,9 +347,9 @@ def test_product0():
"dims, exp_size",
(
# Some overlapping dimensions
((dict(a=2, b=2, c=2, d=2), dict(b=2, c=2, d=2, e=2, f=2)), 2 ** 6),
((dict(a=2, b=2, c=2, d=2), dict(b=2, c=2, d=2, e=2, f=2)), 2**6),
# 1D with disjoint dimensions ** 3 = 3D
((dict(a=2), dict(b=2), dict(c=2)), 2 ** 3),
((dict(a=2), dict(b=2), dict(c=2)), 2**3),
# 2D × scalar × scalar = 2D
((dict(a=2, b=2), dict(), dict()), 4),
# scalar × 1D × scalar = 1D
Expand Down

0 comments on commit ca4c4c0

Please sign in to comment.