-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
imports.py
156 lines (127 loc) · 4.67 KB
/
imports.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Utilities related to importing modules and symbols by name."""
import os
import sys
import warnings
from contextlib import contextmanager
from importlib import import_module, reload
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points
from kombu.utils.imports import symbol_by_name
#: Billiard sets this when execv is enabled.
#: We use it to find out the name of the original ``__main__``
#: module, so that we can properly rewrite the name of the
#: task to be that of ``App.main``.
MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE')
__all__ = (
'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name',
'cwd_in_path', 'find_module', 'import_from_cwd',
'reload_from_cwd', 'module_file', 'gen_task_name',
)
class NotAPackage(Exception):
"""Raised when importing a package, but it's not a package."""
def qualname(obj):
"""Return object name."""
if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
obj = obj.__class__
q = getattr(obj, '__qualname__', None)
if '.' not in q:
q = '.'.join((obj.__module__, q))
return q
def instantiate(name, *args, **kwargs):
"""Instantiate class by name.
See Also:
:func:`symbol_by_name`.
"""
return symbol_by_name(name)(*args, **kwargs)
@contextmanager
def cwd_in_path():
"""Context adding the current working directory to sys.path."""
cwd = os.getcwd()
if cwd in sys.path:
yield
else:
sys.path.insert(0, cwd)
try:
yield cwd
finally:
try:
sys.path.remove(cwd)
except ValueError: # pragma: no cover
pass
def find_module(module, path=None, imp=None):
"""Version of :func:`imp.find_module` supporting dots."""
if imp is None:
imp = import_module
with cwd_in_path():
try:
return imp(module)
except ImportError:
# Raise a more specific error if the problem is that one of the
# dot-separated segments of the module name is not a package.
if '.' in module:
parts = module.split('.')
for i, part in enumerate(parts[:-1]):
package = '.'.join(parts[:i + 1])
try:
mpart = imp(package)
except ImportError:
# Break out and re-raise the original ImportError
# instead.
break
try:
mpart.__path__
except AttributeError:
raise NotAPackage(package)
raise
def import_from_cwd(module, imp=None, package=None):
"""Import module, temporarily including modules in the current directory.
Modules located in the current directory has
precedence over modules located in `sys.path`.
"""
if imp is None:
imp = import_module
with cwd_in_path():
return imp(module, package=package)
def reload_from_cwd(module, reloader=None):
"""Reload module (ensuring that CWD is in sys.path)."""
if reloader is None:
reloader = reload
with cwd_in_path():
return reloader(module)
def module_file(module):
"""Return the correct original file name of a module."""
name = module.__file__
return name[:-1] if name.endswith('.pyc') else name
def gen_task_name(app, name, module_name):
"""Generate task name from name/module pair."""
module_name = module_name or '__main__'
try:
module = sys.modules[module_name]
except KeyError:
# Fix for manage.py shell_plus (Issue #366)
module = None
if module is not None:
module_name = module.__name__
# - If the task module is used as the __main__ script
# - we need to rewrite the module part of the task name
# - to match App.main.
if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE:
# - see comment about :envvar:`MP_MAIN_FILE` above.
module_name = '__main__'
if module_name == '__main__' and app.main:
return '.'.join([app.main, name])
return '.'.join(p for p in (module_name, name) if p)
def load_extension_class_names(namespace):
for ep in entry_points().get(namespace, []):
yield ep.name, ep.value
def load_extension_classes(namespace):
for name, class_name in load_extension_class_names(namespace):
try:
cls = symbol_by_name(class_name)
except (ImportError, SyntaxError) as exc:
warnings.warn(
f'Cannot load {namespace} extension {class_name!r}: {exc!r}')
else:
yield name, cls