Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: isolate modules in module watcher #1358

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 68 additions & 0 deletions marimo/_runtime/reload/autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

import gc
import io
import modulefinder
import os
import sys
import threading
import traceback
import types
import warnings
import weakref
from dataclasses import dataclass
from importlib import reload
Expand Down Expand Up @@ -69,6 +71,59 @@ def modules_imported_by_cell(
return modules


class ModuleDependencyFinder:
def __init__(self) -> None:
# __file__ ->
self._module_dependencies: dict[str, dict[str, types.ModuleType]] = {}
self._failed_module_filenames: set[str] = set()

def find_dependencies(
self, module: types.ModuleType, excludes: list[str]
) -> dict[str, types.ModuleType]:
if not hasattr(module, "__file__") or module.__file__ is None:
return {}

file = module.__file__
if module.__file__ in self._failed_module_filenames:
return {}

if file in self._module_dependencies:
return self._module_dependencies[file]

finder = modulefinder.ModuleFinder(excludes=excludes)
try:
with warnings.catch_warnings():
# We temporarily ignore warnings to avoid spamming the console,
# since the watcher runs in a loop
warnings.simplefilter("ignore")
finder.run_script(module.__file__)
except SyntaxError:
# user introduced a syntax error, maybe; still check if the
# module itself has been modified
return {}
except Exception:
# some modules like numpy fail when called with run_script;
# run_script takes a long time before failing on them, so
# don't try to analyze them again
self._failed_module_filenames.add(file)
return {}
else:
# False positives
self._module_dependencies[file] = finder.modules # type: ignore[assignment]
return finder.modules # type: ignore[return-value]

def cached(self, module: types.ModuleType) -> bool:
if not hasattr(module, "__file__") or module.__file__ is None:
return False

return module.__file__ in self._module_dependencies

def evict_from_cache(self, module: types.ModuleType) -> None:
file = module.__file__
if file in self._module_dependencies:
del self._module_dependencies[file]


class ModuleReloader:
"""Thread-safe module reloader."""

Expand All @@ -83,6 +138,7 @@ def __init__(self) -> None:
self.stale_modules: set[str] = set()
# for thread-safety
self.lock = threading.Lock()
self._module_dependency_finder = ModuleDependencyFinder()

# Timestamp existing modules
self.check(modules=sys.modules, reload=False)
Expand Down Expand Up @@ -170,6 +226,7 @@ def check(
self.modules_mtimes[modname] = pymtime
modified_modules.add(m)
self.stale_modules.add(modname)
self._module_dependency_finder.evict_from_cache(m)

if not reload:
return modified_modules
Expand Down Expand Up @@ -197,9 +254,20 @@ def check(
msg.format(modname, traceback.format_exc(10)),
)
self.failed[py_filename] = pymtime
else:
# TODO or always evict?
self._module_dependency_finder.evict_from_cache(m)

self.stale_modules.clear()
return modified_modules

def get_module_dependencies(
self, module: types.ModuleType, excludes: list[str]
) -> dict[str, types.ModuleType]:
return self._module_dependency_finder.find_dependencies(
module, excludes
)


def update_function(old: object, new: object) -> None:
"""Upgrade the code object of a function"""
Expand Down
54 changes: 16 additions & 38 deletions marimo/_runtime/reload/module_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import sys
import threading
import time
import warnings
from modulefinder import ModuleFinder
from typing import TYPE_CHECKING, Callable, Literal

from marimo._messaging.types import Stream
Expand Down Expand Up @@ -38,37 +36,23 @@ def is_submodule(src_name: str, target_name: str) -> bool:
def _depends_on(
src_module: types.ModuleType,
target_modules: set[types.ModuleType],
failed_filenames: set[str],
finder: ModuleFinder,
excludes: list[str],
reloader: ModuleReloader,
) -> bool:
"""Returns whether src_module depends on any of target_filenames"""
if not hasattr(src_module, "__file__") or src_module.__file__ is None:
return False

if src_module.__file__ in failed_filenames:
return False
if src_module in target_modules:
return True

try:
with warnings.catch_warnings():
# We temporarily ignore warnings to avoid spamming the console,
# since the watcher runs in a loop
warnings.simplefilter("ignore")
finder.run_script(src_module.__file__)
except SyntaxError:
# user introduced a syntax error, maybe; still check if the
# module itself has been modified
pass
except Exception:
# some modules like numpy fail when called with run_script;
# run_script takes a long time before failing on them, so
# don't try to analyze them again
failed_filenames.add(src_module.__file__)
return False
module_dependencies = reloader.get_module_dependencies(
src_module, excludes=excludes
)

target_filenames = set(
t.__file__ for t in target_modules if hasattr(t, "__file__")
)
for found_module in itertools.chain([src_module], finder.modules.values()):
for found_module in itertools.chain(
[src_module], module_dependencies.values()
):
file = getattr(found_module, "__file__", None)
if file is None:
continue
Expand Down Expand Up @@ -107,24 +91,22 @@ def _get_excluded_modules(modules: dict[str, types.ModuleType]) -> list[str]:
def _check_modules(
modules: dict[str, types.ModuleType],
reloader: ModuleReloader,
failed_filenames: set[str],
finder: ModuleFinder,
sys_modules: dict[str, types.ModuleType],
) -> dict[str, types.ModuleType]:
"""Returns the set of modules used by the graph that have been modified"""
stale_modules: dict[str, types.ModuleType] = {}
modified_modules = reloader.check(modules=sys_modules, reload=False)
# TODO(akshayka): could also exclude modules part of the standard library;
# haven't found a reliable way to do this, however.
excludes = _get_excluded_modules(sys_modules)
for modname, module in modules.items():
if _depends_on(
src_module=module,
target_modules=set(m for m in modified_modules if m is not None),
failed_filenames=failed_filenames,
finder=finder,
excludes=excludes,
reloader=reloader,
):
stale_modules[modname] = module

return stale_modules


Expand All @@ -143,12 +125,9 @@ def watch_modules(
modules imported by the notebook as well as the modules imported by those
modules, recursively.
"""
# modules that failed to be analyzed
failed_filenames: set[str] = set()
# work with a copy to avoid race conditions
# in CPython, dict.copy() is atomic
sys_modules = sys.modules.copy()
finder = ModuleFinder(excludes=_get_excluded_modules(sys_modules))
while not should_exit.is_set():
# Collect the modules used by each cell
modules: dict[str, types.ModuleType] = {}
Expand All @@ -159,13 +138,13 @@ def watch_modules(
if modname in sys_modules:
modules[modname] = sys_modules[modname]
modname_to_cell_id[modname] = cell_id

stale_modules = _check_modules(
modules=modules,
reloader=reloader,
failed_filenames=failed_filenames,
finder=finder,
sys_modules=sys_modules,
)

if stale_modules:
with graph.lock:
# If any modules are stale, communicate that to the FE
Expand All @@ -181,14 +160,13 @@ def watch_modules(
if mode == "autorun":
run_is_processed.clear()
enqueue_run_stale_cells()

# Don't proceed until enqueue_run_stale_cells() has been processed,
# ie until stale cells have been rerun
run_is_processed.wait()
time.sleep(1)
# Update our snapshot of sys.modules
sys_modules = sys.modules.copy()
# Update excluded modules in case the module set has changed.
finder.excludes = _get_excluded_modules(sys_modules)


class ModuleWatcher:
Expand Down
1 change: 1 addition & 0 deletions marimo/_smoke_tests/sidebar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2024 Marimo. All rights reserved.
import marimo

__generated_with = "0.5.1"
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ extend-select = [
"TRY002", # Prohibit use of `raise Exception`, use specific exceptions instead.
]

[tool.ruff.lint.per-file-ignores]
"**/{tests}/*" = ["ANN201", "ANN202"]

# Never try to fix `F401` (unused imports).
unfixable = ["F401"]

[tool.ruff.lint.per-file-ignores]
"**/{tests}/*" = ["ANN201", "ANN202"]

[tool.ruff.lint.isort]
required-imports = ["from __future__ import annotations"]
combine-as-imports = true
Expand Down
Empty file.
1 change: 1 addition & 0 deletions tests/_runtime/reload/reload_data/a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import reload_data.c # noqa:F401
1 change: 1 addition & 0 deletions tests/_runtime/reload/reload_data/b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import reload_data.d # noqa:F401
1 change: 1 addition & 0 deletions tests/_runtime/reload/reload_data/c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pass
1 change: 1 addition & 0 deletions tests/_runtime/reload/reload_data/d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pass
33 changes: 32 additions & 1 deletion tests/_runtime/reload/test_autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from reload_test_utils import update_file

from marimo._runtime.reload.autoreload import ModuleReloader
from marimo._runtime.reload.autoreload import (
ModuleDependencyFinder,
ModuleReloader,
)


def test_reload_function(tmp_path: pathlib.Path, py_modname: str):
Expand Down Expand Up @@ -98,3 +101,31 @@ def foo():
assert py_modname in sys.modules
# ... but it's basically empty
assert not hasattr(mod, "foo")


class TestModuleDependencyFinder:
def test_dependencies_isolated(self):
from tests._runtime.reload.reload_data import a, b, c, d

finder = ModuleDependencyFinder()
a_deps = set(list(finder.find_dependencies(a, excludes=[]).keys()))
b_deps = set(list(finder.find_dependencies(b, excludes=[]).keys()))
c_deps = set(list(finder.find_dependencies(c, excludes=[]).keys()))
d_deps = set(list(finder.find_dependencies(d, excludes=[]).keys()))

assert a_deps == set(["__main__", "reload_data", "reload_data.c"])
assert b_deps == set(["__main__", "reload_data", "reload_data.d"])
assert c_deps == set(["__main__"])
assert d_deps == set(["__main__"])

def test_dependencies_cached(self):
from tests._runtime.reload.reload_data import a

finder = ModuleDependencyFinder()
assert not finder.cached(a)

finder.find_dependencies(a, excludes=[])
assert finder.cached(a)

finder.evict_from_cache(a)
assert not finder.cached(a)