Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly deal with matmul and document unsupported functions #1216

Merged
merged 1 commit into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 38 additions & 1 deletion brian2/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DIMENSIONLESS,
fail_for_dimension_mismatch)
from brian2.units.allunits import *
from brian2.units.stdunits import ms, mV, kHz, nS, cm, Hz, mM
from brian2.units.stdunits import ms, mV, kHz, nS, cm, Hz, mM, nA
from brian2.tests.utils import assert_allclose


Expand Down Expand Up @@ -900,6 +900,43 @@ def test_numpy_functions_change_dimensions():
1.0 / volt)


@pytest.mark.codegen_independent
def test_numpy_functions_matmul():
'''
Check support for matmul and the ``@`` operator.
'''
no_units_eye = np.eye(3)
with_units_eye = no_units_eye*Mohm
matrix_no_units = np.arange(9).reshape((3, 3))
matrix_units = matrix_no_units*nA

# First operand with units
assert_allclose(no_units_eye @ matrix_units, matrix_units)
assert have_same_dimensions(no_units_eye @ matrix_units, matrix_units)
assert_allclose(np.matmul(no_units_eye, matrix_units), matrix_units)
assert have_same_dimensions(np.matmul(no_units_eye, matrix_units), matrix_units)

# Second operand with units
assert_allclose(with_units_eye @ matrix_no_units,
matrix_no_units*Mohm)
assert have_same_dimensions(with_units_eye @ matrix_no_units,
matrix_no_units*Mohm)
assert_allclose(np.matmul(with_units_eye, matrix_no_units),
matrix_no_units*Mohm)
assert have_same_dimensions(np.matmul(with_units_eye, matrix_no_units),
matrix_no_units*Mohm)

# Both operands with units
assert_allclose(with_units_eye @ matrix_units,
no_units_eye @ matrix_no_units * nA * Mohm)
assert have_same_dimensions(with_units_eye @ matrix_units,
nA*Mohm)
assert_allclose(np.matmul(with_units_eye, matrix_units),
np.matmul(no_units_eye, matrix_no_units) * nA * Mohm)
assert have_same_dimensions(np.matmul(with_units_eye, matrix_units),
nA * Mohm)


@pytest.mark.codegen_independent
def test_numpy_functions_typeerror():
'''
Expand Down
6 changes: 3 additions & 3 deletions brian2/units/fundamentalunits.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _short_str(arr):
#: ufuncs that work on all dimensions but change the dimensions, e.g. square
UFUNCS_CHANGE_DIMENSIONS = ['multiply', 'divide', 'true_divide',
'floor_divide', 'sqrt', 'square', 'reciprocal',
'dot']
'dot', 'matmul']

#: ufuncs that work with matching dimensions, e.g. add
UFUNCS_MATCHING_DIMENSIONS = ['add', 'subtract', 'maximum', 'minimum',
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def __array_prepare__(self, array, context=None):
def __array_wrap__(self, array, context=None):
dim = DIMENSIONLESS

if not context is None:
if context is not None:
uf, args, _ = context
if uf.__name__ in (UFUNCS_PRESERVE_DIMENSIONS +
UFUNCS_MATCHING_DIMENSIONS):
Expand All @@ -1058,7 +1058,7 @@ def __array_wrap__(self, array, context=None):
dim = get_dimensions(args[0]) / get_dimensions(args[1])
elif uf.__name__ == 'reciprocal':
dim = get_dimensions(args[0]) ** -1
elif uf.__name__ in ('multiply', 'dot'):
elif uf.__name__ in ('multiply', 'dot', 'matmul'):
dim = get_dimensions(args[0]) * get_dimensions(args[1])
else:
warn("Unknown ufunc '%s' in __array_wrap__" % uf.__name__)
Expand Down
5 changes: 4 additions & 1 deletion docs_sphinx/developer/units.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ they are only using functions/methods that work with quantities:
* ``correlate`` (returns a quantity with wrong units)
* ``histogramdd`` (raises a ``DimensionMismatchError``)

**other unsupported functions**
Functions in ``numpy``'s subpackages such as ``linalg`` are not supported and will
either not work with units, or remove units from their inputs.

User-defined functions and units
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For performance and simplicity reasons, code within the Brian core does not use
Quantity objects but unitless numpy arrays instead. See :doc:`functions` for
details on how to make use user-defined functions with Brian's unit system.