Skip to content

Commit

Permalink
Linearize (#774)
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Feb 7, 2023
2 parents 214b838 + d76518b commit 1094b6b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ The following overview lists user facing changes as well as newly added
features in inverse chronological order.


NEW: function.linearize

Similar to `derivative`, the new function `linearize` takes the derivative of
an array to one or more arguments, but with the derivative directions
represented by arguments rather than array axes. This is particularly useful in
situations where weak forms are made up of symmetric, energy like components,
combined with terms that require dedicated test fields.


NEW: support for searchsorted and interp

Numpy's ufunc support has been extended to include numpy.searchsorted and
Expand Down
40 changes: 40 additions & 0 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2880,6 +2880,46 @@ def replace_arguments(__array: IntoArray, __arguments: Mapping[str, IntoArray])
return _Replace(array, {k: Array.cast(v) for k, v in __arguments.items()}) * scale


def linearize(__array: IntoArray, __arguments: Union[str, Dict[str, str], Iterable[str], Iterable[Tuple[str, str]]]):
'''Linearize functional.
Similar to :func:`derivative`, linearize takes the derivative of an array
to one or more arguments, but with the derivative directions represented by
arguments rather than array axes. The result is by definition linear in the
new arguments.
Parameters
----------
array : :class:`Array` or something that can be :meth:`~Array.cast` into one
arguments : :class:`str`, :class:`dict` or iterable of strings
Example
-------
The following example demonstrates the use of linearize with four
equivalent argument specifications:
>>> u, v, p, q = [Argument(s, (), float) for s in 'uvpq']
>>> f = u**2 + p
>>> lin1 = linearize(f, 'u:v,p:q')
>>> lin2 = linearize(f, dict(u='v', p='q'))
>>> lin3 = linearize(f, ('u:v', 'p:q'))
>>> lin4 = linearize(f, (('u', 'v'), ('p', 'q')))
>>> # lin1 = lin2 == lin3 == lin4 == 2 * u * v + q
'''

array, scale = Array.cast_withscale(__array)
args = __arguments.split(',') if isinstance(__arguments, str) \
else __arguments.items() if isinstance(__arguments, dict) \
else __arguments
parts = []
for kv in args:
k, v = kv.split(':', 1) if isinstance(kv, str) else kv
f = derivative(array, k)
parts.append(numpy.sum(f * Argument(v, f.shape[array.ndim:]), tuple(range(array.ndim, f.ndim))))
return util.sum(parts) * scale


def broadcast_arrays(*arrays: IntoArray) -> Tuple[Array, ...]:
'''Broadcast the given arrays.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,3 +1403,15 @@ def test_divide(self):
f = function.Argument('test', shape=(2, 3), dtype=int)
self.assertIsNot(f / 1, f)
self.assertIsNot(f / 1., f)


class linearize(TestCase):

def test(self):
f = function.linearize(function.Argument('u', shape=(3, 4), dtype=float)**3
+ function.Argument('p', shape=(), dtype=float), 'u:v,p:q')
# test linearization of u**3 + p -> 3 u**2 v + q through evaluation
_u = numpy.arange(3, dtype=float)[:,numpy.newaxis].repeat(4, 1)
_v = numpy.arange(4, dtype=float)[numpy.newaxis,:].repeat(3, 0)
_q = 5.
self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q)

0 comments on commit 1094b6b

Please sign in to comment.