Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sympy
from sympy import Expr, Function, Number, Tuple, cacheit, sympify
from sympy.core.decorators import call_highest_priority
from sympy.core.function import Application
from sympy.logic.boolalg import BooleanFunction

from devito.finite_differences.elementary import Max, Min
Expand Down Expand Up @@ -718,7 +719,13 @@ def __new__(cls, name, arguments=None, template=None, **kwargs):
if _template:
args.append(Tuple(*_template))

obj = Function.__new__(cls, *args)
# `Function.__new__` and `Application.__new__` are both cached by
# SymPy. DefFunction subclasses may attach reconstruction kwargs as
# side attributes after this base constructor returns; going through
# the cached route could then alias a previous object and mutate it
# during reconstruction. Call Application's uncached constructor
# explicitly instead of using super()/Function.__new__.
obj = Application.__new__.__wrapped__(cls, *args)
obj._name = name
obj._arguments = tuple(_arguments)
obj._template = tuple(_template)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,36 @@ def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None):
assert func1.p1 == (g,)
assert func1.p2 == 'bar'

def test_custom_def_function_reconstruction_no_aliasing(self):

class MyDefFunction(DefFunction):
__rargs__ = ('name', 'arguments')
__rkwargs__ = ('p0',)

def __new__(cls, name=None, arguments=None, p0=None):
obj = super().__new__(cls, name=name, arguments=arguments)
obj.p0 = p0
return obj

def _hashable_content(self):
return super()._hashable_content() + (self.p0,)

grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)

func0 = MyDefFunction(name='foo', arguments=f.indexify(), p0=f)
h0 = hash(func0)

func1 = func0.func(p0=g)

assert func1 is not func0
assert func1 != func0
assert hash(func0) == h0
assert func0.p0 is f
assert func1.p0 is g

def test_reduce_to_number(self):
grid = Grid(shape=(4, 4))
x, _ = grid.dimensions
Expand Down
Loading