Skip to content

Commit

Permalink
Faster reload mode (#5267)
Browse files Browse the repository at this point in the history
* This works

* Add code

* Final touches

* Lint

* Fix bug in other dirs

* add changeset

* Reload

* lint + test

* Load from frontend

* add changeset

* Use key

* tweak frontend config generation

* tweak

* WIP ipython

* Fix robust

* fix

* Fix for jupyter notebook

* Add checks

* Lint frontend

* Undo demo changes

* add changeset

* Use is_in_or_equal

* python 3.11 changes and no if __name__

* Forward sys.argv + guide

* lint

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed Aug 29, 2023
1 parent 63b7a3c commit 119c834
Show file tree
Hide file tree
Showing 15 changed files with 581 additions and 194 deletions.
7 changes: 7 additions & 0 deletions .changeset/curvy-signs-pump.md
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/client": minor
"gradio": minor
---

feat:Faster reload mode
1 change: 1 addition & 0 deletions .vscode/settings.json
Expand Up @@ -9,6 +9,7 @@
"svelte.plugin.svelte.diagnostics.enable": false,
"prettier.configPath": ".config/.prettierrc.json",
"prettier.ignorePath": ".config/.prettierignore",
"python.analysis.typeCheckingMode": "basic",
"python.testing.pytestArgs": ["."],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
Expand Down
3 changes: 2 additions & 1 deletion client/js/src/client.ts
Expand Up @@ -1133,7 +1133,8 @@ async function resolve_config(
if (
typeof window !== "undefined" &&
window.gradio_config &&
location.origin !== "http://localhost:9876"
location.origin !== "http://localhost:9876" &&
!window.gradio_config.dev_mode
) {
const path = window.gradio_config.root;
const config = window.gradio_config;
Expand Down
31 changes: 20 additions & 11 deletions gradio/blocks.py
Expand Up @@ -741,7 +741,7 @@ def __init__(
self.space_id = utils.get_space()
self.favicon_path = None
self.auth = None
self.dev_mode = True
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", False))
self.app_id = random.getrandbits(64)
self.temp_file_sets = []
self.title = title
Expand Down Expand Up @@ -775,6 +775,12 @@ def __init__(
}
analytics.initiated_analytics(data)

@property
def _is_running_in_reload_thread(self):
from gradio.reload import reload_thread

return getattr(reload_thread, "running_reload", False)

@classmethod
def from_config(
cls,
Expand Down Expand Up @@ -1465,6 +1471,7 @@ def get_config_file(self):
config = {
"version": routes.VERSION,
"mode": self.mode,
"app_id": self.app_id,
"dev_mode": self.dev_mode,
"analytics_enabled": self.analytics_enabled,
"components": [],
Expand Down Expand Up @@ -1796,10 +1803,13 @@ def reverse(text):
demo = gr.Interface(reverse, "text", "text")
demo.launch(share=True, auth=("username", "password"))
"""
if self._is_running_in_reload_thread:
# We have already launched the demo
return None, None, None # type: ignore

if not self.exited:
self.__exit__()

self.dev_mode = False
if (
auth
and not callable(auth)
Expand Down Expand Up @@ -2033,11 +2043,10 @@ def reverse(text):
if self.share and self.share_url:
while not networking.url_ok(self.share_url):
time.sleep(0.25)
display(
HTML(
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
artifact = HTML(
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)

elif self.is_colab:
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
code = """(async (port, path, width, height, cache, element) => {
Expand Down Expand Up @@ -2072,13 +2081,13 @@ def reverse(text):
cache=json.dumps(False),
)

display(Javascript(code))
artifact = Javascript(code)
else:
display(
HTML(
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
artifact = HTML(
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
)
self.artifact = artifact
display(artifact)
except ImportError:
pass

Expand Down
6 changes: 6 additions & 0 deletions gradio/exceptions.py
Expand Up @@ -29,6 +29,12 @@ class InvalidBlockError(ValueError):
pass


class ReloadError(ValueError):
"""Raised when something goes wrong when reloading the gradio app."""

pass


InvalidApiName = InvalidApiNameError # backwards compatibility


Expand Down
90 changes: 78 additions & 12 deletions gradio/ipython_ext.py
@@ -1,23 +1,89 @@
try:
from IPython.core.magic import needs_local_scope, register_cell_magic
from IPython.core.magic import (
needs_local_scope,
register_cell_magic,
)
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
except ImportError:
pass

import warnings

import gradio as gr
from gradio.networking import App
from gradio.utils import BaseReloader


class CellIdTracker:
"""Determines the most recently run cell in the notebook.
Needed to keep track of which demo the user is updating.
"""

def __init__(self, ipython):
ipython.events.register("pre_run_cell", self.pre_run_cell)
self.shell = ipython
self.current_cell: str = ""

def pre_run_cell(self, info):
self._current_cell = info.cell_id


class JupyterReloader(BaseReloader):
"""Swap a running blocks class in a notebook with the latest cell contents."""

def __init__(self, ipython) -> None:
super().__init__()
self._cell_tracker = CellIdTracker(ipython)
self._running: dict[str, gr.Blocks] = {}

@property
def current_cell(self):
return self._cell_tracker.current_cell

@property
def running_app(self) -> App:
assert self.running_demo.server
return self.running_demo.server.running_app

@property
def running_demo(self):
return self._running[self.current_cell]

def demo_tracked(self) -> bool:
return self.current_cell in self._running

def track(self, demo: gr.Blocks):
self._running[self.current_cell] = demo


def load_ipython_extension(ipython):
__demo = gr.Blocks()
reloader = JupyterReloader(ipython)

@magic_arguments()
@argument("--demo-name", default="demo", help="Name of gradio blocks instance.")
@argument(
"--share",
default=False,
const=True,
nargs="?",
help="Whether to launch with sharing. Will slow down reloading.",
)
@register_cell_magic
@needs_local_scope
def blocks(line, cell, local_ns=None):
if "gr.Interface" in cell:
warnings.warn(
"Usage of gradio.Interface with %%blocks may result in errors."
)
with __demo.clear():
exec(cell, None, local_ns)
__demo.launch(quiet=True)
def blocks(line, cell, local_ns):
"""Launch a demo defined in a cell in reload mode."""

args = parse_argstring(blocks, line)

exec(cell, None, local_ns)
demo: gr.Blocks = local_ns[args.demo_name]
if not reloader.demo_tracked():
demo.launch(share=args.share)
reloader.track(demo)
elif reloader.queue_changed(demo):
print("Queue got added or removed. Restarting demo.")
reloader.running_demo.close()
demo.launch()
reloader.track(demo)
else:
reloader.swap_blocks(demo)
return reloader.running_demo.artifact
41 changes: 40 additions & 1 deletion gradio/networking.py
Expand Up @@ -9,14 +9,17 @@
import threading
import time
import warnings
from functools import partial
from typing import TYPE_CHECKING

import requests
import uvicorn
from uvicorn.config import Config

from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
from gradio.tunneling import Tunnel
from gradio.utils import SourceFileReloader, watchfn

if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
Expand All @@ -28,13 +31,34 @@
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request"

should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_FILE = os.getenv("GRADIO_WATCH_FILE", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")


class Server(uvicorn.Server):
def __init__(
self, config: Config, reloader: SourceFileReloader | None = None
) -> None:
assert isinstance(config.app, App)
self.running_app = config.app
super().__init__(config)
self.reloader = reloader
if self.reloader:
self.event = threading.Event()
self.watch = partial(watchfn, self.reloader)

def install_signal_handlers(self):
pass

def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
if self.reloader:
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
self.watch_thread.start()
self.thread.start()
start = time.time()
while not self.started:
Expand All @@ -46,6 +70,9 @@ def run_in_thread(self):

def close(self):
self.should_exit = True
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join()


Expand Down Expand Up @@ -160,7 +187,19 @@ def start_server(
ssl_keyfile_password=ssl_keyfile_password,
ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB
)
server = Server(config=config)
reloader = None
if GRADIO_WATCH_DIRS:
change_event = threading.Event()
app.change_event = change_event
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_file=GRADIO_WATCH_FILE,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
change_event=change_event,
)
server = Server(config=config, reloader=reloader)
server.run_in_thread()
break
except (OSError, ServerFailedToStartError):
Expand Down
9 changes: 7 additions & 2 deletions gradio/queueing.py
Expand Up @@ -19,7 +19,12 @@
ProgressUnit,
)
from gradio.helpers import TrackedIterable
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
from gradio.utils import (
AsyncRequest,
run_coro_in_background,
safe_get_lock,
set_task_name,
)


class Event:
Expand Down Expand Up @@ -59,7 +64,7 @@ def __init__(
self.max_thread_count = concurrency_count
self.update_intervals = update_intervals
self.active_jobs: list[None | list[Event]] = [None] * concurrency_count
self.delete_lock = asyncio.Lock()
self.delete_lock = safe_get_lock()
self.server_path = None
self.duration_history_total = 0
self.duration_history_count = 0
Expand Down

0 comments on commit 119c834

Please sign in to comment.