diff --git a/marimo/_runtime/reload/autoreload.py b/marimo/_runtime/reload/autoreload.py index aee137e5e0..ebabea986d 100644 --- a/marimo/_runtime/reload/autoreload.py +++ b/marimo/_runtime/reload/autoreload.py @@ -11,11 +11,13 @@ import gc import io +import modulefinder import os import sys import threading import traceback import types +import warnings import weakref from dataclasses import dataclass from importlib import reload @@ -69,6 +71,59 @@ def modules_imported_by_cell( return modules +class ModuleDependencyFinder: + def __init__(self) -> None: + # __file__ -> + self._module_dependencies: dict[str, dict[str, types.ModuleType]] = {} + self._failed_module_filenames: set[str] = set() + + def find_dependencies( + self, module: types.ModuleType, excludes: list[str] + ) -> dict[str, types.ModuleType]: + if not hasattr(module, "__file__") or module.__file__ is None: + return {} + + file = module.__file__ + if module.__file__ in self._failed_module_filenames: + return {} + + if file in self._module_dependencies: + return self._module_dependencies[file] + + finder = modulefinder.ModuleFinder(excludes=excludes) + try: + with warnings.catch_warnings(): + # We temporarily ignore warnings to avoid spamming the console, + # since the watcher runs in a loop + warnings.simplefilter("ignore") + finder.run_script(module.__file__) + except SyntaxError: + # user introduced a syntax error, maybe; still check if the + # module itself has been modified + return {} + except Exception: + # some modules like numpy fail when called with run_script; + # run_script takes a long time before failing on them, so + # don't try to analyze them again + self._failed_module_filenames.add(file) + return {} + else: + # False positives + self._module_dependencies[file] = finder.modules # type: ignore[assignment] + return finder.modules # type: ignore[return-value] + + def cached(self, module: types.ModuleType) -> bool: + if not hasattr(module, "__file__") or module.__file__ is None: + return False + + return module.__file__ in self._module_dependencies + + def evict_from_cache(self, module: types.ModuleType) -> None: + file = module.__file__ + if file in self._module_dependencies: + del self._module_dependencies[file] + + class ModuleReloader: """Thread-safe module reloader.""" @@ -83,6 +138,7 @@ def __init__(self) -> None: self.stale_modules: set[str] = set() # for thread-safety self.lock = threading.Lock() + self._module_dependency_finder = ModuleDependencyFinder() # Timestamp existing modules self.check(modules=sys.modules, reload=False) @@ -170,6 +226,7 @@ def check( self.modules_mtimes[modname] = pymtime modified_modules.add(m) self.stale_modules.add(modname) + self._module_dependency_finder.evict_from_cache(m) if not reload: return modified_modules @@ -197,9 +254,20 @@ def check( msg.format(modname, traceback.format_exc(10)), ) self.failed[py_filename] = pymtime + else: + # TODO or always evict? + self._module_dependency_finder.evict_from_cache(m) + self.stale_modules.clear() return modified_modules + def get_module_dependencies( + self, module: types.ModuleType, excludes: list[str] + ) -> dict[str, types.ModuleType]: + return self._module_dependency_finder.find_dependencies( + module, excludes + ) + def update_function(old: object, new: object) -> None: """Upgrade the code object of a function""" diff --git a/marimo/_runtime/reload/module_watcher.py b/marimo/_runtime/reload/module_watcher.py index 6b49ae13cb..220d8ab971 100644 --- a/marimo/_runtime/reload/module_watcher.py +++ b/marimo/_runtime/reload/module_watcher.py @@ -6,8 +6,6 @@ import sys import threading import time -import warnings -from modulefinder import ModuleFinder from typing import TYPE_CHECKING, Callable, Literal from marimo._messaging.types import Stream @@ -38,37 +36,23 @@ def is_submodule(src_name: str, target_name: str) -> bool: def _depends_on( src_module: types.ModuleType, target_modules: set[types.ModuleType], - failed_filenames: set[str], - finder: ModuleFinder, + excludes: list[str], + reloader: ModuleReloader, ) -> bool: """Returns whether src_module depends on any of target_filenames""" - if not hasattr(src_module, "__file__") or src_module.__file__ is None: - return False - - if src_module.__file__ in failed_filenames: - return False + if src_module in target_modules: + return True - try: - with warnings.catch_warnings(): - # We temporarily ignore warnings to avoid spamming the console, - # since the watcher runs in a loop - warnings.simplefilter("ignore") - finder.run_script(src_module.__file__) - except SyntaxError: - # user introduced a syntax error, maybe; still check if the - # module itself has been modified - pass - except Exception: - # some modules like numpy fail when called with run_script; - # run_script takes a long time before failing on them, so - # don't try to analyze them again - failed_filenames.add(src_module.__file__) - return False + module_dependencies = reloader.get_module_dependencies( + src_module, excludes=excludes + ) target_filenames = set( t.__file__ for t in target_modules if hasattr(t, "__file__") ) - for found_module in itertools.chain([src_module], finder.modules.values()): + for found_module in itertools.chain( + [src_module], module_dependencies.values() + ): file = getattr(found_module, "__file__", None) if file is None: continue @@ -107,8 +91,6 @@ def _get_excluded_modules(modules: dict[str, types.ModuleType]) -> list[str]: def _check_modules( modules: dict[str, types.ModuleType], reloader: ModuleReloader, - failed_filenames: set[str], - finder: ModuleFinder, sys_modules: dict[str, types.ModuleType], ) -> dict[str, types.ModuleType]: """Returns the set of modules used by the graph that have been modified""" @@ -116,15 +98,15 @@ def _check_modules( modified_modules = reloader.check(modules=sys_modules, reload=False) # TODO(akshayka): could also exclude modules part of the standard library; # haven't found a reliable way to do this, however. + excludes = _get_excluded_modules(sys_modules) for modname, module in modules.items(): if _depends_on( src_module=module, target_modules=set(m for m in modified_modules if m is not None), - failed_filenames=failed_filenames, - finder=finder, + excludes=excludes, + reloader=reloader, ): stale_modules[modname] = module - return stale_modules @@ -143,12 +125,9 @@ def watch_modules( modules imported by the notebook as well as the modules imported by those modules, recursively. """ - # modules that failed to be analyzed - failed_filenames: set[str] = set() # work with a copy to avoid race conditions # in CPython, dict.copy() is atomic sys_modules = sys.modules.copy() - finder = ModuleFinder(excludes=_get_excluded_modules(sys_modules)) while not should_exit.is_set(): # Collect the modules used by each cell modules: dict[str, types.ModuleType] = {} @@ -159,13 +138,13 @@ def watch_modules( if modname in sys_modules: modules[modname] = sys_modules[modname] modname_to_cell_id[modname] = cell_id + stale_modules = _check_modules( modules=modules, reloader=reloader, - failed_filenames=failed_filenames, - finder=finder, sys_modules=sys_modules, ) + if stale_modules: with graph.lock: # If any modules are stale, communicate that to the FE @@ -181,14 +160,13 @@ def watch_modules( if mode == "autorun": run_is_processed.clear() enqueue_run_stale_cells() + # Don't proceed until enqueue_run_stale_cells() has been processed, # ie until stale cells have been rerun run_is_processed.wait() time.sleep(1) # Update our snapshot of sys.modules sys_modules = sys.modules.copy() - # Update excluded modules in case the module set has changed. - finder.excludes = _get_excluded_modules(sys_modules) class ModuleWatcher: diff --git a/marimo/_smoke_tests/sidebar.py b/marimo/_smoke_tests/sidebar.py index 9a47249ac8..7f5ae21035 100644 --- a/marimo/_smoke_tests/sidebar.py +++ b/marimo/_smoke_tests/sidebar.py @@ -1,3 +1,4 @@ +# Copyright 2024 Marimo. All rights reserved. import marimo __generated_with = "0.5.1" diff --git a/pyproject.toml b/pyproject.toml index 2f79c8734d..9b6372e004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -223,12 +223,12 @@ extend-select = [ "TRY002", # Prohibit use of `raise Exception`, use specific exceptions instead. ] -[tool.ruff.lint.per-file-ignores] -"**/{tests}/*" = ["ANN201", "ANN202"] - # Never try to fix `F401` (unused imports). unfixable = ["F401"] +[tool.ruff.lint.per-file-ignores] +"**/{tests}/*" = ["ANN201", "ANN202"] + [tool.ruff.lint.isort] required-imports = ["from __future__ import annotations"] combine-as-imports = true diff --git a/tests/_runtime/reload/reload_data/__init__.py b/tests/_runtime/reload/reload_data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/_runtime/reload/reload_data/a.py b/tests/_runtime/reload/reload_data/a.py new file mode 100644 index 0000000000..0f17215843 --- /dev/null +++ b/tests/_runtime/reload/reload_data/a.py @@ -0,0 +1 @@ +import reload_data.c # noqa:F401 diff --git a/tests/_runtime/reload/reload_data/b.py b/tests/_runtime/reload/reload_data/b.py new file mode 100644 index 0000000000..b598c5354b --- /dev/null +++ b/tests/_runtime/reload/reload_data/b.py @@ -0,0 +1 @@ +import reload_data.d # noqa:F401 diff --git a/tests/_runtime/reload/reload_data/c.py b/tests/_runtime/reload/reload_data/c.py new file mode 100644 index 0000000000..2ae28399f5 --- /dev/null +++ b/tests/_runtime/reload/reload_data/c.py @@ -0,0 +1 @@ +pass diff --git a/tests/_runtime/reload/reload_data/d.py b/tests/_runtime/reload/reload_data/d.py new file mode 100644 index 0000000000..2ae28399f5 --- /dev/null +++ b/tests/_runtime/reload/reload_data/d.py @@ -0,0 +1 @@ +pass diff --git a/tests/_runtime/reload/test_autoreload.py b/tests/_runtime/reload/test_autoreload.py index faeaa03964..12412289bb 100644 --- a/tests/_runtime/reload/test_autoreload.py +++ b/tests/_runtime/reload/test_autoreload.py @@ -7,7 +7,10 @@ from reload_test_utils import update_file -from marimo._runtime.reload.autoreload import ModuleReloader +from marimo._runtime.reload.autoreload import ( + ModuleDependencyFinder, + ModuleReloader, +) def test_reload_function(tmp_path: pathlib.Path, py_modname: str): @@ -98,3 +101,31 @@ def foo(): assert py_modname in sys.modules # ... but it's basically empty assert not hasattr(mod, "foo") + + +class TestModuleDependencyFinder: + def test_dependencies_isolated(self): + from tests._runtime.reload.reload_data import a, b, c, d + + finder = ModuleDependencyFinder() + a_deps = set(list(finder.find_dependencies(a, excludes=[]).keys())) + b_deps = set(list(finder.find_dependencies(b, excludes=[]).keys())) + c_deps = set(list(finder.find_dependencies(c, excludes=[]).keys())) + d_deps = set(list(finder.find_dependencies(d, excludes=[]).keys())) + + assert a_deps == set(["__main__", "reload_data", "reload_data.c"]) + assert b_deps == set(["__main__", "reload_data", "reload_data.d"]) + assert c_deps == set(["__main__"]) + assert d_deps == set(["__main__"]) + + def test_dependencies_cached(self): + from tests._runtime.reload.reload_data import a + + finder = ModuleDependencyFinder() + assert not finder.cached(a) + + finder.find_dependencies(a, excludes=[]) + assert finder.cached(a) + + finder.evict_from_cache(a) + assert not finder.cached(a)