diff --git a/gradio/__init__.py b/gradio/__init__.py index 2399d1a9619b5..9ebac266821b1 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -97,6 +97,6 @@ TextArea, ) from gradio.themes import Base as Theme -from gradio.utils import get_package_version, set_static_paths, no_reload +from gradio.utils import get_package_version, no_reload, set_static_paths __version__ = get_package_version() diff --git a/gradio/cli/commands/reload.py b/gradio/cli/commands/reload.py index 5df1a389c9e5d..71bd7aec927a2 100644 --- a/gradio/cli/commands/reload.py +++ b/gradio/cli/commands/reload.py @@ -7,23 +7,25 @@ """ from __future__ import annotations +import atexit import inspect +import json import os import re +import shutil import site import subprocess import sys +import tempfile import threading from pathlib import Path from typing import List, Optional -import shutil import typer from rich import print import gradio from gradio import utils -import tempfile reload_thread = threading.local() @@ -99,9 +101,17 @@ def _setup_config( print(message + "\n") reload_sources: dict[str, str] = {} + + # copy the watched directories to temp directories + # so that we can reload the modules after parsing out + # the gr.no_reload contexts for s in watching_dirs: temp_dir = tempfile.mkdtemp() - shutil.copytree(s, temp_dir, dirs_exist_ok=True) + for file in Path(s).rglob("*.py"): + src = file + dst = Path(temp_dir) / file.relative_to(s) + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src, dst) reload_sources[str(s)] = temp_dir # guaranty access to the module of an app @@ -119,7 +129,15 @@ def main( module_name, path, watch_sources, demo_name = _setup_config( demo_path, demo_name, watch_dirs, encoding ) - import json + + def delete_temp_dirs(): + for temp_dir in watch_sources.values(): + shutil.rmtree(temp_dir, ignore_errors=True) + + atexit.register(delete_temp_dirs) + + # Pass the following data as environment variables + # so that we can set up reload mode correctly in the networking.py module popen = subprocess.Popen( [sys.executable, "-u", path], env=dict( diff --git a/gradio/networking.py b/gradio/networking.py index da8e4d185618d..99de646378077 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -4,6 +4,7 @@ """ from __future__ import annotations +import json import os import socket import threading @@ -15,7 +16,6 @@ import httpx import uvicorn from uvicorn.config import Config -import json from gradio.exceptions import ServerFailedToStartError from gradio.routes import App diff --git a/gradio/utils.py b/gradio/utils.py index 7d5ab14cb0be8..6fb717992973a 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -2,11 +2,14 @@ from __future__ import annotations +import ast import asyncio +import contextlib import copy import dataclasses import functools import importlib +import importlib.util import inspect import json import json.decoder @@ -14,6 +17,7 @@ import pkgutil import re import shutil +import sys import tempfile import threading import time @@ -27,7 +31,7 @@ from io import BytesIO from numbers import Number from pathlib import Path -from types import AsyncGeneratorType, GeneratorType +from types import AsyncGeneratorType, GeneratorType, ModuleType from typing import ( TYPE_CHECKING, Any, @@ -38,9 +42,6 @@ Optional, TypeVar, ) -import ast -import importlib.util -import sys import anyio import httpx @@ -151,63 +152,83 @@ def swap_blocks(self, demo: Blocks): self.alert_change() -import ast -import importlib.util -import sys -import contextlib - - +@document() @contextlib.contextmanager def no_reload(): + """Context manager to prevent a code block from being reloaded in reload mode.""" yield -def remove_sections(string: str, sections: list[tuple[int, int]]) -> str: - lines = string.split('\n') +def _remove_sections(string: str, sections: list[tuple[int, int]]) -> str: + """Helper functions to remove lines from a string. + + Parameters: + string (str): The string to remove lines from. + sections (list[tuple[int, int]]): A list of tuples, where each tuple is a start and end index of a section to remove. + """ + lines = string.split("\n") removed_lines = set() - + for start, end in sections: for i in range(start, end + 1): removed_lines.add(i) - - new_lines = [line for i, line in enumerate(lines) if i not in removed_lines] - return '\n'.join(new_lines) - + new_lines = [line for i, line in enumerate(lines) if i not in removed_lines] + return "\n".join(new_lines) -def load_code_except_no_reload(file_path: str): +def _remove_no_reload_codeblocks(file_path: str): + """Parse the file, remove the gr.no_reload code blocks, and write the file back to disk. - # if not module_name: - # print(f"Cannot find module for {file_path}") - # return - - # if not any(is_in_or_equal(file_path, d) for d in watch_dirs): - # return + Parameters: + file_path (str): The path to the file to remove the no_reload code blocks from. + """ - with open(file_path, 'r') as file: + with open(file_path) as file: code = file.read() tree = ast.parse(code) + def _is_gr_no_reload(expr: ast.expr) -> bool: + """Find with gr.no_reload context managers.""" + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id == "gr" + and expr.func.attr == "no_reload" + ) + + def _is_no_reload(expr: ast.expr): + """Find with no_reload context managers.""" + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "no_reload" + ) + # Find the positions of the code blocks to load skip_indices = [] for node in ast.walk(tree): if isinstance(node, ast.With): ctx_manager = node.items[0].context_expr - if isinstance(ctx_manager, ast.Call) and isinstance(ctx_manager.func, ast.Attribute) and ctx_manager.func.attr == 'no_reload': + if _is_gr_no_reload(ctx_manager) or _is_no_reload(ctx_manager): start_index = node.lineno - 1 end_index = node.body[-1].lineno skip_indices.append((start_index, end_index)) - - code_removed = remove_sections(code, skip_indices) - with open(file_path, 'w') as file: + + code_removed = _remove_sections(code, skip_indices) + with open(file_path, "w") as file: file.write(code_removed) - # breakpoint() - # module = sys.modules[module_name] - # exec(code_removed, module.__dict__) - # sys.modules.pop(module_name) - # sys.modules[module_name] = module + + +def _find_module(source_file: Path) -> ModuleType: + for s, v in sys.modules.items(): + if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str( + source_file + ): + return v + raise ValueError(f"Cannot find module for source file: {source_file}") def watchfn(reloader: SourceFileReloader): @@ -252,57 +273,44 @@ def iter_py_files() -> Iterator[Path]: sys.path.insert(0, str(dir_)) mtimes = {} + # Need to import the module in this thread so that the + # module is available in the namespace of this thread module = importlib.import_module(reloader.watch_module_name) while reloader.should_watch(): changed = get_changes() if changed: print(f"Changes detected in: {changed}") - # To simulate a fresh reload, delete all module references from sys.modules - # for the modules in the package the change came from. - # dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d)) - # modules = list(sys.modules) - # for k in modules: - # v = sys.modules[k] - # sourcefile = getattr(v, "__file__", None) - # # Do not reload `reload.py` to keep thread data - # if ( - # sourcefile - # and dir_ == Path(inspect.getfile(gradio)).parent - # and sourcefile.endswith("reload.py") - # ): - # continue - # if sourcefile and is_in_or_equal(sourcefile, dir_): - # del sys.modules[k] try: - # iterate over every file in reload_dir and remove the no_reload sections - copied_dir = Path(reloader.watch_sources[str(dir_)]) - changed_in_copy = copied_dir / changed.relative_to(dir_) + # How source file reloading works + # 1. Copy the changed source file to a temp directory + # 2. Remove the gr.no_reload code blocks from the temp file + # 3. Execute the changed source code in the original module's namespace + # 4. Do 2-3 for the main demo file even if it did not change. + # This is because the main demo file may import the changed file and we need the + # changes to be reflected in the main demo file. + + changed_dir = next( + dir_ for dir_ in reload_dirs if is_in_or_equal(changed, dir_) + ) + temp_dir = Path(reloader.watch_sources[str(changed_dir)]) + changed_in_copy = temp_dir / changed.relative_to(changed_dir) + shutil.copy(changed, changed_in_copy) - for file in copied_dir.rglob("*.py"): - load_code_except_no_reload(str(file)) - print("module_file", str(changed_in_copy)) - #load_code_except_no_reload(str(changed_in_copy)) + _remove_no_reload_codeblocks(str(changed_in_copy)) + if changed != reloader.demo_file: - for s,v in sys.modules.items(): - if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str(changed): - changed_module = sys.modules[s] - exec(Path(changed_in_copy).read_text(), changed_module.__dict__) - - # for file in copied_dir.rglob("*.py"): - # load_code_except_no_reload(str(file)) - #spec = importlib.util.spec_from_file_location(reloader.watch_module_name, - #str(changed_in_copy)) - # assert spec - # assert spec.loader - #module = importlib.util.module_from_spec(spec) - # assert module - # spec.loader.exec_module(module) - changed_demo_file = copied_dir / reloader.demo_file.relative_to(dir_) + changed_module = _find_module(changed) + exec(Path(changed_in_copy).read_text(), changed_module.__dict__) + + demo_dir = next( + dir_ + for dir_ in reload_dirs + if is_in_or_equal(reloader.demo_file, dir_) + ) + + changed_demo_file = temp_dir / reloader.demo_file.relative_to(demo_dir) + _remove_no_reload_codeblocks(str(changed_demo_file)) exec(Path(changed_demo_file).read_text(), module.__dict__) - # sys.modules[reloader.watch_module_name] = module - # breakpoint() - # module = importlib.import_module(reloader.watch_module_name) - # importlib.reload(module) except Exception: print( f"Reloading {reloader.watch_module_name} failed with the following exception: "