Permalink
Browse files

Merge pull request #1457 from bfroehle/dreload

Update deepreload to use a rewritten knee.py. 

This fixes the long-standing bug with `dreload(numpy)` not working.

`knee.py`, a Python re-implementation of hierarchical module import was removed from the standard library because it no longer functioned properly.  This uses `deepreload.py`, a fixed version of `knee.py` which overrides `__builtin__.__import__` to ensure that each module is re-imported once (before just referring to sys.modules as usual).

In addition, `os.path` was added to the default excluded modules, since somehow it has an entry in sys.modules without `os' being a package.

Added tests for this functionality (which had never had any tests).

Closes gh-32.
  • Loading branch information...
2 parents 2464446 + 2e6a444 commit a206ee29897390864c52d1ef1321db363d877e1c @fperez fperez committed Apr 18, 2012
Showing with 299 additions and 114 deletions.
  1. +246 −114 IPython/lib/deepreload.py
  2. +53 −0 IPython/lib/tests/test_deepreload.py
View
360 IPython/lib/deepreload.py
@@ -14,9 +14,9 @@
__builtin__.dreload = deepreload.reload
-This code is almost entirely based on knee.py from the standard library.
+This code is almost entirely based on knee.py, which is a Python
+re-implementation of hierarchical module import.
"""
-
#*****************************************************************************
# Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
#
@@ -28,135 +28,267 @@
import imp
import sys
-# Replacement for __import__()
-def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
- # For now level is ignored, it's just there to prevent crash
- # with from __future__ import absolute_import
- parent = determine_parent(globals)
- q, tail = find_head_package(parent, name)
- m = load_tail(q, tail)
- if not fromlist:
- return q
- if hasattr(m, "__path__"):
- ensure_fromlist(m, fromlist)
- return m
+from types import ModuleType
+from warnings import warn
+
+def get_parent(globals, level):
+ """
+ parent, name = get_parent(globals, level)
+
+ Return the package that an import is being performed in. If globals comes
+ from the module foo.bar.bat (not itself a package), this returns the
+ sys.modules entry for foo.bar. If globals is from a package's __init__.py,
+ the package's entry in sys.modules is returned.
+
+ If globals doesn't come from a package or a module in a package, or a
+ corresponding entry is not found in sys.modules, None is returned.
+ """
+ orig_level = level
+
+ if not level or not isinstance(globals, dict):
+ return None, ''
-def determine_parent(globals):
- if not globals or not globals.has_key("__name__"):
- return None
- pname = globals['__name__']
- if globals.has_key("__path__"):
- parent = sys.modules[pname]
- assert globals is parent.__dict__
- return parent
- if '.' in pname:
- i = pname.rfind('.')
- pname = pname[:i]
- parent = sys.modules[pname]
- assert parent.__name__ == pname
- return parent
- return None
-
-def find_head_package(parent, name):
- # Import the first
- if '.' in name:
- # 'some.nested.package' -> head = 'some', tail = 'nested.package'
- i = name.find('.')
- head = name[:i]
- tail = name[i+1:]
+ pkgname = globals.get('__package__', None)
+
+ if pkgname is not None:
+ # __package__ is set, so use it
+ if not hasattr(pkgname, 'rindex'):
+ raise ValueError('__package__ set to non-string')
+ if len(pkgname) == 0:
+ if level > 0:
+ raise ValueError('Attempted relative import in non-package')
+ return None, ''
+ name = pkgname
else:
- # 'packagename' -> head = 'packagename', tail = ''
- head = name
- tail = ""
- if parent:
- # If this is a subpackage then qname = parent's name + head
- qname = "%s.%s" % (parent.__name__, head)
+ # __package__ not set, so figure it out and set it
+ if '__name__' not in globals:
+ return None, ''
+ modname = globals['__name__']
+
+ if '__path__' in globals:
+ # __path__ is set, so modname is already the package name
+ globals['__package__'] = name = modname
+ else:
+ # Normal module, so work out the package name if any
+ lastdot = modname.rfind('.')
+ if lastdot < 0 and level > 0:
+ raise ValueError("Attempted relative import in non-package")
+ if lastdot < 0:
+ globals['__package__'] = None
+ return None, ''
+ globals['__package__'] = name = modname[:lastdot]
+
+ dot = len(name)
+ for x in xrange(level, 1, -1):
+ try:
+ dot = name.rindex('.', 0, dot)
+ except ValueError:
+ raise ValueError("attempted relative import beyond top-level "
+ "package")
+ name = name[:dot]
+
+ try:
+ parent = sys.modules[name]
+ except:
+ if orig_level < 1:
+ warn("Parent module '%.200s' not found while handling absolute "
+ "import" % name)
+ parent = None
+ else:
+ raise SystemError("Parent module '%.200s' not loaded, cannot "
+ "perform relative import" % name)
+
+ # We expect, but can't guarantee, if parent != None, that:
+ # - parent.__name__ == name
+ # - parent.__dict__ is globals
+ # If this is violated... Who cares?
+ return parent, name
+
+def load_next(mod, altmod, name, buf):
+ """
+ mod, name, buf = load_next(mod, altmod, name, buf)
+
+ altmod is either None or same as mod
+ """
+
+ if len(name) == 0:
+ # completely empty module name should only happen in
+ # 'from . import' (or '__import__("")')
+ return mod, None, buf
+
+ dot = name.find('.')
+ if dot == 0:
+ raise ValueError('Empty module name')
+
+ if dot < 0:
+ subname = name
+ next = None
else:
- qname = head
- q = import_module(head, qname, parent)
- if q: return q, tail
- if parent:
- qname = head
- parent = None
- q = import_module(head, qname, parent)
- if q: return q, tail
- raise ImportError, "No module named " + qname
-
-def load_tail(q, tail):
- m = q
- while tail:
- i = tail.find('.')
- if i < 0: i = len(tail)
- head, tail = tail[:i], tail[i+1:]
-
- # fperez: fix dotted.name reloading failures by changing:
- #mname = "%s.%s" % (m.__name__, head)
- # to:
- mname = m.__name__
- # This needs more testing!!! (I don't understand this module too well)
-
- #print '** head,tail=|%s|->|%s|, mname=|%s|' % (head,tail,mname) # dbg
- m = import_module(head, mname, m)
- if not m:
- raise ImportError, "No module named " + mname
- return m
+ subname = name[:dot]
+ next = name[dot+1:]
+
+ if buf != '':
+ buf += '.'
+ buf += subname
+
+ result = import_submodule(mod, subname, buf)
+ if result is None and mod != altmod:
+ result = import_submodule(altmod, subname, subname)
+ if result is not None:
+ buf = subname
+
+ if result is None:
+ raise ImportError("No module named %.200s" % name)
-def ensure_fromlist(m, fromlist, recursive=0):
- for sub in fromlist:
- if sub == "*":
- if not recursive:
- try:
- all = m.__all__
- except AttributeError:
- pass
- else:
- ensure_fromlist(m, all, 1)
- continue
- if sub != "*" and not hasattr(m, sub):
- subname = "%s.%s" % (m.__name__, sub)
- submod = import_module(sub, subname, m)
- if not submod:
- raise ImportError, "No module named " + subname
+ return result, next, buf
# Need to keep track of what we've already reloaded to prevent cyclic evil
found_now = {}
-def import_module(partname, fqname, parent):
+def import_submodule(mod, subname, fullname):
+ """m = import_submodule(mod, subname, fullname)"""
+ # Require:
+ # if mod == None: subname == fullname
+ # else: mod.__name__ + "." + subname == fullname
+
global found_now
- if found_now.has_key(fqname):
+ if fullname in found_now and fullname in sys.modules:
+ m = sys.modules[fullname]
+ else:
+ print 'Reloading', fullname
+ found_now[fullname] = 1
+ oldm = sys.modules.get(fullname, None)
+
+ if mod is None:
+ path = None
+ elif hasattr(mod, '__path__'):
+ path = mod.__path__
+ else:
+ return None
+
try:
- return sys.modules[fqname]
- except KeyError:
- pass
+ fp, filename, stuff = imp.find_module(subname, path)
+ except ImportError:
+ return None
+
+ try:
+ m = imp.load_module(fullname, fp, filename, stuff)
+ except:
+ # load_module probably removed name from modules because of
+ # the error. Put back the original module object.
+ if oldm:
+ sys.modules[fullname] = oldm
+ raise
+ finally:
+ if fp: fp.close()
+
+ add_submodule(mod, m, fullname, subname)
+
+ return m
+
+def add_submodule(mod, submod, fullname, subname):
+ """mod.{subname} = submod"""
+ if mod is None:
+ return #Nothing to do here.
+
+ if submod is None:
+ submod = sys.modules[fullname]
+
+ setattr(mod, subname, submod)
+
+ return
+
+def ensure_fromlist(mod, fromlist, buf, recursive):
+ """Handle 'from module import a, b, c' imports."""
+ if not hasattr(mod, '__path__'):
+ return
+ for item in fromlist:
+ if not hasattr(item, 'rindex'):
+ raise TypeError("Item in ``from list'' not a string")
+ if item == '*':
+ if recursive:
+ continue # avoid endless recursion
+ try:
+ all = mod.__all__
+ except AttributeError:
+ pass
+ else:
+ ret = ensure_fromlist(mod, all, buf, 1)
+ if not ret:
+ return 0
+ elif not hasattr(mod, item):
+ import_submodule(mod, item, buf + '.' + item)
+
+def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
+ """Replacement for __import__()"""
+ parent, buf = get_parent(globals, level)
+
+ head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
+
+ tail = head
+ while name:
+ tail, name, buf = load_next(tail, tail, name, buf)
+
+ # If tail is None, both get_parent and load_next found
+ # an empty module name: someone called __import__("") or
+ # doctored faulty bytecode
+ if tail is None:
+ raise ValueError('Empty module name')
- print 'Reloading', fqname #, sys.excepthook is sys.__excepthook__, \
- #sys.displayhook is sys.__displayhook__
+ if not fromlist:
+ return head
+
+ ensure_fromlist(tail, fromlist, buf, 0)
+ return tail
+
+modules_reloading = {}
+
+def deep_reload_hook(m):
+ """Replacement for reload()."""
+ if not isinstance(m, ModuleType):
+ raise TypeError("reload() argument must be module")
+
+ name = m.__name__
+
+ if name not in sys.modules:
+ raise ImportError("reload(): module %.200s not in sys.modules" % name)
+
+ global modules_reloading
+ try:
+ return modules_reloading[name]
+ except:
+ modules_reloading[name] = m
+
+ dot = name.rfind('.')
+ if dot < 0:
+ subname = name
+ path = None
+ else:
+ try:
+ parent = sys.modules[name[:dot]]
+ except KeyError:
+ modules_reloading.clear()
+ raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
+ subname = name[dot+1:]
+ path = getattr(parent, "__path__", None)
- found_now[fqname] = 1
try:
- fp, pathname, stuff = imp.find_module(partname,
- parent and parent.__path__)
- except ImportError:
- return None
+ fp, filename, stuff = imp.find_module(subname, path)
+ finally:
+ modules_reloading.clear()
try:
- m = imp.load_module(fqname, fp, pathname, stuff)
+ newm = imp.load_module(name, fp, filename, stuff)
+ except:
+ # load_module probably removed name from modules because of
+ # the error. Put back the original module object.
+ sys.modules[name] = m
+ raise
finally:
if fp: fp.close()
- if parent:
- setattr(parent, partname, m)
-
- return m
-
-def deep_reload_hook(module):
- name = module.__name__
- if '.' not in name:
- return import_module(name, name, None)
- i = name.rfind('.')
- pname = name[:i]
- parent = sys.modules[pname]
- return import_module(name[i+1:], name, parent)
+ modules_reloading.clear()
+ return newm
# Save the original hooks
try:
@@ -165,7 +297,7 @@ def deep_reload_hook(module):
original_reload = imp.reload # Python 3
# Replacement for reload()
-def reload(module, exclude=['sys', '__builtin__', '__main__']):
+def reload(module, exclude=['sys', 'os.path', '__builtin__', '__main__']):
"""Recursively reload all modules used in the given module. Optionally
takes a list of modules to exclude from reloading. The default exclude
list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
View
53 IPython/lib/tests/test_deepreload.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+"""Test suite for the deepreload module."""
+
+#-----------------------------------------------------------------------------
+# Imports
+#-----------------------------------------------------------------------------
+
+import os
+import sys
+
+import nose.tools as nt
+
+from IPython.testing import decorators as dec
+from IPython.utils.syspathcontext import prepended_to_syspath
+from IPython.utils.tempdir import TemporaryDirectory
+from IPython.lib.deepreload import reload as dreload
+
+#-----------------------------------------------------------------------------
+# Test functions begin
+#-----------------------------------------------------------------------------
+
+@dec.skipif_not_numpy
+def test_deepreload_numpy():
+ "Test that NumPy can be deep reloaded."
+ import numpy
+ exclude = [
+ # Standard exclusions:
+ 'sys', 'os.path', '__builtin__', '__main__',
+ # Test-related exclusions:
+ 'unittest',
+ ]
+ dreload(numpy, exclude=exclude)
+
+def test_deepreload():
+ "Test that dreload does deep reloads and skips excluded modules."
+ with TemporaryDirectory() as tmpdir:
+ with prepended_to_syspath(tmpdir):
+ with open(os.path.join(tmpdir, 'A.py'), 'w') as f:
+ f.write("class Object(object):\n pass\n")
+ with open(os.path.join(tmpdir, 'B.py'), 'w') as f:
+ f.write("import A\n")
+ import A
+ import B
+
+ # Test that A is not reloaded.
+ obj = A.Object()
+ dreload(B, exclude=['A'])
+ nt.assert_true(isinstance(obj, A.Object))
+
+ # Test that A is reloaded.
+ obj = A.Object()
+ dreload(B)
+ nt.assert_false(isinstance(obj, A.Object))

0 comments on commit a206ee2

Please sign in to comment.