diff --git a/sympy/core/basic.py b/sympy/core/basic.py index 46ae7264a19..d5007133b71 100644 --- a/sympy/core/basic.py +++ b/sympy/core/basic.py @@ -41,7 +41,8 @@ class Basic(metaclass=ManagedProperties): """ __slots__ = ['_mhash', # hash value '_args', # arguments - '_assumptions' + '_assumptions', + '__weakref__' ] # To be overridden with True in the appropriate subclasses diff --git a/sympy/core/cache.py b/sympy/core/cache.py index 7c5cc7dbe23..aa3f178256d 100644 --- a/sympy/core/cache.py +++ b/sympy/core/cache.py @@ -1,6 +1,7 @@ """ Caching facility for SymPy """ from decorator import decorator +import weakref # TODO: refactor CACHE & friends into class? @@ -81,7 +82,7 @@ def __cacheit(f): set environment variable SYMPY_USE_CACHE to 'debug'. """ - func_cache_it_cache = {} + f._cache_it_cache = func_cache_it_cache = weakref.WeakValueDictionary() CACHE.append((f, func_cache_it_cache)) def wrapper(f, *args, **kw_args): diff --git a/sympy/core/tests/test_symbol.py b/sympy/core/tests/test_symbol.py index 8e57fad04b0..0badf8300aa 100644 --- a/sympy/core/tests/test_symbol.py +++ b/sympy/core/tests/test_symbol.py @@ -1,8 +1,11 @@ +import gc +import weakref + import pytest from sympy import (Symbol, Wild, GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, pi, I, Rational, sympify, symbols, Dummy, - Integer, Float, sstr) + Integer, Float, sstr, default_sort_key) def test_Symbol(): @@ -332,3 +335,12 @@ def test_call(): f = Symbol('f') assert f(2) pytest.raises(TypeError, lambda: Wild('x')(1)) + + +def test_weakref(): + x, y = Symbol('x'), Symbol('y') + d = weakref.WeakKeyDictionary([(x, 1), (y, 2)]) + assert sstr(sorted(d.keys(), key=default_sort_key)) == '[x, y]' + del x + gc.collect() + assert sstr(list(d.keys())) == '[y]' diff --git a/sympy/utilities/tests/test_pickling.py b/sympy/utilities/tests/test_pickling.py index 5fcfbc8d640..3afc093c609 100644 --- a/sympy/utilities/tests/test_pickling.py +++ b/sympy/utilities/tests/test_pickling.py @@ -24,7 +24,7 @@ from sympy import symbols, S -excluded_attrs = {'_assumptions', '_mhash'} +excluded_attrs = {'_assumptions', '_mhash', '__weakref__'} def check(a, exclude=[], check_attr=True):