Skip to content

Commit

Permalink
fix evaluable array const checks (#807)
Browse files Browse the repository at this point in the history
The `WithDerivative` evaluable defines a derivative of a wrapped array
to some target. The target is not included in the
`Evaluable.dependencies`, but is included in the `Array.arguments`. This
PR fixes `Evaluable.isconstant` and `Array._derivative` (default impl)
which both neglect to inspect the `.arguments`.
  • Loading branch information
joostvanzwieten committed Jun 13, 2023
2 parents c08fe88 + f161b67 commit 8551cc5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def arguments(self):

@property
def isconstant(self):
return EVALARGS not in self.dependencies
return EVALARGS not in self.dependencies and not self.arguments

@cached_property
def ordereddeps(self):
Expand Down Expand Up @@ -987,7 +987,7 @@ def _unaligned(self):
_inflations = ()

def _derivative(self, var, seen):
if self.dtype in (bool, int) or var not in self.dependencies:
if self.dtype in (bool, int) or var not in self.arguments:
return Zeros(self.shape + var.shape, dtype=self.dtype)
raise NotImplementedError('derivative not defined for {}'.format(self.__class__.__name__))

Expand Down
20 changes: 20 additions & 0 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,26 @@ def test_int_to_float(self):
func = evaluable.IntToFloat(evaluable.BoolToInt(evaluable.Greater(arg, evaluable.zeros(()))))
self.assertTrue(evaluable.iszero(evaluable.derivative(func, arg)))

def test_with_derivative(self):
arg = evaluable.Argument('arg', (evaluable.constant(3),), float)
deriv = numpy.arange(6, dtype=float).reshape(2, 3)
func = evaluable.zeros((evaluable.constant(2),), float)
func = evaluable.WithDerivative(func, arg, evaluable.asarray(deriv))
self.assertAllAlmostEqual(evaluable.derivative(func, arg).eval(), deriv)

def test_default_derivative(self):
# Tests whether `evaluable.Array._derivative` correctly raises an
# exception when taking a derivative to one of the arguments present in
# its `.arguments`.
class DefaultDeriv(evaluable.Array): pass
has_arg = evaluable.Argument('has_arg', (), float)
has_not_arg = evaluable.Argument('has_not_arg', (), float)
func = evaluable.WithDerivative(evaluable.Zeros((), float), has_arg, evaluable.Zeros((), float))
func = DefaultDeriv((func,), (), float)
with self.assertRaises(NotImplementedError):
evaluable.derivative(func, has_arg)
self.assertTrue(evaluable.iszero(evaluable.derivative(func, has_not_arg)))


class asciitree(TestCase):

Expand Down

0 comments on commit 8551cc5

Please sign in to comment.