From e74ac06081d6f90434a68cd302c6b1b0a704f6d4 Mon Sep 17 00:00:00 2001 From: Robin De Schepper Date: Sun, 11 Oct 2020 14:54:38 +0200 Subject: [PATCH] Improved wrapping (#38) * Added init_subclass that registers a wrapper in the PythonHocInterpreter * Improved the PythonHocInterpreter's HocObject wrapping. closes #28 --- patch/__init__.py | 1 + patch/interpreter.py | 65 ++++++++++++++++++++++++++++++-------------- patch/objects.py | 17 ++++++++++++ tests/test_basics.py | 19 +++++++++++++ 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/patch/__init__.py b/patch/__init__.py index 5f0b7cc..fb4c366 100644 --- a/patch/__init__.py +++ b/patch/__init__.py @@ -23,6 +23,7 @@ def p(self): global _p if _p is None: _p = PythonHocModule.PythonHocInterpreter() + PythonHocModule.PythonHocInterpreter._process_registration_queue() return _p def connection(self, source, target, strict=True): diff --git a/patch/interpreter.py b/patch/interpreter.py index 60b8212..b3187b6 100644 --- a/patch/interpreter.py +++ b/patch/interpreter.py @@ -1,4 +1,4 @@ -from .objects import PythonHocObject, NetCon, PointProcess, VecStim, Section, IClamp, SectionRef +from .objects import PythonHocObject, NetCon, PointProcess, VecStim, Section, IClamp, SectionRef, _get_obj_registration_queue from .core import ( transform, transform_netcon, @@ -8,6 +8,7 @@ ) from .exceptions import * from .error_handler import catch_hoc_error, CatchNetCon, CatchSectionAccess, _suppress_nrn +from functools import wraps class PythonHocInterpreter: @@ -15,22 +16,52 @@ def __init__(self): from neuron import h self.__dict__["_PythonHocInterpreter__h"] = h - # Wrapping should occur around all calls to functions that share a name with - # child classes of the PythonHocObject like h.Section, h.NetStim, h.NetCon - self.__object_classes = PythonHocObject.__subclasses__().copy() - self.__requires_wrapping = [cls.__name__ for cls in self.__object_classes] self.__loaded_extensions = [] self.load_file("stdrun.hoc") self.runtime = 0 + @classmethod + def _process_registration_queue(cls): + """ + Most PythonHocObject classes (all those provided by Patch for sure) are created + before the PythonHocInterpreter class is available. Yet they require the class to + combine the original pointer from ``h.`` (e.g. ``h.Section``) with a + function that defers to their constructor so that you can call ``p.Section()`` + and create a PythonHocObject wrapped around the underlying ``h`` pointer. + + This function is called right after the PythonHocInterpreter class is created so + that PythonHocObjects can place themselves in a queue and have themselves + registered into the class right after it's ready. + """ + for hoc_object_class in _get_obj_registration_queue(): + cls.register_hoc_object(hoc_object_class) + + @classmethod + def register_hoc_object(interpreter_class, hoc_object_class): + # We shouldn't use multiple copies of h in case of monkey patches but since we + # need only native functions that return a hoc object this is fine. + from neuron import h + + if hoc_object_class.__name__ in interpreter_class.__dict__: + # The function call was overridden in the interpreter and should not be destroyed. + return + hoc_object_name = hoc_object_class.__name__ + # If the original interpreter doesn't have a function with the same name we can't + # simplify the constructor of the PythonHocObject and shouldn't wrap it. + if hasattr(h, hoc_object_name): + # Wrap it in the interpreter with a call to the underlying `h` to obtain a pointer + # and use that to make our PythonHocObject + factory = getattr(h, hoc_object_name) + @wraps(hoc_object_class.__init__) + def wrapper(interpreter_instance, *args, **kwargs): + hoc_ptr = factory(*args, **kwargs) + return hoc_object_class(interpreter_instance, hoc_ptr) + + setattr(PythonHocInterpreter, hoc_object_class.__name__, wrapper) + def __getattr__(self, attr_name): - # Get the missing attribute from h, if it requires wrapping return a wrapped - # object instead. - attr = getattr(self.__h, attr_name) - if attr_name in self.__requires_wrapping: - return self.wrap(attr, attr_name) - else: - return attr + # Get the missing attribute from h + return getattr(self.__h, attr_name) def __setattr__(self, attr, value): if hasattr(self.__h, attr): @@ -38,14 +69,6 @@ def __setattr__(self, attr, value): else: self.__dict__[attr] = value - def wrap(self, factory, name): - def wrapper(*args, **kwargs): - obj = factory(*args, **kwargs) - cls = next((c for c in self.__object_classes if c.__name__ == name), None) - return cls(self, obj) - - return wrapper - def NetCon(self, source, target, *args, **kwargs): nrn_source = transform_netcon(source) nrn_target = transform_netcon(target) @@ -337,3 +360,5 @@ def _broadcast(self, data, root=0): raise BroadcastError( "Root node did not transmit. Look for root node error." ) from None + +PythonHocInterpreter._process_registration_queue() diff --git a/patch/objects.py b/patch/objects.py index e58cc62..6156a82 100644 --- a/patch/objects.py +++ b/patch/objects.py @@ -2,7 +2,20 @@ from .error_handler import catch_hoc_error, CatchRecord +_registration_queue = [] + + class PythonHocObject: + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + try: + from .interpreter import PythonHocInterpreter + except ImportError: + _registration_queue.append(cls) + return + + PythonHocInterpreter.register_hoc_object(cls) + def __init__(self, interpreter, ptr): # Initialize ourselves with a reference to our own "pointer" # and prepare a list for other references. @@ -309,3 +322,7 @@ def stimulate(self, pattern=None, weight=0.04, delay=0.0, **kwargs): stimulus = self._interpreter.VecStim(pattern=pattern) self._interpreter.NetCon(stimulus, self, weight=weight, delay=delay) return stimulus + + +def _get_obj_registration_queue(): + return _registration_queue diff --git a/tests/test_basics.py b/tests/test_basics.py index 3d9565a..037219a 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -6,6 +6,25 @@ from patch.exceptions import * +class TestPatchRegistration(_shared.NeuronTestCase): + """ + Check that the registration of PythonHocObjects works. (Will almost never be relevant + since most actual HocObjects will be covered by Patch and use the registration queue + rather than immediate registration; and any class names that don't correspond to an + actual ``h.`` function don't create a wrapper) + """ + + def test_registration(self): + from patch import p + + # Create a new PythonHocObject, no wrapper will be added as it does not exist in h + class NewHocObject(patch.objects.PythonHocObject): + pass + + # Nothing to test, but the import inside ``PythonHocObject.__init_subclass__`` + # should complete and the call to ``PythonHocInterpreter.register_hoc_object`` + # should be covered in test coverage results. + class TestPatch(_shared.NeuronTestCase): """ Check Patch basics like object wrapping and the standard interface.