Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Context manager and default context under threading #38

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions doc/ref_fundamental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ Reference: Basic Building Blocks
Context
-------

.. note:: this class implements Python's ``__copy__`` and ``__deepcopy__``
protocols. Each of these returns the context being 'copied' identically.

.. note:: during an pickle operation, the current default :class:`Context`
is always used.

.. seealso:: :ref:`sec-context-management`
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note on pickle behavior.

Copy link
Contributor Author

@thisiscam thisiscam Oct 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
However, I'm thinking towards a slightly better semantics of pickling of Context. I can't imagine a use case for this, but here's a sketch:

@functools.lru_cache(maxsize=None)
def _get_context_by_id(ctx_id):
        return Context()

def context_reduce(self):
        if self == get_default_context():
            return (get_default_context, ())
        return (_get_context_by_id, (hash(self), ))

What do you think?


.. autoclass:: Context()
:members:

Expand Down
12 changes: 6 additions & 6 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ Lifetime Helpers
function call to which they're passed. These callback return a callback
handle that must be kept alive until the callback is no longer needed.

Global Data
^^^^^^^^^^^
.. _sec-context-management:

Context Management
^^^^^^^^^^^^^^^^^^

.. data:: DEFAULT_CONTEXT
.. autofunction:: get_default_context

ISL objects being unpickled or initialized from strings will be instantiated
within this :class:`Context`.
.. autofunction:: push_context

.. versionadded:: 2015.2

Symbolic Constants
^^^^^^^^^^^^^^^^^^
Expand Down
124 changes: 111 additions & 13 deletions islpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
THE SOFTWARE.
"""

import sys
import contextlib

import islpy._isl as _isl
from islpy.version import VERSION, VERSION_TEXT # noqa
import six
Expand Down Expand Up @@ -145,14 +148,104 @@
EXPR_CLASSES = tuple(cls for cls in ALL_CLASSES
if "Aff" in cls.__name__ or "Polynomial" in cls.__name__)

DEFAULT_CONTEXT = Context()
inducer marked this conversation as resolved.
Show resolved Hide resolved

def _module_property(func):
"""Decorator to turn module functions into properties.
Function names must be prefixed with an underscore."""
module = sys.modules[func.__module__]

def base_getattr(name):
raise AttributeError(
f"module '{module.__name__}' has no attribute '{name}'")

old_getattr = getattr(module, "__getattr__", base_getattr)

def new_getattr(name):
if f"_{name}" == func.__name__:
return func()
else:
return old_getattr(name)

module.__getattr__ = new_getattr
return func


import threading


_thread_local_storage = threading.local()


def _check_init_default_context():
if not hasattr(_thread_local_storage, "islpy_default_contexts"):
_thread_local_storage.islpy_default_contexts = [Context()]


def get_default_context():
inducer marked this conversation as resolved.
Show resolved Hide resolved
"""Get or create the default context under current thread.

:return: the current default :class:`Context`

.. versionadded:: 2020.3
"""
_check_init_default_context()
return _thread_local_storage.islpy_default_contexts[-1]


def _get_default_context():
"""A callable to get the default context for the benefit of Python's
``__reduce__`` protocol.
from warnings import warn
warn("It appears that you might be deserializing an islpy.Context"
"that was serialized by a previous version of islpy."
"If so, this is discouraged and please consider to re-serialize"
"the Context with the newer version to avoid possible inconsistencies.",
UserWarning)
return get_default_context()


@contextlib.contextmanager
def push_context(ctx=None):
"""Context manager to push new default :class:`Context`

:param ctx: an optional explicit context that is pushed to
the stack of default :class:`Context` s

.. versionadded:: 2020.3

:mod:`islpy` internally maintains a stack of default :class:`Context` s
for each Python thread.
By default, each stack is initialized with a base default :class:`Context`.
ISL objects being unpickled or initialized from strings will be
instantiated within the top :class:`Context` of the stack of
the executing thread.

Usage example::

with islpy.push_context() as dctx:
s = islpy.Set("{[0]: }")
assert s.get_ctx() == dctx

"""
return DEFAULT_CONTEXT
if ctx is None:
ctx = Context()
_check_init_default_context()
_thread_local_storage.islpy_default_contexts.append(ctx)
yield ctx
_thread_local_storage.islpy_default_contexts.pop()


@_module_property
def _DEFAULT_CONTEXT(): # noqa: N802
from warnings import warn
warn("Use of islpy.DEFAULT_CONTEXT is deprecated "
"and will be removed in 2022."
" Please use `islpy.get_default_context()` instead. ",
FutureWarning,
stacklevel=3)
return get_default_context()


if sys.version_info < (3, 7):
DEFAULT_CONTEXT = get_default_context()


def _read_from_str_wrapper(cls, context, s):
Expand All @@ -168,10 +261,14 @@ def _add_functionality():
# {{{ Context

def context_reduce(self):
if self._wraps_same_instance_as(DEFAULT_CONTEXT):
return (_get_default_context, ())
else:
return (Context, ())
return (get_default_context, ())
inducer marked this conversation as resolved.
Show resolved Hide resolved

def context_copy(self):
return self

def context_deepcopy(self, memo):
del memo
return self
inducer marked this conversation as resolved.
Show resolved Hide resolved

def context_eq(self, other):
return isinstance(other, Context) and self._wraps_same_instance_as(other)
Expand All @@ -180,9 +277,10 @@ def context_ne(self, other):
return not self.__eq__(other)

Context.__reduce__ = context_reduce
Context.__copy__ = context_copy
Context.__deepcopy__ = context_deepcopy
Context.__eq__ = context_eq
Context.__ne__ = context_ne

# }}}

# {{{ generic initialization, pickling
Expand All @@ -197,7 +295,7 @@ def obj_new(cls, s=None, context=None):
return cls._prev_new(cls)

if context is None:
context = DEFAULT_CONTEXT
context = get_default_context()

result = cls.read_from_str(context, s)
return result
Expand Down Expand Up @@ -473,7 +571,7 @@ def obj_get_coefficients_by_name(self, dimtype=None, dim_to_name=None):

def id_new(cls, name, user=None, context=None):
if context is None:
context = DEFAULT_CONTEXT
context = get_default_context()

result = cls.alloc(context, name, user)
result._made_from_python = True
Expand Down Expand Up @@ -777,7 +875,7 @@ def expr_like_floordiv(self, other):

def val_new(cls, src, context=None):
if context is None:
context = DEFAULT_CONTEXT
context = get_default_context()

if isinstance(src, six.string_types):
result = cls.read_from_str(context, src)
Expand Down Expand Up @@ -1274,7 +1372,7 @@ def make_zero_and_vars(set_vars, params=[], ctx=None):
)
"""
if ctx is None:
ctx = DEFAULT_CONTEXT
ctx = get_default_context()

if isinstance(set_vars, str):
set_vars = [s.strip() for s in set_vars.split(",")]
Expand Down
49 changes: 44 additions & 5 deletions test/test_isl.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ def cb_print_for(printer, options, node):
printer = printer.print_str("Callback For")
return printer

opts = isl.AstPrintOptions.alloc(isl.DEFAULT_CONTEXT)
opts = isl.AstPrintOptions.alloc(isl.get_default_context())
inducer marked this conversation as resolved.
Show resolved Hide resolved
opts, cb_print_user_handle = opts.set_print_user(cb_print_user)
opts, cb_print_for_handle = opts.set_print_for(cb_print_for)

printer = isl.Printer.to_str(isl.DEFAULT_CONTEXT)
printer = isl.Printer.to_str(isl.get_default_context())
printer = printer.set_output_format(isl.format.C)
printer.print_str("// Start\n")
printer = ast.print_(printer, opts)
Expand Down Expand Up @@ -248,7 +248,7 @@ def isl_ast_codegen(S): # noqa: N803
m = isl.Map.identity(m.get_space())
m = isl.Map.from_domain(S)
ast = b.ast_from_schedule(m)
p = isl.Printer.to_str(isl.DEFAULT_CONTEXT)
p = isl.Printer.to_str(isl.get_default_context())
p = p.set_output_format(isl.format.C)
p.flush()
p = p.print_ast_node(ast)
Expand Down Expand Up @@ -362,8 +362,47 @@ def test_bound():
def test_copy_context():
ctx = isl.Context()
import copy
assert not ctx._wraps_same_instance_as(copy.copy(ctx))
assert not isl.DEFAULT_CONTEXT._wraps_same_instance_as(copy.copy(ctx))
assert ctx._wraps_same_instance_as(copy.copy(ctx))
assert ctx == copy.copy(ctx)
assert not isl.get_default_context()._wraps_same_instance_as(copy.copy(ctx))


def test_context_manager():
import pickle

def transfer_copy(obj):
return pickle.loads(pickle.dumps(obj))

b1 = isl.BasicSet("{ [0] : }")
old_dctx = isl.get_default_context()
assert b1.get_ctx() == old_dctx

with isl.push_context() as dctx:
assert dctx == isl.get_default_context()
assert not old_dctx._wraps_same_instance_as(dctx)
b2 = isl.BasicSet("{ [0] : }")
assert b2.get_ctx() == dctx
# Under context manager always use `dctx`
assert transfer_copy(b2).get_ctx() == transfer_copy(b1).get_ctx() == dctx

# Check for proper exit
assert old_dctx == isl.get_default_context()

# Check for nested context
with isl.push_context() as c1:
with isl.push_context() as c2:
assert c1 != c2
with isl.push_context() as c3:
assert c2 != c3
# Check for proper exit
assert old_dctx == isl.get_default_context()


def test_deprecated_default_context():
import warnings
with warnings.catch_warnings():
dctx = isl.DEFAULT_CONTEXT
assert dctx == isl.get_default_context()


def test_ast_node_list_free():
Expand Down