Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Mar 12, 2024
1 parent 17c60e5 commit e18d258
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 82 deletions.
2 changes: 1 addition & 1 deletion gradio/__init__.py
Expand Up @@ -97,6 +97,6 @@
TextArea,
)
from gradio.themes import Base as Theme
from gradio.utils import get_package_version, set_static_paths, no_reload
from gradio.utils import get_package_version, no_reload, set_static_paths

__version__ = get_package_version()
26 changes: 22 additions & 4 deletions gradio/cli/commands/reload.py
Expand Up @@ -7,23 +7,25 @@
"""
from __future__ import annotations

import atexit
import inspect
import json
import os
import re
import shutil
import site
import subprocess
import sys
import tempfile
import threading
from pathlib import Path
from typing import List, Optional
import shutil

import typer
from rich import print

import gradio
from gradio import utils
import tempfile

reload_thread = threading.local()

Expand Down Expand Up @@ -99,9 +101,17 @@ def _setup_config(

print(message + "\n")
reload_sources: dict[str, str] = {}

# copy the watched directories to temp directories
# so that we can reload the modules after parsing out
# the gr.no_reload contexts
for s in watching_dirs:
temp_dir = tempfile.mkdtemp()
shutil.copytree(s, temp_dir, dirs_exist_ok=True)
for file in Path(s).rglob("*.py"):
src = file
dst = Path(temp_dir) / file.relative_to(s)
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(src, dst)
reload_sources[str(s)] = temp_dir

# guaranty access to the module of an app
Expand All @@ -119,7 +129,15 @@ def main(
module_name, path, watch_sources, demo_name = _setup_config(
demo_path, demo_name, watch_dirs, encoding
)
import json

def delete_temp_dirs():
for temp_dir in watch_sources.values():
shutil.rmtree(temp_dir, ignore_errors=True)

atexit.register(delete_temp_dirs)

# Pass the following data as environment variables
# so that we can set up reload mode correctly in the networking.py module
popen = subprocess.Popen(
[sys.executable, "-u", path],
env=dict(
Expand Down
2 changes: 1 addition & 1 deletion gradio/networking.py
Expand Up @@ -4,6 +4,7 @@
"""
from __future__ import annotations

import json
import os
import socket
import threading
Expand All @@ -15,7 +16,6 @@
import httpx
import uvicorn
from uvicorn.config import Config
import json

from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
Expand Down
160 changes: 84 additions & 76 deletions gradio/utils.py
Expand Up @@ -2,18 +2,22 @@

from __future__ import annotations

import ast
import asyncio
import contextlib
import copy
import dataclasses
import functools
import importlib
import importlib.util
import inspect
import json
import json.decoder
import os
import pkgutil
import re
import shutil
import sys
import tempfile
import threading
import time
Expand All @@ -27,7 +31,7 @@
from io import BytesIO
from numbers import Number
from pathlib import Path
from types import AsyncGeneratorType, GeneratorType
from types import AsyncGeneratorType, GeneratorType, ModuleType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -38,9 +42,6 @@
Optional,
TypeVar,
)
import ast
import importlib.util
import sys

import anyio
import httpx
Expand Down Expand Up @@ -151,63 +152,83 @@ def swap_blocks(self, demo: Blocks):
self.alert_change()


import ast
import importlib.util
import sys
import contextlib


@document()
@contextlib.contextmanager
def no_reload():
"""Context manager to prevent a code block from being reloaded in reload mode."""
yield


def remove_sections(string: str, sections: list[tuple[int, int]]) -> str:
lines = string.split('\n')
def _remove_sections(string: str, sections: list[tuple[int, int]]) -> str:
"""Helper functions to remove lines from a string.
Parameters:
string (str): The string to remove lines from.
sections (list[tuple[int, int]]): A list of tuples, where each tuple is a start and end index of a section to remove.
"""
lines = string.split("\n")
removed_lines = set()

for start, end in sections:
for i in range(start, end + 1):
removed_lines.add(i)

new_lines = [line for i, line in enumerate(lines) if i not in removed_lines]
return '\n'.join(new_lines)


new_lines = [line for i, line in enumerate(lines) if i not in removed_lines]
return "\n".join(new_lines)


def load_code_except_no_reload(file_path: str):
def _remove_no_reload_codeblocks(file_path: str):
"""Parse the file, remove the gr.no_reload code blocks, and write the file back to disk.
# if not module_name:
# print(f"Cannot find module for {file_path}")
# return

# if not any(is_in_or_equal(file_path, d) for d in watch_dirs):
# return
Parameters:
file_path (str): The path to the file to remove the no_reload code blocks from.
"""

with open(file_path, 'r') as file:
with open(file_path) as file:
code = file.read()

tree = ast.parse(code)

def _is_gr_no_reload(expr: ast.expr) -> bool:
"""Find with gr.no_reload context managers."""
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id == "gr"
and expr.func.attr == "no_reload"
)

def _is_no_reload(expr: ast.expr):
"""Find with no_reload context managers."""
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "no_reload"
)

# Find the positions of the code blocks to load
skip_indices = []
for node in ast.walk(tree):
if isinstance(node, ast.With):
ctx_manager = node.items[0].context_expr
if isinstance(ctx_manager, ast.Call) and isinstance(ctx_manager.func, ast.Attribute) and ctx_manager.func.attr == 'no_reload':
if _is_gr_no_reload(ctx_manager) or _is_no_reload(ctx_manager):
start_index = node.lineno - 1
end_index = node.body[-1].lineno
skip_indices.append((start_index, end_index))
code_removed = remove_sections(code, skip_indices)
with open(file_path, 'w') as file:

code_removed = _remove_sections(code, skip_indices)
with open(file_path, "w") as file:
file.write(code_removed)
# breakpoint()
# module = sys.modules[module_name]
# exec(code_removed, module.__dict__)
# sys.modules.pop(module_name)
# sys.modules[module_name] = module


def _find_module(source_file: Path) -> ModuleType:
for s, v in sys.modules.items():
if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str(
source_file
):
return v
raise ValueError(f"Cannot find module for source file: {source_file}")


def watchfn(reloader: SourceFileReloader):
Expand Down Expand Up @@ -252,57 +273,44 @@ def iter_py_files() -> Iterator[Path]:
sys.path.insert(0, str(dir_))

mtimes = {}
# Need to import the module in this thread so that the
# module is available in the namespace of this thread
module = importlib.import_module(reloader.watch_module_name)
while reloader.should_watch():
changed = get_changes()
if changed:
print(f"Changes detected in: {changed}")
# To simulate a fresh reload, delete all module references from sys.modules
# for the modules in the package the change came from.
# dir_ = next(d for d in reload_dirs if is_in_or_equal(changed, d))
# modules = list(sys.modules)
# for k in modules:
# v = sys.modules[k]
# sourcefile = getattr(v, "__file__", None)
# # Do not reload `reload.py` to keep thread data
# if (
# sourcefile
# and dir_ == Path(inspect.getfile(gradio)).parent
# and sourcefile.endswith("reload.py")
# ):
# continue
# if sourcefile and is_in_or_equal(sourcefile, dir_):
# del sys.modules[k]
try:
# iterate over every file in reload_dir and remove the no_reload sections
copied_dir = Path(reloader.watch_sources[str(dir_)])
changed_in_copy = copied_dir / changed.relative_to(dir_)
# How source file reloading works
# 1. Copy the changed source file to a temp directory
# 2. Remove the gr.no_reload code blocks from the temp file
# 3. Execute the changed source code in the original module's namespace
# 4. Do 2-3 for the main demo file even if it did not change.
# This is because the main demo file may import the changed file and we need the
# changes to be reflected in the main demo file.

changed_dir = next(
dir_ for dir_ in reload_dirs if is_in_or_equal(changed, dir_)
)
temp_dir = Path(reloader.watch_sources[str(changed_dir)])
changed_in_copy = temp_dir / changed.relative_to(changed_dir)

shutil.copy(changed, changed_in_copy)
for file in copied_dir.rglob("*.py"):
load_code_except_no_reload(str(file))
print("module_file", str(changed_in_copy))
#load_code_except_no_reload(str(changed_in_copy))
_remove_no_reload_codeblocks(str(changed_in_copy))

if changed != reloader.demo_file:
for s,v in sys.modules.items():
if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str(changed):
changed_module = sys.modules[s]
exec(Path(changed_in_copy).read_text(), changed_module.__dict__)

# for file in copied_dir.rglob("*.py"):
# load_code_except_no_reload(str(file))
#spec = importlib.util.spec_from_file_location(reloader.watch_module_name,
#str(changed_in_copy))
# assert spec
# assert spec.loader
#module = importlib.util.module_from_spec(spec)
# assert module
# spec.loader.exec_module(module)
changed_demo_file = copied_dir / reloader.demo_file.relative_to(dir_)
changed_module = _find_module(changed)
exec(Path(changed_in_copy).read_text(), changed_module.__dict__)

demo_dir = next(
dir_
for dir_ in reload_dirs
if is_in_or_equal(reloader.demo_file, dir_)
)

changed_demo_file = temp_dir / reloader.demo_file.relative_to(demo_dir)
_remove_no_reload_codeblocks(str(changed_demo_file))
exec(Path(changed_demo_file).read_text(), module.__dict__)
# sys.modules[reloader.watch_module_name] = module
# breakpoint()
# module = importlib.import_module(reloader.watch_module_name)
# importlib.reload(module)
except Exception:
print(
f"Reloading {reloader.watch_module_name} failed with the following exception: "
Expand Down

0 comments on commit e18d258

Please sign in to comment.