Skip to content

Commit

Permalink
Use RTLD_DEEPBIND when loading Mitsuba DLL from python
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Sep 9, 2022
1 parent 484c33b commit 59d7b35
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 43 deletions.
30 changes: 28 additions & 2 deletions src/core/tests/test_python.py
@@ -1,6 +1,5 @@
from importlib import import_module as _import
import pytest
import sys

def test01_import_mitsuba_variants():
import mitsuba as mi
Expand Down Expand Up @@ -118,4 +117,31 @@ def test08_sys_module_size():
for k, v in sys.modules.items():
getattr(v, "foo", False)

assert True
assert True


@pytest.mark.parametrize('order', [0, 1, 2])
def test09_import_torch_order(order):
if order == 0:
pytest.importorskip("torch")
import mitsuba as mi
mi.set_variant(mi.variants()[0])
if order == 1:
import mitsuba as mi
pytest.importorskip("torch")
mi.set_variant(mi.variants()[0])
if order == 2:
import mitsuba as mi
mi.set_variant(mi.variants()[0])
pytest.importorskip("torch")

bmp = mi.Bitmap(mi.TensorXf([0.0, 0.0, 0.0, 0.0], [2, 2]))
bsdf = mi.load_dict({
'type': 'diffuse',
'reflectance': {
'type': 'bitmap',
'bitmap': bmp,
},
})

print(bsdf)
90 changes: 49 additions & 41 deletions src/python/__init__.py
@@ -1,32 +1,32 @@
""" Mitsuba Python extension library """

import typing
import types
import sys
import threading
import sys as _sys
import os as _os
from importlib import import_module as _import, reload as _reload
import drjit as dr
import os
import drjit as _dr

import typing, types
import threading
import logging

if sys.version_info < (3, 8):
if _sys.version_info < (3, 8):
raise ImportError("Mitsuba requires Python 3.8 or greater.")

if os.name == 'nt':
if _os.name == 'nt':
# Specify DLL search path for windows (no rpath on this platform..)
d = __file__
for i in range(3):
d = os.path.dirname(d)
d = _os.path.dirname(d)
try: # try to use Python 3.8's DLL handling
os.add_dll_directory(d)
_os.add_dll_directory(d)
except AttributeError: # otherwise use PATH
os.environ['PATH'] += os.pathsep + d
_os.environ['PATH'] += _os.pathsep + d
del d, i

mi_dir = os.path.dirname(os.path.realpath(__file__))
if os.name != 'nt' and os.path.relpath(dr.__path__[0], mi_dir) != '../drjit':
drjit_loc = os.path.realpath(dr.__path__[0])
drjit_expected_loc = os.path.realpath(os.path.join(mi_dir, "..", "drjit"))
mi_dir = _os.path.dirname(_os.path.realpath(__file__))
if _os.name != 'nt' and _os.path.relpath(_dr.__path__[0], mi_dir) != '../drjit':
drjit_loc = _os.path.realpath(_dr.__path__[0])
drjit_expected_loc = _os.path.realpath(_os.path.join(mi_dir, "..", "drjit"))
logging.warning("The `mitsuba` package relies on `drjit` and needs it "
"to be installed at a specific location. Currently, "
"`drjit` is located at \"%s\" when it is expected to be "
Expand All @@ -37,7 +37,7 @@
del mi_dir

from .config import DRJIT_VERSION_REQUIREMENT
if dr.__version__ != DRJIT_VERSION_REQUIREMENT:
if _dr.__version__ != DRJIT_VERSION_REQUIREMENT:
raise ImportError("You are using an incompatible version of `drjit`. "
"Only version \"%s\" is guaranteed to be compatible with "
"your current Mitsuba installation. Please update your "
Expand All @@ -46,7 +46,11 @@
del DRJIT_VERSION_REQUIREMENT

try:
# Use RTLD_DEEPBIND to prevent the DLL to search symbols in the global scope
old_flags = _sys.getdlopenflags()
_sys.setdlopenflags(_os.RTLD_LAZY | _os.RTLD_LOCAL | _os.RTLD_DEEPBIND)
_import('mitsuba.mitsuba_ext')
_sys.setdlopenflags(old_flags)
_tls = threading.local()
_tls.cache = {}
except (ImportError, ModuleNotFoundError) as e:
Expand All @@ -56,20 +60,20 @@

if 'Symbol not found' in str(e):
pass
elif PYTHON_EXECUTABLE != sys.executable:
elif PYTHON_EXECUTABLE != _sys.executable:
extra_msg = ("You're likely trying to use Mitsuba within a Python "
"binary (%s) that is different from the one for which "
"the native module was compiled (%s).") % (
sys.executable, PYTHON_EXECUTABLE)
_sys.executable, PYTHON_EXECUTABLE)

exc = ImportError("The 'mitsuba' native modules could not be "
"imported. %s" % extra_msg)
exc.__cause__ = e

raise exc
finally:
# Make sure mitsuba_ext isn't accessible from sys.modules
sys.modules.pop('mitsuba.mitsuba_ext', None)
# Make sure mitsuba_ext isn't accessible from _sys.modules
_sys.modules.pop('mitsuba.mitsuba_ext', None)

# Known submodules that will be directly accessible from the mitsuba package
submodules = ['warp', 'math', 'spline', 'quad', 'mueller', 'util', 'filesystem']
Expand Down Expand Up @@ -110,21 +114,25 @@ def __getattribute__(self, key):

if modules is None:
try:
# Use RTLD_DEEPBIND to prevent DLLs to search symbols in the global scope
old_flags = _sys.getdlopenflags()
_sys.setdlopenflags(_os.RTLD_LAZY | _os.RTLD_LOCAL | _os.RTLD_DEEPBIND)
modules = (
_import('mitsuba.mitsuba_ext'),
_import('mitsuba.mitsuba_' + variant + '_ext'),
)
_sys.setdlopenflags(old_flags)
super().__setattr__('_modules', modules)
except ImportError as e:
if str(e).startswith('No module named'):
raise AttributeError('Mitsuba variant "%s" not found.' % variant)
else:
raise AttributeError(e)
finally:
# Remove those modules from sys.modules as only the
# Remove those modules from _sys.modules as only the
# MitsubaVariantModule instance should hold a reference to them.
sys.modules.pop('mitsuba.mitsuba_ext', None)
sys.modules.pop('mitsuba.mitsuba_' + variant + '_ext', None)
_sys.modules.pop('mitsuba.mitsuba_ext', None)
_sys.modules.pop('mitsuba.mitsuba_' + variant + '_ext', None)

submodule = super().__getattribute__('_submodule')
sub_suffix = '' if submodule is None else f'.{submodule}'
Expand Down Expand Up @@ -228,14 +236,14 @@ def __getattribute__(self, key):
# Check whether we are trying to directly import a variant
from .config import MI_VARIANTS
if key in MI_VARIANTS:
return sys.modules[f'mitsuba.{key}']
return _sys.modules[f'mitsuba.{key}']

if not key in ['__dict__', '__wrapped__'] and variant is None:
# The variant wasn't set explicitly, we first check if a default
# variant is set in the config.py file.
from .config import MI_DEFAULT_VARIANT
import os
default_variant = os.getenv('MI_DEFAULT_VARIANT', default=MI_DEFAULT_VARIANT)
default_variant = _os.getenv('MI_DEFAULT_VARIANT', default=MI_DEFAULT_VARIANT)
if default_variant != '':
self.set_variant(default_variant)
variant = default_variant
Expand All @@ -252,11 +260,11 @@ def __getattribute__(self, key):
if not variant is None:
# Check whether we are importing a known submodule
if submodule is None and key in submodules:
return sys.modules[f'mitsuba.{variant}.{key}']
return _sys.modules[f'mitsuba.{variant}.{key}']

# Redirect all other imports to the currently enabled variant module.
sub_suffix = '' if submodule is None else f'.{submodule}'
module = sys.modules[f'mitsuba.{variant}{sub_suffix}']
module = _sys.modules[f'mitsuba.{variant}{sub_suffix}']
result = module.__getattribute__(key)

# Add set_variant(), variant() and variant modules to the __dict__
Expand All @@ -267,7 +275,7 @@ def __getattribute__(self, key):
result['variant'] = super().__getattribute__('variant')
result['variants'] = super().__getattribute__('variants')
for v in super().__getattribute__('variants')():
result[v] = sys.modules[f'mitsuba.{v}']
result[v] = _sys.modules[f'mitsuba.{v}']

# Add this lookup to the cache
cache = getattr(_tls, 'cache', None)
Expand Down Expand Up @@ -318,14 +326,14 @@ def set_variant(self, *args) -> None:
# Automatically load/reload and register Python integrators for AD variants
if value.startswith(('llvm_', 'cuda_')):
import sys
if 'mitsuba.ad.integrators' in sys.modules:
_reload(sys.modules['mitsuba.ad.integrators'])
if 'mitsuba.ad.integrators' in _sys.modules:
_reload(_sys.modules['mitsuba.ad.integrators'])
else:
_import('mitsuba.ad.integrators')
del sys

# Check whether we are reloading the mitsuba module
reload = f'mitsuba.{submodules[0]}' in sys.modules
reload = f'mitsuba.{submodules[0]}' in _sys.modules
if reload:
print(
"The Mitsuba module was reloaded (imported a second time). "
Expand All @@ -338,35 +346,35 @@ def set_variant(self, *args) -> None:
for variant in MI_VARIANTS:
name = f'mitsuba.{variant}'
if reload:
sys.modules[name].__init__(name, variant)
_sys.modules[name].__init__(name, variant)
else:
sys.modules[name] = MitsubaVariantModule(name, variant)
_sys.modules[name] = MitsubaVariantModule(name, variant)

# Register variant submodules
for variant in MI_VARIANTS:
for submodule in submodules:
name = f'mitsuba.{variant}.{submodule}'
if reload:
sys.modules[name].__init__(name, variant, submodule)
_sys.modules[name].__init__(name, variant, submodule)
else:
sys.modules[name] = MitsubaVariantModule(name, variant, submodule)
_sys.modules[name] = MitsubaVariantModule(name, variant, submodule)

# Register the virtual mitsuba module and submodules. This will overwrite the
# real mitsuba module in order to redirect future imports.
if reload:
sys.modules['mitsuba'].__init__('mitsuba')
_sys.modules['mitsuba'].__init__('mitsuba')
else:
sys.modules['mitsuba'] = MitsubaModule('mitsuba')
_sys.modules['mitsuba'] = MitsubaModule('mitsuba')

for submodule in submodules:
name = f'mitsuba.{submodule}'
if reload:
sys.modules[name].__init__(name, submodule)
_sys.modules[name].__init__(name, submodule)
else:
sys.modules[name] = MitsubaModule(name, submodule)
_sys.modules[name] = MitsubaModule(name, submodule)

# Pre-import all symbols from the python submodules to prevent the size of
# sys.modules to change during a call to __getattribute__().
# _sys.modules to change during a call to __getattribute__().
_import(f'mitsuba.python')
for submodule in submodules:
try:
Expand All @@ -379,7 +387,7 @@ def set_variant(self, *args) -> None:
del MitsubaVariantModule
del typing, types
del threading
del os
del logging
if config in locals():
del config
del MI_VARIANTS
Expand Down

0 comments on commit 59d7b35

Please sign in to comment.