Skip to content

Commit

Permalink
Merge pull request #1216 from brian-team/fix_#1212
Browse files Browse the repository at this point in the history
Correctly deal with matmul and document unsupported functions
  • Loading branch information
mstimberg committed Aug 27, 2020
2 parents 6a14c76 + 8708bfe commit 56cbdf6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
39 changes: 38 additions & 1 deletion brian2/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DIMENSIONLESS,
fail_for_dimension_mismatch)
from brian2.units.allunits import *
from brian2.units.stdunits import ms, mV, kHz, nS, cm, Hz, mM
from brian2.units.stdunits import ms, mV, kHz, nS, cm, Hz, mM, nA
from brian2.tests.utils import assert_allclose


Expand Down Expand Up @@ -900,6 +900,43 @@ def test_numpy_functions_change_dimensions():
1.0 / volt)


@pytest.mark.codegen_independent
def test_numpy_functions_matmul():
'''
Check support for matmul and the ``@`` operator.
'''
no_units_eye = np.eye(3)
with_units_eye = no_units_eye*Mohm
matrix_no_units = np.arange(9).reshape((3, 3))
matrix_units = matrix_no_units*nA

# First operand with units
assert_allclose(no_units_eye @ matrix_units, matrix_units)
assert have_same_dimensions(no_units_eye @ matrix_units, matrix_units)
assert_allclose(np.matmul(no_units_eye, matrix_units), matrix_units)
assert have_same_dimensions(np.matmul(no_units_eye, matrix_units), matrix_units)

# Second operand with units
assert_allclose(with_units_eye @ matrix_no_units,
matrix_no_units*Mohm)
assert have_same_dimensions(with_units_eye @ matrix_no_units,
matrix_no_units*Mohm)
assert_allclose(np.matmul(with_units_eye, matrix_no_units),
matrix_no_units*Mohm)
assert have_same_dimensions(np.matmul(with_units_eye, matrix_no_units),
matrix_no_units*Mohm)

# Both operands with units
assert_allclose(with_units_eye @ matrix_units,
no_units_eye @ matrix_no_units * nA * Mohm)
assert have_same_dimensions(with_units_eye @ matrix_units,
nA*Mohm)
assert_allclose(np.matmul(with_units_eye, matrix_units),
np.matmul(no_units_eye, matrix_no_units) * nA * Mohm)
assert have_same_dimensions(np.matmul(with_units_eye, matrix_units),
nA * Mohm)


@pytest.mark.codegen_independent
def test_numpy_functions_typeerror():
'''
Expand Down
6 changes: 3 additions & 3 deletions brian2/units/fundamentalunits.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _short_str(arr):
#: ufuncs that work on all dimensions but change the dimensions, e.g. square
UFUNCS_CHANGE_DIMENSIONS = ['multiply', 'divide', 'true_divide',
'floor_divide', 'sqrt', 'square', 'reciprocal',
'dot']
'dot', 'matmul']

#: ufuncs that work with matching dimensions, e.g. add
UFUNCS_MATCHING_DIMENSIONS = ['add', 'subtract', 'maximum', 'minimum',
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def __array_prepare__(self, array, context=None):
def __array_wrap__(self, array, context=None):
dim = DIMENSIONLESS

if not context is None:
if context is not None:
uf, args, _ = context
if uf.__name__ in (UFUNCS_PRESERVE_DIMENSIONS +
UFUNCS_MATCHING_DIMENSIONS):
Expand All @@ -1058,7 +1058,7 @@ def __array_wrap__(self, array, context=None):
dim = get_dimensions(args[0]) / get_dimensions(args[1])
elif uf.__name__ == 'reciprocal':
dim = get_dimensions(args[0]) ** -1
elif uf.__name__ in ('multiply', 'dot'):
elif uf.__name__ in ('multiply', 'dot', 'matmul'):
dim = get_dimensions(args[0]) * get_dimensions(args[1])
else:
warn("Unknown ufunc '%s' in __array_wrap__" % uf.__name__)
Expand Down
5 changes: 4 additions & 1 deletion docs_sphinx/developer/units.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ they are only using functions/methods that work with quantities:
* ``correlate`` (returns a quantity with wrong units)
* ``histogramdd`` (raises a ``DimensionMismatchError``)

**other unsupported functions**
Functions in ``numpy``'s subpackages such as ``linalg`` are not supported and will
either not work with units, or remove units from their inputs.

User-defined functions and units
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For performance and simplicity reasons, code within the Brian core does not use
Quantity objects but unitless numpy arrays instead. See :doc:`functions` for
details on how to make use user-defined functions with Brian's unit system.

0 comments on commit 56cbdf6

Please sign in to comment.